 Useful links:
 - [Flax_basics](https://flax.readthedocs.io/en/latest/guides/flax_basics.html)
 - [Training-Loop-in-JAX](https://wandb.ai/jax-series/simple-training-loop/reports/Writing-a-Training-Loop-in-JAX-FLAX--VmlldzoyMzA4ODEy)
 - [Kaggle example](https://www.kaggle.com/code/nilaychauhan/digit-recognizer-using-jax-flax/notebook)

In [None]:
# %env JAX_ENABLE_X64=1
%env JAX_PLATFORM_NAME=cpu
# %env JAX_DISABLE_JIT=1
# %env JAX_DEBUG_NANS=1

In [None]:
import jax
from jax import random, lax, jit
from flax import linen as nn
from typing import Sequence, Callable
import optax
from flax.training import train_state
import jax.numpy as jnp
from collections import defaultdict
import matplotlib.pylab as plt
import numpy as np

In [None]:
key1, key2 = random.split(random.PRNGKey(0), 2)

def generate_x(shape):
    return 5 * jnp.pi * jnp.asarray(np.random.uniform(size=shape))

def target(x):
    return jnp.sin(x)*x

x = generate_x((8, 1))
x.shape

In [None]:
class Net(nn.Module):
    features: Sequence[int]

    def setup(self):
        self.layers = [nn.Dense(feat) for feat in self.features]

    def __call__(self, inputs):
        x = inputs
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i != len(self.layers) - 1:
                x = nn.relu(x)
        return x

In [None]:
model = Net(features=[16, 32, 64, 1])
params = model.init(key2, x)['params']
print(type)
jax.tree_map(lambda x: x.shape, params) # Check the parameters

In [None]:
y = model.apply({'params': params}, x)
y.shape

In [None]:
def init_train_state(
    model, 
    random_key, 
    shape, 
    learning_rate,
) -> train_state.TrainState:
    
    # Initialize the Model
    variables = model.init(random_key, jnp.ones(shape))
    
    # Create the optimizer
    optimizer = optax.adam(learning_rate)
    # optimizer = optax.sgd(0.0001, 0.9)
    
    # Create a State
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params'],
    )

state = init_train_state(
    model, key2, (16, 1), 0.001
)
# state

In [None]:
def mse_loss(*, logits, labels):
    return ((labels - logits)**2).mean()

In [None]:
def compute_metrics(*, logits, labels):
    loss = mse_loss(logits=logits, labels=labels)
    accuracy = jnp.sqrt(loss)
    return {
        'loss': loss,
        'accuracy': accuracy,
    }

In [None]:
@jit
def train_step(
    state: train_state.TrainState, 
    batch: jnp.ndarray,
):
    x, y = batch

    def loss_fn(params):
        logits = state.apply_fn({'params': params}, x)
        loss = mse_loss(logits=logits, labels=y)
        return loss, logits
    
    grad_fn = jax.grad(loss_fn, has_aux=True)
    grads, logits = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits=logits, labels=y)
    
    return state, metrics

@jit
def eval_step(
    state: train_state.TrainState, 
    batch: jnp.ndarray,
):
    x, y = batch
    logits = state.apply_fn({'params': state.params}, x)
    metrics = compute_metrics(logits=logits, labels=y)
    return metrics

In [None]:
history = defaultdict(list)

for epoch in range(500):
    for _ in range(64): # loop over batches
        x_ = generate_x((32, 1))
        batch = x_, target(x_) 
        state, metrics = train_step(state, batch)
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, loss={metrics['loss']}")
        # print(state.params['layers_0']['bias'])
    history['epoch'].append(epoch)
    history['metrics_train'].append(metrics)

In [None]:
plt.plot(history['epoch'], [v['loss'] for v in history['metrics_train']])
plt.xlabel("epoch")
plt.ylabel("loss");

In [None]:
x = generate_x((512, 1))
y_pred = model.apply({'params': state.params}, x)
y_true = target(x)
print(jnp.mean((y_pred - y_true)**2))

plt.scatter(x, y_pred, label="neural network")
plt.scatter(x, y_true, label="target fucntion")
plt.legend()

In [None]:
plt.scatter(y_true, y_pred)
plt.plot(y_true, y_true, c="r");

### Train multiple model

In [None]:
from frozendict import frozendict
from typing import Tuple

@jit
def train_step_multiple_model(
    state: Tuple[train_state.TrainState], 
    batch: jnp.ndarray,
):
    x, y = batch

    def loss_fn(params: Tuple[frozendict]):
        logits = jnp.array(0.0)
        for s, p in zip(state, params):
            logits += s.apply_fn({'params': p}, x)
        loss = mse_loss(logits=logits, labels=y)
        return loss, logits
    
    grad_fn = jax.grad(loss_fn, has_aux=True)
    grads, logits = grad_fn(tuple(s.params for s in state))
    state = tuple(s.apply_gradients(grads=g) for s, g in zip(state, grads))
    metrics = compute_metrics(logits=logits, labels=y)
    
    return state, metrics

In [None]:
N = 2
model = tuple(Net(features=[8, 16, 32, 1]) for _ in range(N))
random_keys = jax.random.split(jax.random.PRNGKey(0), N)
state = tuple(init_train_state(m, key, (16, 1), 0.001) for m, key in zip(model, random_keys))

In [None]:
history = defaultdict(list)

for epoch in range(500):
    for _ in range(64): # loop over batches
        x_ = generate_x((32, 1))
        batch = x_, target(x_) 
        state, metrics = train_step_multiple_model(state, batch)
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, loss={metrics['loss']}")
        # print(state.params['layers_0']['bias'])
    history['epoch'].append(epoch)
    history['metrics_train'].append(metrics)

In [None]:
plt.plot(history['epoch'], [v['loss'] for v in history['metrics_train']])
plt.xlabel("epoch")
plt.ylabel("loss");

In [None]:
x = generate_x((512, 1))
y_pred = sum(s.apply_fn({'params': s.params}, x) for s in state[:])
y_true = target(x)
print(jnp.mean((y_pred - y_true)**2))

plt.title("Multiple model training")
plt.scatter(x, y_pred, label="neural network")
plt.scatter(x, y_true, label="target fucntion")
plt.legend()

## Gradient of model

In [None]:
from jax import grad, jit

In [None]:
def func_model(x):
    return jnp.sum(sum(s.apply_fn({'params': s.params}, x) for s in state[:]))
    
grad_func_model = jit(grad(func_model))

In [None]:
timeit grad_func_model(x)

In [None]:
def fn(x):
    return jnp.sin(x).sum()

def scale(x):
    return 2 * x - 1
    
def func(x):
    return fn(scale(x))
    
grad_fn = grad(fn)
grad_func = grad(func)

jnp.all(grad_func(x) == 2 * grad_fn(scale(x)) )

In [None]:
class SimpleDense(nn.Module):
    features: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros

    @nn.compact
    def __call__(self, inputs):
        kernel = self.param('kernel',
                            self.kernel_init, # Initialization function
                            (inputs.shape[-1], self.features))  # shape info.
        y = lax.dot_general(inputs, kernel,
                            (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?
        bias = self.param('bias', self.bias_init, (self.features,))
        y = y + bias
        return y

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)