# Prepare DataLoaders

In [11]:
import time
from typing import Dict, Any
from types import SimpleNamespace
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
import mlx.core as mx
import mlx.data as dx

# Convert the Hugging Face dataset to the custom format
def huggingface_to_array_of_dict(dataset):    
    return [{"image": np.array(image).copy(), "label": label}
            for label, image in zip(dataset['label'], dataset['image'])]

# Convert the Hugging Face dataset to a stream of batches
def hf_dataset_to_mlx_stream(dataset, cfg: Dict, shuffle=False):
    buffer = dx.buffer_from_vector(huggingface_to_array_of_dict(dataset))
    if shuffle:
        buffer = buffer.shuffle()
    
    return (
        buffer
        .to_stream()
        .key_transform("image", lambda x: x.astype("float32").reshape(-1) / 255)        
        .batch(cfg.batch_size)  # the value doesn't matter for now
        .prefetch(cfg.prefetch_size, cfg.num_threads) # fetch batches in background threads
    )

# Load the MNIST dataset from Hugging Face
# We’ll use only the training set for now, splitting it into two parts:
# - Training Set: 80% for training our model.
# - Validation Set: 20% for checking the model’s accuracy at each epoch.
# - Test Set: kept for totally unseen data at the final stage
ds = load_dataset("ylecun/mnist")
split_ds = ds['train'].train_test_split(test_size=0.2)
ds = {
    'train': split_ds['train'],
    'val': split_ds['test'],
    'test': ds['test']
}

cfg = SimpleNamespace()
cfg.batch_size = 4
cfg.prefetch_size = 4
cfg.num_threads = 4

# Transform the dataset to streams
train_stream = hf_dataset_to_mlx_stream(ds['train'], cfg, shuffle=True)

# Iterate on the batches
train_stream.reset()
for batch_counter, batch in enumerate(train_stream):
    (X, y) = mx.array(batch['image']), mx.array(batch['label'])

    print('The image has shape ', X[0].shape)    
    print('The image should display a ', y[0].item())
    break


The image has shape  (784,)
The image should display a  5


# Define Architecture

In [None]:
import mlx.core as mx
import mlx.nn as nn

class MLP(nn.Module):
    def __init__(self, n_inputs, n_hidden, n_outputs):
        super().__init__()
        self.layers = [
            nn.Linear(n_inputs, n_hidden),
            nn.Linear(n_hidden, n_outputs)
        ]

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = nn.relu(layer(x))
        # we don't do ReLU for last layer
        return self.layers[-1](x)

# Training Helper Classes

### Metric Class (Not That Interesting Just Move On)

In [None]:
import time
import pandas as pd
from IPython.display import display, clear_output, HTML
import mlx.core as mx

class Metrics:
    def __init__(self):
        self.epochs = {}
        self.metrics_df = pd.DataFrame(columns=["epoch", "tr_loss", "val_loss", "accuracy", "samples/s", "time"])

    def start_epoch(self, epoch):
        self.start_epoch_timer = time.perf_counter()
        self.epochs[epoch] = {}

    def start_batch(self, epoch):
        self.start_batch_timer = time.perf_counter()
        self.epochs[epoch]['batch_losses'] = []
        self.epochs[epoch]['batch_samples_per_sec'] = []

    def finish_batch(self, epoch, batch_size, loss):
        self.epochs[epoch]['batch_losses'].append(loss.item())
        self.epochs[epoch]['batch_samples_per_sec'].append(batch_size / (time.perf_counter() - self.start_batch_timer))
        # print(self.epochs[epoch]['batch_samples_per_sec'])

    def finish_epoch(self, epoch, val_acc, val_loss):
        self.epochs[epoch]['tr_loss'] = mx.mean(mx.array(self.epochs[epoch]['batch_losses']))
        self.epochs[epoch]['val_loss'] = val_loss
        self.epochs[epoch]['accuracy'] = val_acc
        self.epochs[epoch]['throughput'] = mx.mean(mx.array(self.epochs[epoch]['batch_samples_per_sec']))
        self.epochs[epoch]['time'] = time.perf_counter() - self.start_epoch_timer

    def print_metrics(self):
        epoch = len(self.epochs) - 1
        data = self.epochs[epoch]
        data = {
            'epoch': epoch,
            'tr_loss': f"{data['tr_loss'].item():.6f}",
            'val_loss': f"{data['val_loss'].item():.6f}",
            'accuracy': f"{data['accuracy'].item():.6f}",
            'samples/s': f"{int(data['throughput'].item())}",
            'time': f"{int(divmod(data['time'], 60)[0]):02}:{int(divmod(data['time'], 60)[1]):02}"
        }
        new_row = pd.DataFrame([data])
        self.metrics_df = pd.concat([self.metrics_df, new_row], ignore_index=True)
        clear_output(wait=True)
        display(HTML(self.metrics_df.to_html(index=False)))

### Learner Class

In [14]:
import mlx.core as mx
from functools import partial
from typing import Any, Dict, Tuple


