# Model

## Anatomy of the Model

Install the required packages

In [None]:
%%capture
%pip install flax

Define the imports

In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt

from functools import partial
from typing import Any, Callable

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn

### Model signature

$$ f(w; x) = \hat{y} $$
We place parameters at the first place to match the signature required later by JAX.

In [None]:
# Linear Regression
np.random.seed(1337)

def predict(w, x):
  # y = w.T @ x
  y = np.sum(w * x)
  return y

params = np.ones(5)

# features, batch of data
x = np.array([1] + [2, 3, 7, 2])

# output
y = predict(params, x)

y

### MLP model signature

1-layer Dense network

In [None]:
def predict(W, b, x):
    z = W @ x + b   # Linear transformation
    a = np.maximum(0, z)  # ReLU activation
    return a

input_dim = 4  # Input features
output_dim = 1  # Number of output neurons

x = np.array([2, 3, 7, 2])

W = np.ones((output_dim, input_dim))  # Initialize weights with all 1s
b = np.ones(output_dim, )  # Initialize biases with all 1s

y = predict(W, b, x)

y

### MLP in JAX and Flax

JAX as NumPy on steroids!
But beware: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

In [None]:
import jax.numpy as jnp
import jax

def predict(W, b, x):
    z = W @ x + b  # Linear transformation
    a = jnp.maximum(0, z)  # ReLU activation
    return a

input_dim = 4
output_dim = 1

W = jnp.ones((output_dim, input_dim))
b = jnp.ones((output_dim, ))

x = jnp.array([2, 3, 7, 2])

y = predict(W, b, x)

print(y)

Flax - a library for Neural Networks in JAX

Flax Model API:

1️⃣ **Define the model** (`nn.Module`, (optionally) with `setup()`)  
2️⃣ **Initialize parameters** (`model.init()`)  
3️⃣ **Run inference** (`model.apply()`)  


In [None]:
import flax.linen as nn

class VanillaMLP(nn.Module):
    output_dim: int

    def setup(self):
        self.dense = nn.Dense(
            self.output_dim,
            kernel_init=lambda key, shape, dtype: jnp.ones(shape, dtype), # just to match the init of previous models
            bias_init=lambda key, shape, dtype: jnp.ones(shape, dtype), # just to match the init of previous models
        )

    def __call__(self, x):
        z = self.dense(x)  # Linear transformation
        return nn.relu(z)  # ReLU activation

input_dim = 4
output_dim = 1

x = jnp.array([2, 3, 7, 2])

model = VanillaMLP(output_dim=output_dim)

key = jax.random.PRNGKey(0)
params = model.init(key, jnp.ones(input_dim))

y = model.apply(params, x)

print(y)

## Bookkeeping

In **Flax**, model parameters (`params`) are stored as a **frozen dictionary (`FrozenDict`)**, which can be **saved and loaded** using JAX serialization tools like `flax.serialization.to_bytes()` and `flax.serialization.from_bytes()`, or `pickle`/`json` for more flexibility.

**1️⃣ Save Model Weights to a File**
```python
import flax
import pickle

# Save params to a file (binary format)
with open("model_params.pkl", "wb") as f:
    pickle.dump(flax.serialization.to_bytes(params), f)
```

**2️⃣ Load Model Weights from a File**
```python
# Load params from file
with open("model_params.pkl", "rb") as f:
    params_loaded = flax.serialization.from_bytes(params, pickle.load(f))

print("Loaded Parameters:", params_loaded)
```


In [None]:
import pickle

with open("model_params.pkl", "wb") as f:
    pickle.dump(flax.serialization.to_bytes(params), f)

In [None]:
with open("model_params.pkl", "rb") as f:
    params_loaded = flax.serialization.from_bytes(params, pickle.load(f))

In [None]:
# Run inference again
y = model.apply(params_loaded, x)
print(y)

## Pre-trained Models

In [None]:
%%capture
%pip install transformers

In [None]:
from IPython.display import display, Image
import requests
from PIL import Image as PILImage
from io import BytesIO

url = "http://images.cocodataset.org/val2017/000000039769.jpg"

# Fetch and display the image
response = requests.get(url)
img = PILImage.open(BytesIO(response.content))
display(img)

Model: ViT https://huggingface.co/docs/transformers/en/model_doc/vit

Trained on ImageNet: https://paperswithcode.com/dataset/imagenet

In [None]:
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests

# Get the image from the web
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

# Load preprocessor and model
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

# Run the inference engine
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)

In [None]:
type(outputs)

In [None]:
logits = outputs.logits
logits.shape

In [None]:
# Top-1
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

In [None]:
import torch

# Get the top-10 predictions
top_10 = torch.topk(logits, 10)

# Extract top indices and their corresponding scores
top_10_indices = top_10.indices[0].tolist()
top_10_scores = top_10.values[0].tolist()

# Display results
print("Top-10 Predicted Classes:")
for rank, (idx, score) in enumerate(zip(top_10_indices, top_10_scores), start=1):
    print(f"{rank}. {model.config.id2label[idx]} ({score:.4f})")