<a href="https://colab.research.google.com/github/gauravjain14/All-about-JAX/blob/main/Flax_Basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --upgrade -q pip jax jaxlib 
!pip install --upgrade git+https://github.com/google/flax.git

[K     |████████████████████████████████| 2.0 MB 4.5 MB/s 
[K     |████████████████████████████████| 1.0 MB 25.5 MB/s 
[K     |████████████████████████████████| 72.0 MB 1.3 MB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/google/flax.git
  Cloning https://github.com/google/flax.git to /tmp/pip-req-build-flhgtbd2
  Running command git clone --filter=blob:none --quiet https://github.com/google/flax.git /tmp/pip-req-build-flhgtbd2
  Resolved https://github.com/google/flax.git to commit fb253d05f3ae93942c407f1131a816e225b179ae
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m145.1/145.1 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rich~=11.1
  Downloading rich-11.2.0-py3-none-any.whl (

In [2]:
import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn

## Linear regression with Flax 

In [3]:
# Create one dense layer instance
# We only specify the output size of the model. Size of the 
# input is idenfied by the correct size of the kerne
model = nn.Dense(features=5) # Number of 'features' parameter as input 

### Follow this up with Model parameters and Initialization

In [4]:
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,))  # generate dummy input
params = model.init(key2, x)
jax.tree_util.tree_map(lambda x: x.shape, params)



FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

In [5]:
# The params created are immutable i.e. retains the functional nature of JAX
try:
  params['new_key'] = jnp.ones((2,2))
except ValueError as e:
  print("Error: ", e)

Error:  FrozenDict is immutable.


#### How do we evaluate a model given a set of parameters?

We execute `model.apply(parameters, input)`

Note: Seems like when we call the `print` method on the model output, it copies that to the host.

In [6]:
model_output = model.apply(params, x)
model_output
# print(model_output)

DeviceArray([-1.3721193 ,  0.61131495,  0.6442836 ,  2.2192965 ,
             -1.1271116 ], dtype=float32)

## Gradient Descent

Now that we know how to initialize parameters and apply them to a model, let's learn how to use the same "immutable" parameters and execute **Gradient Descent**. 

Confused about **Gradient Descent**, feel free to take a look at [this link](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#gradient-descent)

In [7]:
## Let's begin with setting up initial probelm dimensions
n_samples = 20
x_dim = 10
y_dim = 5

## Initialize parameters (W) and bias (b)
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))

# This is a first - store the parameters in a pytree
true_params = freeze({'params': {'bias':b, 'kernel': W}})
# We can look at how this pytree looks like
true_params

FrozenDict({
    params: {
        bias: DeviceArray([-1.4581939, -2.047044 ,  2.0473392,  1.1684095, -0.9758364],            dtype=float32),
        kernel: DeviceArray([[ 1.0247566 ,  0.18528093,  0.03387944, -0.86629736,
                       0.34718114],
                     [ 1.7656006 ,  0.99169755,  1.1657897 ,  1.1106981 ,
                      -0.08589564],
                     [-1.1820309 ,  0.29050717,  1.436301  ,  0.15073189,
                      -1.3651401 ],
                     [-1.1463748 , -0.16064964,  0.04578291,  1.3267074 ,
                       0.08830649],
                     [ 0.15840754,  1.3908992 , -1.3764939 ,  0.4419787 ,
                      -2.2242246 ],
                     [ 0.5943986 ,  0.8191525 ,  0.32800463,  0.51409715,
                       0.92392564],
                     [-0.32272884,  1.7835051 ,  1.0902369 , -0.5799917 ,
                       0.9487662 ],
                     [ 0.97157586, -1.2998172 ,  0.3205269 ,  0.806568  ,
      

In [8]:
# Generate input samples with additional noise
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise, (n_samples, y_dim))

# Let's see how these x and y samples look like
print('x shape: ', x_samples.shape, '; y shape: ', y_samples.shape)

x shape:  (20, 10) ; y shape:  (20, 5)


Now that we have initialized the inputs and the parameters, let's define the loss function and set up the gradient descent

In [9]:
@jax.jit
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y - pred, y - pred) / 2.0

  # Vectorize the previous to compute the average of the loss on all samples
  # For the uninitiated, jax.vmap takes the vector of x samples and y samples,
  # applies them to the function squared_error
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

In [10]:
## Now applying Gradient Descent
learning_rate = 0.3  # Gradient step size
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))

# Using the magic wand here - jax.value_and_grad which computes the value
# and gradient on the linear regression model
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_util.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)

Loss for "true" W,b:  0.023639793
Loss step 0:  35.343876
Loss step 10:  0.514347
Loss step 20:  0.11384159
Loss step 30:  0.039326735
Loss step 40:  0.019916208
Loss step 50:  0.014209136
Loss step 60:  0.012425654
Loss step 70:  0.01185039
Loss step 80:  0.011661784
Loss step 90:  0.011599409
Loss step 100:  0.011578695


## Optimizing with Optax

Optax is an optimization package for doing stuff like gradient transformations, changing hyperparameters over time, applying updates to different parts of the parameter tree, and more.

In [11]:
import optax
tx = optax.sgd(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

  PyTreeDef = type(jax.tree_structure(None))


In [12]:
for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

Loss step 0:  0.011577628
Loss step 10:  0.011571462
Loss step 20:  0.011569391
Loss step 30:  0.011568717
Loss step 40:  0.011568484
Loss step 50:  0.011568407
Loss step 60:  0.01156839
Loss step 70:  0.0115683675
Loss step 80:  0.011568379
Loss step 90:  0.011568378
Loss step 100:  0.011568374


## Let's define our own models

In [13]:
class ExplicitMLP(nn.Module):
  features: Sequence[int]

  def setup(self):
    """In setup, you are able to name some sublayers and keep them around
     for further use (e.g. encoder/decoder methods in autoencoders). """
    self.layers = [nn.Dense(feat) for feat in self.features]
  
  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate(self.layers):
      x = lyr(x)
      if i != len(self.layers) - 1:
        x = nn.relu(x)
    return x

In [14]:
## Let's initialize and apply params to the model

key1, key2 = random.split(random.PRNGKey(0), 2)
# Input is defined as a tensor of shape [4, 4]
x = random.uniform(key1, (4,4))
# Initializes the multi-layer MLP;
# Each shape corresponds to the size of the output of the layer
model = ExplicitMLP(features=[3,4,5])

# This is important because this is where the JIT feature kicks in and the
# model is lazily initialized with param shapes but without actually running
# the inputs
# model.init and model.apply, both call the __call__ function
params = model.init(key2, x)

# Because layers are not defined in the ExplicitMLP, we can't directly call
# model(x); 
y = model.apply(params, x)

print('initialized parameter shapes: \n', jax.tree_util.tree_map(jnp.shape, 
                                                            unfreeze(params)))
print('output: \n', y)

initialized parameter shapes: 
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output: 
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.0072379  -0.00810348 -0.0255094   0.02151717 -0.01261241]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


### A segue to underlying APIs

Looking at the Model setup in [Section Let's define our own models](https://colab.research.google.com/drive/1_2fZs7dQsAZwFBoKAJd4SpnrRz0gQd1L?authuser=1#scrollTo=Let_s_define_our_own_models), it occurred to me there are a lot of APIs, like the `lax.dot_general` that don't really make sense.

So, this segue is to explore them using quick examples

In [15]:
# A sample kernel of shape [3, 4]
sample_kernel = jnp.array([[1, 2, 3, 0], [4, 5, 6, 0], [7, 8, 9, 0]])

# A sample input of shape [3, 3]
sample_inp = jnp.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])

# Let's look at what lax.dot_general is all about;
# It takes a dot product. However, it is slightly complicated because
# the API requires us to provide the dimensions along which reduction
# should be performed. 
# Also, we need to specify the batch dimension i.e. which dimension of each
# tensor should be used as a Batch. It can be left empty
print(lax.dot_general(sample_inp, sample_kernel, (((1,), (0,)), ((), ())),))

# Here, we eliminate the reduction dimension (it takes default to be the
# innermost of the lhs and outermost of the rhs) and introduce a batch
# dimension. 
# An important thing to note for the batch dimension number is that the number
# of batches should be the same.

print(lax.dot_general(sample_inp, sample_kernel, (((), ()), ((0,), (0,))),))

[[12 15 18  0]
 [12 15 18  0]
 [12 15 18  0]]
[[[1 2 3 0]
  [1 2 3 0]
  [1 2 3 0]]

 [[4 5 6 0]
  [4 5 6 0]
  [4 5 6 0]]

 [[7 8 9 0]
  [7 8 9 0]
  [7 8 9 0]]]


Getting back to working with MLPs and our own custom Model

### Module parameters

Defining our own Dense layer

In [18]:
class SimpleDense(nn.Module):
  features: int
  # Using typing.Callable allows any number of arguments to the function and
  # allow the function to reurn Any 
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros

  @nn.compact
  def __call__(self, inputs):
    """ __call__ is instantiated when model.init is called and by default
    model.apply jumps to the __call__ method.

    The "method" argument in apply() should be set if the user intends to call
    "apply" on a different method than __call__.
    """
    kernel = self.param('kernel',
                        self.kernel_init,
                        (inputs.shape[-1], self.features))
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),)
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