class Trainer:
    def __init__(self, model, optimizer, cfg, train_stream, val_stream, loss_fn, accuracy_fn):
        self.model = model
        self.optimizer = optimizer
        self.cfg = cfg
        self.train_stream = train_stream
        self.val_stream = val_stream
        self.loss_fn = loss_fn
        self.accuracy_fn = accuracy_fn

    def one_epoch(self, epoch):
        @partial(mx.compile, inputs=self.model.state, outputs=self.model.state)
        def step(X, y):
            loss, grads = nn.value_and_grad(self.model, self.loss_fn)(self.model, X, y)
            # this version doesn't work, seems like we need to pass the model as an argument
            # and do the forward pass inside the loss function
            # preds = self.model(X)
            # loss, grads = nn.value_and_grad(self.model, self.loss_fn)(preds, y)
            self.optimizer.update(self.model, grads)
            return loss # we can't do .item() here because we use mx.compile

        for batch in self.train_stream:
            X = mx.array(batch["image"])
            y = mx.array(batch["label"])
            self.metrics.start_batch(epoch)            

            # interesting stuff
            loss = step(X, y)
            mx.eval(self.model.state, self.optimizer.state)
            
            self.metrics.finish_batch(epoch, X.shape[0], loss)    
        
    def validate_epoch(self) -> Tuple[mx.array, mx.array]: #)returns an array with a single scalar
        accs = []
        losses = []
        for batch in self.val_stream:            
            X = mx.array(batch["image"])
            y = mx.array(batch["label"])
            preds = self.model(X)
            # acc
            acc = self.accuracy_fn(preds, y)
            accs.append(acc.item()) # use .item() to get a scalar
            # loss
            preds = self.model(X)
            loss = self.loss_fn(self.model, X, y)
            losses.append(loss.item())
        acc =  mx.mean(mx.array(accs))
        assert acc.ndim == 0, acc.shape
        loss =  mx.mean(mx.array(loss))
        assert loss.ndim == 0, loss.shape
        return acc, loss


    def fit(self, cfg: Dict[str, Any]):
        self.metrics = Metrics()

        for epoch in range(cfg.epochs):
            self.train_stream.reset()
            self.val_stream.reset()

            self.metrics.start_epoch(epoch)

            self.one_epoch(epoch)
            val_acc, val_loss = self.validate_epoch()
            self.metrics.finish_epoch(epoch, val_acc, val_loss)
            self.metrics.print_metrics()


# Training Blueprint

In [15]:
import mlx.optimizers as optim

# download dataset
# from datasets import load_dataset

# Load the MNIST dataset from Hugging Face
# We’ll use only the training set for now, splitting it into two parts:
# - Training Set: 80% for training our model.
# - Validation Set: 20% for checking the model’s accuracy at each epoch.
# - Test Set: kept for totally unseen data at the final stage
# ds = load_dataset("ylecun/mnist")
# split_ds = ds['train'].train_test_split(test_size=0.2)
# ds = {
#     'train': split_ds['train'],
#     'val': split_ds['test'],
#     'test': ds['test']
# }

# Get the streams (dataloaders in pytorch)
# Eventually modify hf_dataset_to_mlx_stream to be compatible with the shapes/transforms you need
train_stream = hf_dataset_to_mlx_stream(ds['train'], cfg, shuffle=True)
val_stream = hf_dataset_to_mlx_stream(ds['val'], cfg, shuffle=False)

# loss function for SGD
# @mx.compile
def loss_fn(model: nn.Module, X: mx.array, y: mx.array):
    assert X.shape[0] == y.shape[0], (X.shape, y.shape)
    assert X.shape[1] == n_inputs, X.shape    
    preds = model(X)
    assert preds.shape == (X.shape[0], n_classes), preds.shape
    return nn.losses.cross_entropy(preds, y, reduction="mean")

# accuracy function for human metric
@mx.compile
def accuracy_fn(preds: mx.array, y: mx.array) -> mx.array:        
    assert preds.shape[0] == y.shape[0], (preds.shape, y.shape)    
    r = mx.mean(mx.argmax(preds, axis=1) == y) 
    assert r.ndim == 0, r.shape
    return r # we can't do .item() here because we use mx.compile

# HYPER PARAMS!
# Loading Param
cfg.prefetch = 4
cfg.num_threads = 8
cfg.batch_size = 256
# Training Param
cfg.lr = 1e-2
cfg.epochs = 10

# define shapes
train_stream.reset()
one_batch = next(iter(train_stream))
one_batch['image'].shape[-1]
n_inputs = one_batch['image'].shape[-1]
n_hidden = 50
n_classes = len(ds['val'].features['label'].names)

# The training execution
model = MLP(n_inputs, n_hidden, n_classes)
optimizer = optim.SGD(learning_rate=cfg.lr)
trainer = Trainer(model, optimizer, cfg, train_stream, val_stream, loss_fn=loss_fn, accuracy_fn=accuracy_fn)
trainer.fit(cfg)

# sometimes samples/s might not make sense, samples/s might be higher than the whole dataset
# don't worry, it's just because is an extrapolation per sec
# probably your epoch took less than a second

epoch,tr_loss,val_loss,accuracy,samples/s,time
0,1.902779,1.891594,0.656822,29188,00:01
1,1.461295,1.391238,0.750693,17873,00:01
2,0.967882,0.766786,0.806698,77172,00:01
3,0.689295,0.7125,0.831757,118270,00:01
4,0.585121,0.683638,0.848779,32366,00:01
5,0.268904,0.641596,0.859907,11604,00:01
6,0.530931,0.327467,0.872267,78083,00:01
7,0.476316,0.518387,0.88037,95709,00:01
8,0.426353,0.717427,0.879537,88518,00:01
9,0.391617,0.377124,0.884126,61121,00:01
