In [8]:
import os
import gzip
import time
import torch
import pickle
import argparse
import numpy as np
import mlx.nn as nn
import mlx.core as mx
from urllib import request
import mlx.optimizers as optim

# Let's load the dataset

In [6]:
def mnist(save_dir="./data"):
    def download_and_save(save_file):
        base_url = "http://yann.lecun.com/exdb/mnist/"
        filename = [
                    ["training_images", "train-images-idx3-ubyte.gz"],
                    ["test_images", "t10k-images-idx3-ubyte.gz"],
                    ["training_labels", "train-labels-idx1-ubyte.gz"],
                    ["test_labels", "t10k-labels-idx1-ubyte.gz"],
                    ]

        mnist = {}
        for name in filename:
            out_file = os.path.join(save_dir, name[1])
            request.urlretrieve(base_url + name[1], out_file)
        for name in filename[:2]:
            out_file = os.path.join(save_dir, name[1])
            with gzip.open(out_file, "rb") as f:
                mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
                    -1, 28 * 28
                )
        for name in filename[-2:]:
            out_file = os.path.join(save_dir, name[1])
            with gzip.open(out_file, "rb") as f:
                mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
        with open(save_file, "wb") as f:
            pickle.dump(mnist, f)

    save_file = os.path.join(save_dir, "mnist.pkl")
    if not os.path.exists(save_file):
        download_and_save(save_file)
    with open(save_file, "rb") as f:
        mnist = pickle.load(f)

    preproc = lambda x: x.astype(np.float32) / 255.0
    mnist["training_images"] = preproc(mnist["training_images"])
    mnist["test_images"] = preproc(mnist["test_images"])
    return (
        mnist["training_images"],
        mnist["training_labels"].astype(np.uint32),
        mnist["test_images"],
        mnist["test_labels"].astype(np.uint32),
    )

train_x, train_y, test_x, test_y = mnist()
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
assert train_y.shape == (60000,), "Wrong training set size"
assert test_x.shape == (10000, 28 * 28), "Wrong test set size"
assert test_y.shape == (10000,), "Wrong test set size"

# Now build a simple MLP model

In [11]:
class MLP(nn.Module):
    def __init__(
                self, 
                num_layers: int, 
                input_dim: int, 
                hidden_dim: int, 
                output_dim: int
                ):
        super().__init__()
        layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
        self.layers = [
            nn.Linear(idim, odim)
            for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
        ]

    def __call__(self, x):
        for l in self.layers[:-1]:
            x = mx.maximum(l(x), 0.0)
        return self.layers[-1](x)


def loss_fn(model, X, y):
    return mx.mean(nn.losses.cross_entropy(model(X), y))


def eval_fn(model, X, y):
    return mx.mean(mx.argmax(model(X), axis=1) == y)


def batch_iterate(batch_size, X, y):
    perm = mx.array(np.random.permutation(y.size))
    for s in range(0, y.size, batch_size):
        ids = perm[s : s + batch_size]
        yield X[ids], y[ids]


def main():
    seed = 0
    num_layers = 2
    hidden_dim = 32
    num_classes = 10
    batch_size = 256
    num_epochs = 10
    learning_rate = 1e-1

    np.random.seed(seed)

    # Load the data
    train_images, train_labels, test_images, test_labels = map(mx.array, mnist())

    # Load the model
    model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
    mx.eval(model.parameters())

    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
    optimizer = optim.SGD(learning_rate=learning_rate)

    for e in range(num_epochs):
        tic = time.perf_counter()
        for X, y in batch_iterate(batch_size, train_images, train_labels):
            loss, grads = loss_and_grad_fn(model, X, y)
            optimizer.update(model, grads)
            mx.eval(model.parameters(), optimizer.state)
        accuracy = eval_fn(model, test_images, test_labels)
        toc = time.perf_counter()
        print(
            f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
            f" Time {toc - tic:.3f} (s)"
        )

In [12]:
mx.set_default_device(mx.cpu)
main()

Epoch 0: Test accuracy 0.866, Time 0.138 (s)
Epoch 1: Test accuracy 0.896, Time 0.124 (s)
Epoch 2: Test accuracy 0.917, Time 0.126 (s)
Epoch 3: Test accuracy 0.928, Time 0.128 (s)
Epoch 4: Test accuracy 0.935, Time 0.131 (s)
Epoch 5: Test accuracy 0.938, Time 0.128 (s)
Epoch 6: Test accuracy 0.946, Time 0.132 (s)
Epoch 7: Test accuracy 0.941, Time 0.133 (s)
Epoch 8: Test accuracy 0.950, Time 0.133 (s)
Epoch 9: Test accuracy 0.948, Time 0.128 (s)
