<a href="https://colab.research.google.com/github/deterministic-algorithms-lab/Jax-Journey/blob/main/flax_basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --upgrade -q pip jax jaxlib
# Install Flax at head:
!pip install --upgrade -q git+https://github.com/google/flax.git

[K     |████████████████████████████████| 1.5MB 5.5MB/s 
[K     |████████████████████████████████| 522kB 36.3MB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone
  Building wheel for flax (setup.py) ... [?25l[?25hdone


In [2]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn

from jax.config import config
config.enable_omnistaging() # Linen requires enabling omnistaging

* Class attributes are attributes of class specified outside any function. 
* They are same for all instances of the class.
* In below syntax, ```features``` is not a class attribute. In the ```__init__()``` of parent class,  it will be initialized. It is different for different objects, and must be provided during creation of object.



In [5]:
class ExplicitMLP(nn.Module):
    features: Sequence[int]

    def setup(self):
        '''
        This function is called automatically after __postinit__() function. 
        Here we can register submodules, variables, parameters you will need in your model.
        '''
        self.layers = [nn.Dense(feat) for feat in self.features]
        
    def __call__(self, inputs):
        '''
        Is called whenever inputs are sent in the model.apply()
        It doesn't matter whether inputs contain params or not. Don't think about it.
        This function just need specifies the flow.
        '''
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i!=len(self.layers)-1:
                x = nn.relu(x)
        return x

* In the above class, ```model.layers``` won't be accessible from outside the class. It seems like these layers come into existence only when ```model.apply()``` is called.

* Below is an example of a neat trick done by flax. If you would like to modify/define the initialisation procedure for a module, at first sight it looks like you will have to pass in and maintain what method to use outside of class(like with ```params```). But, what flax does is that it recognizes that the initialisation method is basically just a combination of function and a random key, so, it will allow you to store and maintain the function part inside the class! (You can do so for functions, but not for shared state.) And this function will take the random key+ shapes etc. as its input and produce deterministic output based on that, which will be used to provide the initial parameters. 

In [None]:
key = random.PRNGKey(0)
key1, key2 = random.split(key, 2)
x = random.normal(key1, (4,4))                                                  #First dimension will automatically be interpretted as batch-dimension. No need to use vmap.

model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)                                                    #Would go on init-ing all the internal layers too.

The ```model.apply()``` below, would have to call each of its sub-layer as specified in ```__call__``` function above. Before calling each of it's sub layers, it sets that specific layer's params properly and would also set various flags that would make sure that you can only use ```__call__``` from inside ```model.apply()``` or ```model.init()```.

In [None]:
y = model.apply(params, x)                                                      #Can't do y = model((params,x))

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('output shape:\n', y.shape)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output shape:
 (4, 5)


Below is another easier method for specifying the flow of steps in the model. We define as well as use the layers directly, specifying only what to pass to it. 

In [None]:
class SimpleMLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for i, feat in enumerate(self.features):
            x = nn.Dense(feat, name=f'layers_{i}')(x)                           #No need to do init/apply etc. as we are in @nn.compact
            if i!=len(self.features)-1:
                x=nn.relu(x)
        return x        

In [None]:
key = random.PRNGKey(0)
key1, key2 = random.split(key, 2)
x = random.uniform(key1, (4,4))

model = SimpleMLP([4, 3, 5])
params = model.init(key2,x)
y = model.apply(params, x)

print('initialised parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('output shape:\n', y.shape)

initialised parameter shapes:
 {'params': {'layers_0': {'bias': (4,), 'kernel': (4, 4)}, 'layers_1': {'bias': (3,), 'kernel': (4, 3)}, 'layers_2': {'bias': (5,), 'kernel': (3, 5)}}}
output shape:
 (4, 5)


Compact notation for defining computation models from scratch, using mathematical operations(only) alongside defining any parameters that the model has. The ```self.param()``` behave differently based on whether ```__call__``` has been called by ```init()``` or ```apply()```.

In [None]:
class SimpleDense(nn.Module):
    features: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros

    @nn.compact
    def __call__(self, inputs):
        kernel = self.param('kernel',
                            self.kernel_init,
                            (inputs.shape[-1], self.features))
        y = jnp.dot(inputs, kernel)
        bias = self.param('bias',
                          self.bias_init,
                          (self.features, ))
        y = y+bias
        return y        

In [None]:
key = random.PRNGKey(0)
key1, key2 = random.split(key, 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialised parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('output shape:\n', y.shape)

initialised parameter shapes:
 {'params': {'bias': (3,), 'kernel': (4, 3)}}
output shape:
 (4, 3)


If the above model is implemented using ```setup()``` way, it won't be able to fill in the blank below as no input is available in ```setup()``` function.

In [None]:
class SimpleDense(nn.Module):
    features: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros

    def setup(self):
        self.kernel = self.param('kernel',
                                self.kernel_init,
                                (___________, self.features))
        bias = self.param('bias',
                          self.bias_init,
                          (self.features, ))
    @nn.compact
    def __call__(self, inputs):
        y = jnp.dot(inputs, self.kernel)+self.bias
        return y        

* Following code shows how to define variables for a model, apart from its parameters. 
* The variables, like parameters, are stored in a tree. 
* And like parameters, are handled outside the class.
* To define a variable, specify the entire path from root to the final variable. Here we have specified ```('batch_stats', 'mean')```.
* Due to ```@nn.compact()``` the variables and parameters are only initalised and defined once. but all the operations specified are performed every time ```model.apply()``` is called.

In [None]:
class BiasAdderWithRunningMean(nn.Module):
    decay: float = 0.99

    @nn.compact
    def __call__(self, x):
        is_initialized = self.has_variable('batch_stats', 'mean')
        ra_mean = self.variable('batch_stats', 'mean',                          #variable entire path name
                                lambda s: jnp.zeros(s),                         #initialization function
                                x.shape[1:])                                    #input to initialization function
        mean = ra_mean.value
        bias = self.param('bias', 
                          lambda rng, shape : jnp.zeros(shape),                 #Since it's a parameter, its lambda function must take rng and shape both. 
                          x.shape[1:])
        
        if is_initialized:
            ra_mean.value = self.decay * ra_mean.value\
                            + (1.0-self.decay)*jnp.mean(x, axis=0, keepdims=True)

        return x - ra_mean.value + bias   

* The ```model.apply()``` call has been modified below. You must specify the mutable parameters of the model, and receive them in the output. 

* The variable ```y``` still contains, the value returned by the ```__call__``` function defined above.

* ```model.init()``` returns all the initialized parameters, i.e., variables and params, both. All those are sent into the ```apply()``` call. (And hence they don't need to be initialised again in ```__call__```. )

* Although the ```model.apply()``` returns updated variables, but still ```params_n_variables``` has the same old variables. As variables need to be handled outside the class too; so the variables in the ```params_n_variables``` need to be updated here too. 

In [None]:
key = random.PRNGKey(0)
key1, key2 = random.split(key, 2)
x = random.uniform(key1, (5,))

model = BiasAdderWithRunningMean(decay=0.99)
params_n_variables = model.init(key2, x)
print(params_n_variables)

for i in range(10):
    x = random.normal(key2+i, (5,))
    
    y, updated_variables = model.apply(params_n_variables, x, mutable=['batch_stats'])

    old_variables, params = params_n_variables.pop('params')                                #remaining tree is first output and popped part(params) is the second
    params_n_variables = freeze({'params':params, **updated_variables})                     #New tree being made from the available components
    
    print(updated_state)

print('initialised parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('output shape:\n', y.shape)

#Optimizers in flax

The parameters of the model are stored in the optimizer and are available in ```optimizer.target``` .

In [None]:
from flax import optim
optimizer_def = optim.GradientDescent(learning_rate=0.01)
optimizer = optimizer_def.create(params)                                        #These params are stored within the class of optimizer and need not be handled outside.
loss_grad_fn = jax.value_and_grad(loss)                                         

In [None]:
for i in range(101):
    loss_val, grad = loss_grad_fn(optimizer.target)
    optimizer = optimizer.apply_gradient(grad)