# Training Loop

## Training an MLP model

Install the required packages

In [None]:
%%capture
%pip install flax wandb tensorboardX tiktoken

Define the imports

In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets as skdata
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import optax
import flax
from flax import linen as nn

Load and preprocess the Iris dataset.

In [None]:
iris = skdata.load_iris()
X = iris.data  # shape: (150, 4)
y = iris.target  # Labels: 0, 1, 2

scaler = StandardScaler()
X = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=1337,
)

# Convert to JAX arrays
X_train = jnp.array(X_train)
y_train = jnp.array(y_train)

Define a model

In [None]:
class MLPClassifierSmall(nn.Module):
    num_classes: int

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        x = nn.Dense(8)(x)
        x = nn.relu(x)
        x = nn.Dense(16)(x)
        x = nn.relu(x)
        x = nn.Dense(8)(x)
        x = nn.relu(x)
        x = nn.Dense(self.num_classes)(x)

        return x

Finally, run a script!

In [None]:
# HPs
num_epochs = 100
batch_size = 16
learning_rate = 1e-3
num_classes = 3
input_features = X.shape[1]

# Initialize the model
rng = jax.random.PRNGKey(0)
model = MLPClassifierSmall(num_classes=num_classes)
params = model.init(rng, jnp.ones((1, input_features)))

# Set up the optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

# Define the loss function
def loss_fn(params, x, y):
    logits = model.apply(params, x)
    one_hot = jax.nn.one_hot(y, num_classes)
    loss = optax.softmax_cross_entropy(logits, one_hot).mean()
    return loss

@jax.jit
def accuracy(params, x, y):
    logits = model.apply(params, x)
    predicted_classes = jnp.argmax(logits, axis=1)
    correct_predictions = predicted_classes == y
    return jnp.mean(correct_predictions)


# A single update step
@jax.jit
def update(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

num_train = X_train.shape[0]
num_test = X_test.shape[0]

train_losses = []
test_losses = []

print(f"Accuracy before training: {accuracy(params, X_test, y_test)}")

# Training loop!
for epoch in range(num_epochs):
    # Shuffle training data
    permutation = jax.random.permutation(rng, num_train)
    X_train_shuffled = X_train[permutation]
    y_train_shuffled = y_train[permutation]

    epoch_train_loss = 0.0

    # Process training batches
    for i in range(0, num_train, batch_size):
        batch_x = X_train_shuffled[i:i+batch_size]
        batch_y = y_train_shuffled[i:i+batch_size]
        params, opt_state, loss = update(params, opt_state, batch_x, batch_y)
        epoch_train_loss += loss * batch_x.shape[0]

    epoch_train_loss /= num_train
    train_losses.append(float(epoch_train_loss))

print(f"Accuracy after training: {accuracy(params, X_test, y_test)}")

# Plot training vs testing loss.
plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs+1), train_losses, label="Train Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()
plt.show()

## Tracking

#### tqdm - your friendly neighborhood progress bar

`tqdm` is lightweight a Python package that provides **fast, extensible progress bars** for loops and iterative processes—**extremely useful in ML workflows** for tracking training, data loading, and hyperparameter tuning.  

✅ **Real-time feedback** → See how long each epoch/batch takes.  
✅ **ETA estimation** → Know how much time is left for training.  
✅ **Seamless integration** → Works with **loops, DataLoaders, and multiprocessing**.  
✅ **Minimal performance overhead** → Negligible impact on computation time.

In [None]:
#  ------ ONE LINE OF CODE HERE ------
from tqdm.notebook import tqdm

# HPs
num_epochs = 100
batch_size = 16
learning_rate = 1e-3
num_classes = 3
input_features = X.shape[1]

# Initialize the model
rng = jax.random.PRNGKey(0)
model = MLPClassifierSmall(num_classes=num_classes)
params = model.init(rng, jnp.ones((1, input_features)))

# Set up the optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

# Define the loss function
def loss_fn(params, x, y):
    logits = model.apply(params, x)
    one_hot = jax.nn.one_hot(y, num_classes)
    loss = optax.softmax_cross_entropy(logits, one_hot).mean()
    return loss

@jax.jit
def accuracy(params, x, y):
    logits = model.apply(params, x)
    predicted_classes = jnp.argmax(logits, axis=1)
    correct_predictions = predicted_classes == y
    return jnp.mean(correct_predictions)


