# Training Loop

## Training an MLP model

Install the required packages

In [None]:
%%capture
%pip install flax

Define the imports

In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets as skdata
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import optax
import flax
from flax import linen as nn

Load and preprocess the Iris dataset.

In [None]:
iris = skdata.load_iris()
X = iris.data  # shape: (150, 4)
y = iris.target  # Labels: 0, 1, 2

scaler = StandardScaler()
X = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=1337,
)

# Convert to JAX arrays
X_train = jnp.array(X_train)
y_train = jnp.array(y_train)

In [None]:
X_train.shape

Define a model

In [None]:
class MLPClassifierSmall(nn.Module):
    num_classes: int

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        x = nn.Dense(8)(x)
        x = nn.relu(x)
        x = nn.Dense(16)(x)
        x = nn.relu(x)
        x = nn.Dense(8)(x)
        x = nn.relu(x)
        x = nn.Dense(self.num_classes)(x)

        return x

We are now instantiating the training loop "state":
- params: model weights
- opt_state: optimizer internal state
- rng: randomness (shuffling, dropout, etc.)
- step/epoch counters


Finally, run a script!

In [None]:
# HPs
num_epochs = 100
batch_size = 16
learning_rate = 1e-3
num_classes = 3
input_features = X.shape[1]

# Initialize the model
rng = jax.random.PRNGKey(0)
model = MLPClassifierSmall(num_classes=num_classes)
params = model.init(rng, jnp.ones((1, input_features)))

# Set up the optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

# Define the loss function
def loss_fn(params, x, y):
    logits = model.apply(params, x)
    one_hot = jax.nn.one_hot(y, num_classes)
    loss = optax.softmax_cross_entropy(logits, one_hot).mean()
    return loss

@jax.jit
def accuracy(params, x, y):
    logits = model.apply(params, x)
    predicted_classes = jnp.argmax(logits, axis=1)
    correct_predictions = predicted_classes == y
    return jnp.mean(correct_predictions)


# A single update step
@jax.jit
def update(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

num_train = X_train.shape[0]
num_test = X_test.shape[0]

train_losses = []
test_losses = []

print(f"Accuracy before training: {accuracy(params, X_test, y_test)}")

# Training loop!
for epoch in range(num_epochs):
    # Shuffle training data
    rng, perm_key = jax.random.split(rng)
    permutation = jax.random.permutation(perm_key, num_train)
    X_train_shuffled = X_train[permutation]
    y_train_shuffled = y_train[permutation]

    epoch_train_loss = 0.0

    # Process training batches
    for i in range(0, num_train, batch_size):
        batch_x = X_train_shuffled[i:i+batch_size]
        batch_y = y_train_shuffled[i:i+batch_size]
        params, opt_state, loss = update(params, opt_state, batch_x, batch_y)
        epoch_train_loss += loss * batch_x.shape[0]

    epoch_train_loss /= num_train
    train_losses.append(float(epoch_train_loss))

print(f"Accuracy after training: {accuracy(params, X_test, y_test)}")

# Plot training vs testing loss.
plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs+1), train_losses, label="Train Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()
plt.show()