<a href="https://colab.research.google.com/github/gauravjain14/All-about-JAX/blob/main/Flax_Basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
!pip install --upgrade -q pip jax jaxlib 
!pip install --upgrade git+https://github.com/google/flax.git

[K     |████████████████████████████████| 2.0 MB 5.2 MB/s 
[K     |████████████████████████████████| 1.0 MB 42.8 MB/s 
[K     |████████████████████████████████| 72.0 MB 211 kB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/google/flax.git
  Cloning https://github.com/google/flax.git to /tmp/pip-req-build-bpiwk30f
  Running command git clone --filter=blob:none --quiet https://github.com/google/flax.git /tmp/pip-req-build-bpiwk30f
  Resolved https://github.com/google/flax.git to commit 6ae22681ef6f6c004140c3759e7175533bda55bd
  Preparing metadata (setup.py) ... [?25l[?25hdone
[0m

In [7]:
import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn

## Linear regression with Flax 

In [8]:
# Create one dense layer instance
# We only specify the output size of the model. Size of the 
# input is idenfied by the correct size of the kerne
model = nn.Dense(features=5) # Number of 'features' parameter as input 

### Follow this up with Model parameters and Initialization

In [13]:
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,))  # generate dummy input
params = model.init(key2, x)
jax.tree_util.tree_map(lambda x: x.shape, params)

FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

In [14]:
# The params created are immutable i.e. retains the functional nature of JAX
try:
  params['new_key'] = jnp.ones((2,2))
except ValueError as e:
  print("Error: ", e)

Error:  FrozenDict is immutable.


#### How do we evaluate a model given a set of parameters?

We execute `model.apply(parameters, input)`

Note: Seems like when we call the `print` method on the model output, it copies that to the host.

In [17]:
model_output = model.apply(params, x)
model_output
# print(model_output)

DeviceArray([-1.3721193 ,  0.61131495,  0.6442836 ,  2.2192965 ,
             -1.1271116 ], dtype=float32)

## Gradient Descent

Now that we know how to initialize parameters and apply them to a model, let's learn how to use the same "immutable" parameters and execute **Gradient Descent**. 

Confused about **Gradient Descent**, feel free to take a look at [this link](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#gradient-descent)

In [40]:
## Let's begin with setting up initial probelm dimensions
n_samples = 20
x_dim = 10
y_dim = 5

## Initialize parameters (W) and bias (b)
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))

# This is a first - store the parameters in a pytree
true_params = freeze({'params': {'bias':b, 'kernel': W}})
# We can look at how this pytree looks like
true_params

FrozenDict({
    params: {
        bias: DeviceArray([-1.4581939, -2.047044 ,  2.0473392,  1.1684095, -0.9758364],            dtype=float32),
        kernel: DeviceArray([[ 1.0247566 ,  0.18528093,  0.03387944, -0.86629736,
                       0.34718114],
                     [ 1.7656006 ,  0.99169755,  1.1657897 ,  1.1106981 ,
                      -0.08589564],
                     [-1.1820309 ,  0.29050717,  1.436301  ,  0.15073189,
                      -1.3651401 ],
                     [-1.1463748 , -0.16064964,  0.04578291,  1.3267074 ,
                       0.08830649],
                     [ 0.15840754,  1.3908992 , -1.3764939 ,  0.4419787 ,
                      -2.2242246 ],
                     [ 0.5943986 ,  0.8191525 ,  0.32800463,  0.51409715,
                       0.92392564],
                     [-0.32272884,  1.7835051 ,  1.0902369 , -0.5799917 ,
                       0.9487662 ],
                     [ 0.97157586, -1.2998172 ,  0.3205269 ,  0.806568  ,
      

In [28]:
# Generate input samples with additional noise
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise, (n_samples, y_dim))

# Let's see how these x and y samples look like
print('x shape: ', x_samples.shape, '; y shape: ', y_samples.shape)

x shape:  (20, 10) ; y shape:  (20, 5)


Now that we have initialized the inputs and the parameters, let's define the loss function and set up the gradient descent

In [42]:
@jax.jit
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y - pred, y - pred) / 2.0

  # Vectorize the previous to compute the average of the loss on all samples
  # For the uninitiated, jax.vmap takes the vector of x samples and y samples,
  # applies them to the function squared_error
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

In [43]:
## Now applying Gradient Descent
learning_rate = 0.3  # Gradient step size
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))

# Using the magic wand here - jax.value_and_grad which computes the value
# and gradient on the linear regression model
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_util.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)

Loss for "true" W,b:  0.023639793
Loss step 0:  35.343876
Loss step 10:  0.514347
Loss step 20:  0.11384159
Loss step 30:  0.039326735
Loss step 40:  0.019916208
Loss step 50:  0.014209136
Loss step 60:  0.012425654
Loss step 70:  0.01185039
Loss step 80:  0.011661784
Loss step 90:  0.011599409
Loss step 100:  0.011578695