In [19]:
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))
print('Initializing the model')
model = SimpleDense(features=3)
print('Initializing the Parameters')
params = model.init(key2, x)
y = model.apply(params, x)

Initializing the model
Initializing the Parameters


In [38]:
class SimpleDenseMadeComplicated(nn.Module):
  features: int
  # Using typing.Callable allows any number of arguments to the function and
  # allow the function to reurn Any 
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros

  def setup(self):
    """ 
    setup() doesn't accept any inputs. 
    Only initialize the kernel and bias in the __call__ method.
    We'll pass a "method" argument to apply()
    """
    self.intermediate_layers = [nn.Dense(feature) for feature in self.features[1:]]

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init,
                        (inputs.shape[-1], self.features[0]))
    # Apply bias only on the first layer
    bias = self.param('bias', self.bias_init, (self.features[0],))
    x = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),)
    x += bias
    for i, lyr in enumerate(self.intermediate_layers):
      x = lyr(x)
      if i != len(self.intermediate_layers) - 1:
        x = nn.relu(x)

    return x

In [40]:
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))
print('Initializing the model')
model = SimpleDenseMadeComplicated(features=[3, 4, 5])
print('Initializing the Parameters')
params = model.init(key2, x)
print(params)
y = model.apply(params, x)
print(y)

Initializing the model
Initializing the Parameters
FrozenDict({
    params: {
        kernel: DeviceArray([[ 0.61506   , -0.22728713,  0.6054702 ],
                     [-0.29617992,  1.1232013 , -0.879759  ],
                     [-0.35162622,  0.3806491 ,  0.6893246 ],
                     [-0.1151355 ,  0.04567898, -1.091212  ]], dtype=float32),
        bias: DeviceArray([0., 0., 0.], dtype=float32),
        intermediate_layers_0: {
            kernel: DeviceArray([[ 0.7151457 ,  0.12737666, -1.0549262 ,  0.08998597],
                         [-0.74000335,  1.1685069 , -0.26688337, -0.24311556],
                         [-0.20404899,  0.4641534 ,  0.03094014,  0.26043552]],            dtype=float32),
            bias: DeviceArray([0., 0., 0., 0.], dtype=float32),
        },
        intermediate_layers_1: {
            kernel: DeviceArray([[ 0.38301077, -0.21687   , -0.44121116,  0.10675155,
                           0.41591075],
                         [-1.0800076 ,  0.32631534, -