# Model

## Anatomy of the Model

Install the required packages

In [None]:
%%capture
%pip install flax

Define the imports

In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt

from functools import partial
from typing import Any, Callable

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn

### Model signature

$$ f(w; x) = \hat{y} $$
We place parameters at the first place to match the signature required later by JAX.

In [None]:
# Linear Regression
np.random.seed(1337)

def predict(w, x):
  # y = w.T @ x
  y = np.sum(w * x)
  return y


params = np.ones(5)

# features, batch of data
x = np.array([1] + [2, 3, 7, 2])

# output
y = predict(params, x)

y

### MLP model signature

1-layer Dense network

In [None]:
def predict(W, b, x):
    z = W @ x + b   # Linear transformation
    a = np.maximum(0, z)  # ReLU activation
    return a

input_dim = 4  # Input features
output_dim = 1  # Number of output neurons

x = np.array([2, 3, 7, 2])

W = np.ones((output_dim, input_dim))  # Initialize weights with all 1s
b = np.ones(output_dim, )  # Initialize biases with all 1s

y = predict(W, b, x)

y

### Why JAX?

In ML, we need more than a forward pass:

* Compute **gradients** of a loss
* Run efficiently on **GPU/TPU**
* Batch inputs
* Compile for speed

NumPy alone won‚Äôt cover this stack.

JAX extends the NumPy programming model with *program transformations*:

| We write | JAX gives |
| -------- | --------- |
| `f(x)`   | `grad(f)` |
| `f(x)`   | `jit(f)`  |
| `f(x)`   | `vmap(f)` |

We don't just evaluate functions - we **transform functions into new functions**.

Mental model: **JAX = NumPy + autodiff + XLA compilation + vectorization** üöÄ
But beware of the sharp edges: [https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)

---


In [None]:
import jax.numpy as jnp
import jax

def predict(W, b, x):
    z = W @ x + b  # Linear transformation
    a = jnp.maximum(0, z)  # ReLU activation
    return a

input_dim = 4
output_dim = 1

W = jnp.ones((output_dim, input_dim))
b = jnp.ones((output_dim, ))

x = jnp.array([2, 3, 7, 2])

y = predict(W, b, x)

print(y)

Now compare with jitted version.

In [None]:
# Now we jit!
jitted_predict = jax.jit(predict)

In [None]:
import time

# --- setup ---
input_dim, output_dim = 4, 1
W = jnp.ones((output_dim, input_dim))
b = jnp.ones((output_dim,))
x = jnp.array([2, 3, 7, 2])

# --- warmup (compilation happens here) ---
y0 = jitted_predict(W, b, x).block_until_ready()

# --- timing helpers ---
def time_call(f, n=1_000):
    # run once to avoid first-call overhead (except compilation, already done above)
    f(W, b, x).block_until_ready()

    t0 = time.perf_counter()
    for _ in range(n):
        y = f(W, b, x)
    y.block_until_ready()  # sync once at end (important on GPU/async backends)
    t1 = time.perf_counter()
    return (t1 - t0) / n

t_eager = time_call(predict)
t_jit   = time_call(jitted_predict)

print(f"eager avg: {t_eager*1e6:.2f} microsec /call")
print(f"jit   avg: {t_jit*1e6:.2f} microsec /call")
print(f"speedup: {t_eager/t_jit:.1f}x")

No backward pass written manually.
No gradient accumulation stored in parameters.
Just **pure functions + transformations**.

### Why Flax?

JAX gives us transformations.

But building large neural networks directly in raw JAX quickly becomes messy:

* Where do parameters live?
* How do we initialize them?
* How do we manage dropout RNG?
* How do we save / load models?

Flax solves this.

Flax is a **neural network library built on JAX**.

It gives:

* Structured modules (`nn.Module`)
* Parameter initialization via `.init()`
* Explicit parameter passing via `.apply()`
* Clean parameter trees (pytrees)
* Compatibility with `grad`, `jit`, `vmap`