# A single update step
@jax.jit
def update(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

num_train = X_train.shape[0]
num_test = X_test.shape[0]

train_losses = []
test_losses = []

print(f"Accuracy before training: {accuracy(params, X_test, y_test)}")

# Training loop!
#  ------ AND HERE ------
for epoch in tqdm(range(num_epochs)):
    # Shuffle training data
    permutation = jax.random.permutation(rng, num_train)
    X_train_shuffled = X_train[permutation]
    y_train_shuffled = y_train[permutation]

    epoch_train_loss = 0.0

    # Process training batches
    for i in range(0, num_train, batch_size):
        batch_x = X_train_shuffled[i:i+batch_size]
        batch_y = y_train_shuffled[i:i+batch_size]
        params, opt_state, loss = update(params, opt_state, batch_x, batch_y)
        epoch_train_loss += loss * batch_x.shape[0]

    epoch_train_loss /= num_train
    train_losses.append(float(epoch_train_loss))

print(f"Accuracy after training: {accuracy(params, X_test, y_test)}")

#### TensorBoard and Weights & Biases

**TensorBoard** is a visualization toolkit originally developed for TensorFlow but now widely used across ML frameworks. It helps you monitor, debug, and optimize your models. Here's why and how to use it.


**How to use TensorBoard:**

1. **Logging:**  
   In your training loop, log scalar values (like loss and accuracy), histograms, images, or even model graphs. For example, using tensorboardX’s SummaryWriter (or similar for other frameworks):

   ```python
   from tensorboardX import SummaryWriter
   writer = SummaryWriter(log_dir="mle4r")
   
   # Log a scalar value (convert JAX arrays to float if needed)
   writer.add_scalar("Loss/Train", float(train_loss), epoch)
   writer.add_scalar("Accuracy/Test", float(test_acc), epoch)
   ```

2. **Launching TensorBoard:**  
   From the command line, run:
   ```bash
   tensorboard --logdir=runs
   ```
   Then open the provided URL in a browser to view your metrics.

**Weights & Biases (W&B)** is a powerful experiment tracking and collaboration tool for machine learning. It helps you log metrics, visualize training curves, manage hyperparameters, and compare different runs, all in one centralized dashboard.

**How to use W&B:**

1. **Initialization:**  
   At the beginning of your training script, initialize a run with your project name and configuration. For example:
   
   ```python
   import wandb

   # Initialize a new run
   wandb.init(
    project="YOUR_PROJECT",
    
    config={
       "num_epochs": 100,
       "batch_size": 16,
       "learning_rate": 1e-3,
       "num_classes": 3,
       "input_features": X.shape[1],  # assuming X is defined
   })
   ```

2. **Logging Metrics:**  
   In your training loop, log key metrics (like loss, accuracy, etc.) by calling `wandb.log()`. You can log metrics every epoch or even every batch:
   
   ```python
   # Inside your training loop:
   wandb.log({
       "epoch": epoch,
       "train_loss": float(epoch_train_loss),
       "test_loss": float(epoch_test_loss),
       "test_accuracy": float(test_acc)
   })
   ```

3. **Logging Artifacts and Visualizations:**  
   W&B allows you to log model artifacts (like trained weights or model files) and visualizations (images, plots, etc.). For example, you might save a plot of training vs. validation loss or upload the model checkpoint.

4. **Hyperparameter Sweeps:**  
   You can set up sweeps to automatically search through hyperparameter combinations. This helps in automating experiment tracking and finding the best configuration.

5. **Dashboard:**  
   Once your script is running, you can visit your W&B dashboard in a web browser to see real-time charts, compare different runs, and drill down into the details of each experiment.




To use W&B you should create and account first and use the **W&B API key**.

Note: save the key to: Note: Colab -> Secrets

In [None]:
import os
from google.colab import userdata

os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')

In [None]:
!wandb login $WANDB_API_KEY

Run the full scipt

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from tqdm.notebook import tqdm
import wandb
from tensorboardX import SummaryWriter

# HPs
num_epochs = 100
batch_size = 16
learning_rate = 1e-3
num_classes = 3
input_features = X.shape[1]

# Initialize Weights & Biases
wandb.init(
    entity="dysco",
    project="mle4r",

    sync_tensorboard=True,

    config={
      "num_epochs": num_epochs,
      "batch_size": batch_size,
      "learning_rate": learning_rate,
      "num_classes": num_classes,
      "input_features": input_features,
})


# Initialize TensorBoardX SummaryWriter
writer = SummaryWriter(log_dir="mle4r")

# Initialize the model
rng = jax.random.PRNGKey(0)
model = MLPClassifierSmall(num_classes=num_classes)
params = model.init(rng, jnp.ones((1, input_features)))

# Set up the optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

# Define the loss function
def loss_fn(params, x, y):
    logits = model.apply(params, x)
    one_hot = jax.nn.one_hot(y, num_classes)
    loss = optax.softmax_cross_entropy(logits, one_hot).mean()
    return loss

# Accuracy function
@jax.jit
def accuracy(params, x, y):
    logits = model.apply(params, x)
    predicted_classes = jnp.argmax(logits, axis=1)
    correct_predictions = predicted_classes == y
    return jnp.mean(correct_predictions)

# A single update step
@jax.jit
def update(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Evaluation function (for test loss)
@jax.jit
def eval_step(params, x, y):
    return loss_fn(params, x, y)

num_train = X_train.shape[0]
num_test = X_test.shape[0]

train_losses = []
test_losses = []

# Log initial test accuracy
init_test_acc = accuracy(params, X_test, y_test)
print(f"Accuracy before training: {init_test_acc:.4f}")

writer.add_scalar("Test/Accuracy", init_test_acc, 0)

# Training loop!
for epoch in tqdm(range(1, num_epochs + 1)):
    # Shuffle training data
    permutation = jax.random.permutation(rng, num_train)
    X_train_shuffled = X_train[permutation]
    y_train_shuffled = y_train[permutation]

    epoch_train_loss = 0.0

    # Process training batches
    for i in range(0, num_train, batch_size):
        batch_x = X_train_shuffled[i:i+batch_size]
        batch_y = y_train_shuffled[i:i+batch_size]
        params, opt_state, loss = update(params, opt_state, batch_x, batch_y)
        epoch_train_loss += loss * batch_x.shape[0]

    epoch_train_loss /= num_train
    train_losses.append(float(epoch_train_loss))

    # Evaluate on test data (loss)
    epoch_test_loss = loss_fn(params, X_test, y_test)
    test_losses.append(float(epoch_test_loss))

    # Compute test accuracy
    test_acc = accuracy(params, X_test, y_test)

    # Logging to TensorBoard
    writer.add_scalar("Train/Loss", epoch_train_loss, epoch)
    writer.add_scalar("Test/Loss", epoch_test_loss, epoch)
    writer.add_scalar("Test/Accuracy", test_acc, epoch)

    # print(f"Epoch {epoch:03d}: Train Loss: {epoch_train_loss:.4f}, Test Loss: {epoch_test_loss:.4f}, Test Acc: {test_acc:.4f}")

print(f"Accuracy after training: {accuracy(params, X_test, y_test):.4f}")

# Close the TensorBoard writer when done
writer.close()
wandb.finish()

You can also analyze runs in TensorBoard directly from **local** data.

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir=mle4r

In [None]:
# !kill 7436

## From Training Loop To Training Script

#### Wrapping HPs as scripts arguments

`argparse` is a standard Python library that lets you define command line arguments so you can configure your training loop (or any script) without hardcoding hyperparameters. This makes your script flexible and easier to run with different configurations.

In [None]:
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Train a small MLP classifier.")

    parser.add_argument("--num_epochs", type=int, default=100, help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
    parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate")

    # !! Colab Fix: ignore unknown arguments
    args, unknown = parser.parse_known_args()

    return args

if __name__ == "__main__":
    args = parse_args()

    print(args.num_epochs)
    print(args.batch_size)
    print(args.learning_rate)

#### Mini-Transformer on tiny_shakespeare

Step 1. Download from GitHub

In [None]:
# %cd /content/

In [None]:
%%capture

# clean previous files
!rm -rf mle4r-winter25

# add new ones - EGN
!git clone https://github.com/cor3bit/mle4r-winter25.git

Step 2. Let's visualize the script

See in Colab Editor

In [None]:
# %cd mle4r-winter25/scripts
# %ls

Step 3. Run the script

In [None]:
!python mle4r-winter25/scripts/train_lm.py --learning-rate 0.001

#### HPOpt: Grid Search with `subprocess` scripting

Run the training script with different values of HPs

In [None]:
import subprocess
from tqdm import tqdm

# Define the hyper-parameter grid
batch_size_options = [64, 128]
learning_rate_options = [0.001, 0.0001]

# Loop over combinations
for batch_size in batch_size_options:
    for lr in learning_rate_options:
        cmd = [
            "python", "mle4r-winter25/scripts/train_lm.py",
            "--batch-size", str(batch_size),
            "--learning-rate", str(lr),
        ]
        print("Running:", " ".join(cmd))

        process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

        # Print output line by line in real time
        # for line in process.stdout:
        #     print(line, end='')
        process.wait()