In [None]:
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

from time import process_time
from IPython import display

import jax
import jax.numpy as jnp
from flax import nnx
import optax
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as T
import numpy as np

# My own modules
import viz_utils as vu
from plot_lib import set_default
import train_utils as tu

set_default()

In [None]:
# Constants
DATA_DIR = "/Users/mghifary/Work/Code/AI/data"
MODEL_DIR = "models"

BATCH_SIZE = 64
LEARNING_RATE = 1e-3
LAMBDA_L2 = 1e-5

SEED = 42

In [None]:
# Transform for ToTensor() and normalization
transform = T.Compose([T.ToTensor()])

# Load training and test data using torchvision (easiest way to fetch)
train_data = datasets.MNIST(
    root=DATA_DIR,
    train=True,
    download=True,
    transform=transform,
)

test_data = datasets.MNIST(
    root=DATA_DIR,
    train=False,
    download=True,
    transform=transform,
)

In [None]:
num_classes = len(jnp.unique(jnp.array(train_data.targets.numpy())))
print(f"num_classes: {num_classes}")

In [None]:
# Create data loaders
train_dataloader = DataLoader(
    train_data, 
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_dataloader = DataLoader(
    test_data, 
    batch_size=BATCH_SIZE,
    shuffle=False,
)

for X_batch, y_batch in train_dataloader:
    print(f"Shape of X [N, C, H, W]: {X_batch.shape}")
    print(f"Shape of y: {y_batch.shape}, {y_batch.dtype}")
    break

In [None]:
# Show training samples
grid = vu.set_grid(X_batch.numpy()[:48], num_cells=48)
vu.show(grid)

In [None]:
class FFNet(nnx.Module):
    def __init__(self, din, dhidden, dout, rngs: nnx.Rngs):
        self.net = nnx.Sequential(
            lambda x: x.reshape((x.shape[0], -1)),
            nnx.Linear(din, dhidden, rngs=rngs),
            nnx.relu,
            nnx.Linear(dhidden, dout, rngs=rngs),
        )
    
    def __call__(self, x):
        return self.net(x)

rngs = nnx.Rngs(SEED)
model = FFNet(784, 256, num_classes, rngs=rngs)
nnx.display(model)

In [None]:
# Optimizer
optimizer = nnx.Optimizer(model, optax.adam(LEARNING_RATE))
metrics = nnx.MultiMetric(
    accuracy=nnx.metrics.Accuracy(),
    loss=nnx.metrics.Average(),
)

In [None]:
def evaluate_model(model, dataloader):
    eval_metrics = nnx.MultiMetric(
        accuracy=nnx.metrics.Accuracy(),
        loss=nnx.metrics.Average(),
    )
    tu.evaluate(model, dataloader, eval_metrics)
    return eval_metrics.compute()

In [None]:
EPOCHS = 10
for epoch in range(EPOCHS):
    # Metrics for training are managed by tu.train if it updates them
    # But let's use a fresh MultiMetric for each epoch to see per-epoch progress
    epoch_metrics = nnx.MultiMetric(
        accuracy=nnx.metrics.Accuracy(),
        loss=nnx.metrics.Average(),
    )
    
    train_loss, train_time = tu.train(model, train_dataloader, optimizer, epoch_metrics)
    train_results = epoch_metrics.compute()
    
    test_results = evaluate_model(model, test_dataloader)
    
    print(f"[Epoch {epoch+1} / {EPOCHS} with training time {train_time:.4f} secs] "
          f"Train (Acc: {train_results['accuracy']:.4f}, Loss: {train_results['loss']:.4f}), "
          f"Test (Acc: {test_results['accuracy']:.4f}, Loss: {test_results['loss']:.4f})")