In [2]:
import argparse
import time
from functools import partial

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np

from tensorflow.keras.datasets import mnist as mnist_dataset

In [3]:
mx.set_default_device(mx.gpu)

In [4]:
class MLP(nn.Module):
    """A simple MLP."""

    def __init__(
        self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
    ):
        super().__init__()
        layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
        self.layers = [
            nn.Linear(idim, odim)
            for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
        ]

    def __call__(self, x):
        for l in self.layers[:-1]:
            x = nn.relu(l(x))
        return self.layers[-1](x)

In [5]:
def loss_fn(model, X, y):
    return nn.losses.cross_entropy(model(X), y, reduction="mean")

def batch_iterate(batch_size, X, y):
    perm = mx.array(np.random.permutation(y.size))
    for s in range(0, y.size, batch_size):
        ids = perm[s : s + batch_size]
        yield X[ids], y[ids]

In [6]:
seed = 0
num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 10
learning_rate = 1e-1

np.random.seed(seed)

In [7]:
# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist_dataset.load_data()

# Normalize image pixel values by dividing by 255 (grayscale)
gray_scale = 255

x_train = x_train.astype('float32') / gray_scale
x_test = x_test.astype('float32') / gray_scale

# Reshape from ({examples}, 28, 28) into ({examples, 28 * 28)
x_train = x_train.reshape(-1, 28 * 28)
x_test = x_test.reshape(-1, 28 * 28)

# Checking the shape of feature and target matrices
print("Feature matrix (x_train):", x_train.shape)
print("Target matrix (y_train):", y_train.shape)
print("Feature matrix (x_test):", x_test.shape)
print("Target matrix (y_test):", y_test.shape)


assert x_train.shape == (60000, 28 * 28), "Wrong training set size"
assert y_train.shape == (60000,), "Wrong training set size"
assert x_test.shape == (10000, 28 * 28), "Wrong test set size"
assert y_test.shape == (10000,), "Wrong test set size"


# Convert TensortFlow array into MLX array
train_images, train_labels, test_images, test_labels = mx.array(x_train), mx.array(y_train), mx.array(x_test), mx.array(y_test)

Feature matrix (x_train): (60000, 784)
Target matrix (y_train): (60000,)
Feature matrix (x_test): (10000, 784)
Target matrix (y_test): (10000,)


In [8]:
# Load the model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())

# Use Stochastic Gradient Descent (SGD) optimizer
optimizer = optim.SGD(learning_rate=learning_rate)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)


In [9]:
@partial(mx.compile, inputs=model.state, outputs=model.state)
def step(X, y):
    loss, grads = loss_and_grad_fn(model, X, y)
    optimizer.update(model, grads)
    return loss

@partial(mx.compile, inputs=model.state)
def eval_fn(X, y):
    return mx.mean(mx.argmax(model(X), axis=1) == y)

for e in range(num_epochs):
    tic = time.perf_counter()
    for X, y in batch_iterate(batch_size, train_images, train_labels):
        step(X, y)
        mx.eval(model.state)
    accuracy = eval_fn(test_images, test_labels)
    toc = time.perf_counter()
    print(
        f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
        f" Time {toc - tic:.3f} (s)"
    )
    

Epoch 0: Test accuracy 0.834, Time 0.225 (s)
Epoch 1: Test accuracy 0.898, Time 0.193 (s)
Epoch 2: Test accuracy 0.911, Time 0.156 (s)
Epoch 3: Test accuracy 0.927, Time 0.153 (s)
Epoch 4: Test accuracy 0.934, Time 0.152 (s)
Epoch 5: Test accuracy 0.923, Time 0.158 (s)
Epoch 6: Test accuracy 0.942, Time 0.156 (s)
Epoch 7: Test accuracy 0.940, Time 0.154 (s)
Epoch 8: Test accuracy 0.950, Time 0.153 (s)
Epoch 9: Test accuracy 0.950, Time 0.152 (s)