The key design principle: in Flax, a model is still a **pure function**.

Model API:

1Ô∏è‚É£ **Define the model** (`nn.Module`, (optionally) with `setup()`)  
2Ô∏è‚É£ **Initialize parameters** (`model.init()`)  
3Ô∏è‚É£ **Run inference** (`model.apply()`)  


Gradients are obtained by differentiating a loss built from `.apply()`.

In [None]:
import flax.linen as nn

class VanillaMLP(nn.Module):
    output_dim: int

    @nn.compact
    def __call__(self, x):
        # linear layer
        x = nn.Dense(
            self.output_dim,
            kernel_init=nn.initializers.ones,
            bias_init=nn.initializers.ones,
        )(x)

        # activation function
        x = nn.relu(x)

        return x

In [None]:
input_dim = 4
output_dim = 1

x = jnp.array([2, 3, 7, 2])

model = VanillaMLP(output_dim=output_dim)

key = jax.random.PRNGKey(0)
params = model.init(key, jnp.ones(input_dim))

In [None]:
from flax.core import freeze, unfreeze

print(jax.tree_util.tree_map(lambda x: x.shape, params))

In [None]:
y = model.apply(params, x)

print(y)

## Summary: NumPy -> JAX -> Flax

Your NumPy model:

* Computes a forward pass

JAX:

* Makes it differentiable and fast

Flax:

* Makes it scalable and maintainable

Together, they let us build research-grade training pipelines üöÄ

## Bookkeeping

In **Flax**, model parameters (`params`) are stored as a **frozen dictionary (`FrozenDict`)**, which can be **saved and loaded** using JAX serialization tools like `flax.serialization.to_bytes()` and `flax.serialization.from_bytes()`, or `pickle`/`json` for more flexibility.

**1Ô∏è‚É£ Save Model Weights to a File**
```python
import flax
import pickle

# Save params to a file (binary format)
with open("model_params.pkl", "wb") as f:
    pickle.dump(flax.serialization.to_bytes(params), f)
```

**2Ô∏è‚É£ Load Model Weights from a File**
```python
# Load params from file
with open("model_params.pkl", "rb") as f:
    params_loaded = flax.serialization.from_bytes(params, pickle.load(f))

print("Loaded Parameters:", params_loaded)
```


In [None]:
import pickle

with open("model_params.pkl", "wb") as f:
    pickle.dump(flax.serialization.to_bytes(params), f)

In [None]:
with open("model_params.pkl", "rb") as f:
    params_loaded = flax.serialization.from_bytes(params, pickle.load(f))

In [None]:
# Run inference again
y = model.apply(params_loaded, x)
print(y)

## Pre-trained Models

Optional: this is HuggingFace in PyTorch for inference only; training in this course is JAX/Flax.

In [None]:
%%capture
%pip install transformers

In [None]:
from IPython.display import display, Image
import requests
from PIL import Image as PILImage
from io import BytesIO

url = "http://images.cocodataset.org/val2017/000000039769.jpg"

# Fetch and display the image
response = requests.get(url)
img = PILImage.open(BytesIO(response.content))
display(img)

Model: ViT https://huggingface.co/docs/transformers/en/model_doc/vit

Trained on ImageNet: https://paperswithcode.com/dataset/imagenet

In [None]:
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests

# Get the image from the web
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

# Load preprocessor and model
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

# Run the inference engine
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)

In [None]:
type(outputs)

In [None]:
logits = outputs.logits
logits.shape

In [None]:
# Top-1
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

In [None]:
import torch

# Get the top-10 predictions
top_10 = torch.topk(logits, 10)

# Extract top indices and their corresponding scores
top_10_indices = top_10.indices[0].tolist()
top_10_scores = top_10.values[0].tolist()

# Display results
print("Top-10 Predicted Classes:")
for rank, (idx, score) in enumerate(zip(top_10_indices, top_10_scores), start=1):
    print(f"{rank}. {model.config.id2label[idx]} ({score:.4f})")