# Model

Note! Ensure Runtime GPU is used.

## Anatomy of the Model

Install the required packages

In [None]:
%%capture
%pip install flax

Define the imports

In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt

from functools import partial
from typing import Any, Callable

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn

### Model signature

$$ f(w; x) = \hat{y} $$
Here we place parameters at the first place to match the signature required later by JAX.

In [None]:
# Linear Regression
np.random.seed(1337)

def predict(p, x):
  y = p.T @ x
  return y

params = np.random.standard_normal(5)

# features, batch of data
x = np.array([1] + [2, 3, 7, 2])

# output
y = predict(params, x)

y

### MLP model signature

1-layer Dense network

In [None]:
from IPython.core.display import HTML

url = "https://www.researchgate.net/publication/221079407/figure/fig1/AS:651187686744067@1532266651725/One-layer-neural-network-and-nomenclature-employed.png"
display(HTML(f'<img src="{url}" width="500px">'))  # Adjust width as needed


In [None]:
def predict(W, b, x):
    z = W @ x + b   # Linear transformation
    a = np.maximum(0, z)  # ReLU activation
    return a

input_dim = 4  # Input features
output_dim = 1  # Number of output neurons

x = np.array([2, 3, 7, 2])

W = np.random.randn(output_dim, input_dim)  # Random weights
b = np.random.randn(output_dim, )  # Random biases

y = predict(W, b, x)

y

### MLP in JAX/Flax

Flax Model API:

1️⃣ **Define the model** (`nn.Module`, (optionally) with `setup()`)  
2️⃣ **Initialize parameters** (`model.init()`)  
3️⃣ **Run inference** (`model.apply()`)  


In [None]:
import jax.numpy as jnp
import jax

# Define the predict function using JAX
def predict(W, b, x):
    z = W @ x + b  # Linear transformation
    a = jnp.maximum(0, z)  # ReLU activation
    return a

# Define input and output dimensions
input_dim = 4
output_dim = 1

# Initialize weights and biases with random values
key = jax.random.PRNGKey(0)  # JAX requires a PRNG key for randomness
W = jax.random.normal(key, (output_dim, input_dim))  # Random weights
b = jax.random.normal(key, (output_dim, ))  # Random biases

# Define input vector
x = jnp.array([2, 3, 7, 2])

# Predict
y = predict(W, b, x)
print(y)

In [None]:
import flax.linen as nn

class SimpleNN(nn.Module):
    output_dim: int  # Number of output neurons

    def setup(self):
        self.dense = nn.Dense(self.output_dim)

    def __call__(self, x):
        z = self.dense(x)  # Linear transformation
        return nn.relu(z)  # ReLU activation

# Define input and output dimensions
input_dim = 4
output_dim = 1

# Define input vector
x = jnp.array([2, 3, 7, 2])

# Create model instance
model = SimpleNN(output_dim=output_dim)

# Initialize parameters
key = jax.random.PRNGKey(0)
params = model.init(key, jnp.ones(input_dim))  # Initialize with dummy input

# Run inference
y = model.apply(params, x)
print(y)

## Bookkeeping

In **Flax**, model parameters (`params`) are stored as a **frozen dictionary (`FrozenDict`)**, which can be **saved and loaded** using JAX serialization tools like `flax.serialization.to_bytes()` and `flax.serialization.from_bytes()`, or `pickle`/`json` for more flexibility.

**1️⃣ Save Model Weights to a File**
```python
import flax
import pickle

# Save params to a file (binary format)
with open("model_params.pkl", "wb") as f:
    pickle.dump(flax.serialization.to_bytes(params), f)
```

**2️⃣ Load Model Weights from a File**
```python
# Load params from file
with open("model_params.pkl", "rb") as f:
    params_loaded = flax.serialization.from_bytes(params, pickle.load(f))

print("Loaded Parameters:", params_loaded)
```


In [None]:
import pickle

with open("model_params.pkl", "wb") as f:
    pickle.dump(flax.serialization.to_bytes(params), f)

In [None]:
with open("model_params.pkl", "rb") as f:
    params_loaded = flax.serialization.from_bytes(params, pickle.load(f))

In [None]:
# Run inference again
y = model.apply(params_loaded, x)
print(y)

## Attention and Transformer

### Mini Transformer in Flax

In [None]:
class NanoLM(nn.Module):
    vocab_size: int
    num_layers: int = 6
    num_heads: int = 8
    head_size: int = 32
    dropout_rate: float = 0.2
    embed_size: int = 256
    block_size: int = 64

    @nn.compact
    def __call__(self, x, training: bool = True):
        seq_len = x.shape[1]

        x = nn.Embed(self.vocab_size, self.embed_size)(x) + \
            nn.Embed(self.block_size, self.embed_size)(jnp.arange(seq_len))

        for _ in range(self.num_layers):
            x_norm = nn.LayerNorm()(x)

            x = x + nn.MultiHeadDotProductAttention(
                num_heads=self.num_heads,
                qkv_features=self.head_size,
                out_features=self.head_size * self.num_heads,
                dropout_rate=self.dropout_rate,
            )(
                x_norm,
                x_norm,
                mask=jnp.tril(jnp.ones((x.shape[-2], x.shape[-2]))),
                deterministic=not training,
            )

            x = x + nn.Sequential([
                nn.Dense(4 * self.embed_size),
                nn.relu,
                nn.Dropout(self.dropout_rate, deterministic=not training),
                nn.Dense(self.embed_size),
            ])(nn.LayerNorm()(x))

        x = nn.LayerNorm()(x)

        return nn.Dense(self.vocab_size)(x)

In [None]:
# Model initialization
key = jax.random.PRNGKey(1337)
mini_transformer = NanoLM(vocab_size=100)

# Example input: batch of token sequences (batch_size=1, seq_len=10)
x = jnp.ones((1, 10), dtype=jnp.int32)

# Initialize parameters
params = mini_transformer.init(key, x)

# Forward pass
y = mini_transformer.apply(params, x, False)

y.shape