In [2]:
import os, time, math, random
import numpy as np
import pandas as pd

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T

# MLX
import mlx.core as mx
import mlx.nn as mnn
import mlx.optimizers as moptim
from functools import partial

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Benchmark knobs
EPOCHS = 10
BATCH_SIZE = 256
LR = 0.1
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4

NUM_WORKERS = 2  # only used if you decide to use DataLoader (we won't, for fairness)
PRINT_EVERY = 200  # batches

# CIFAR-10 normalization (common)
CIFAR_MEAN = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32)
CIFAR_STD  = np.array([0.2470, 0.2435, 0.2616], dtype=np.float32)

def now():
    return time.perf_counter()

def torch_sync(device: str):
    # For accurate timing; MPS is async.
    if device == "mps":
        torch.mps.synchronize()

def mlx_sync(*arrays):
    # MLX is lazy; force compute.
    mx.eval(*arrays)

In [3]:
# Download once using torchvision, then store as uint8 in RAM.
train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True)
test_set  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True)

def dataset_to_numpy(ds):
    # ds.data is uint8 [N,H,W,C] already for CIFAR10 in torchvision
    x = ds.data  # uint8 NHWC
    y = np.array(ds.targets, dtype=np.int64)
    return x, y

x_train_nhwc_u8, y_train = dataset_to_numpy(train_set)
x_test_nhwc_u8,  y_test  = dataset_to_numpy(test_set)

# Torch wants NCHW float32
x_train_nchw = np.transpose(x_train_nhwc_u8, (0,3,1,2)).astype(np.float32)
x_test_nchw  = np.transpose(x_test_nhwc_u8,  (0,3,1,2)).astype(np.float32)

# MLX wants NHWC float32
x_train_nhwc = x_train_nhwc_u8.astype(np.float32)
x_test_nhwc  = x_test_nhwc_u8.astype(np.float32)

print("Train:", x_train_nchw.shape, y_train.shape, "| Test:", x_test_nchw.shape, y_test.shape)

  entry = pickle.load(f, encoding="latin1")


Train: (50000, 3, 32, 32) (50000,) | Test: (10000, 3, 32, 32) (10000,)


In [4]:
def make_batches_indices(n, batch_size, shuffle=True, seed=SEED):
    idx = np.arange(n)
    if shuffle:
        rng = np.random.default_rng(seed)
        rng.shuffle(idx)
    # drop_last
    n_full = (n // batch_size) * batch_size
    idx = idx[:n_full]
    return idx.reshape(-1, batch_size)

def iter_torch_batches(x_nchw_f32, y_i64, batch_size, device, shuffle=True, seed=SEED):
    batches = make_batches_indices(len(y_i64), batch_size, shuffle=shuffle, seed=seed)
    mean = torch.tensor(CIFAR_MEAN.reshape(1,3,1,1), device=device)
    std  = torch.tensor(CIFAR_STD.reshape(1,3,1,1), device=device)

    for bi, b in enumerate(batches):
        xb = torch.from_numpy(x_nchw_f32[b]).to(device=device)
        yb = torch.from_numpy(y_i64[b]).to(device=device)

        # Normalize on-device
        xb = (xb / 255.0 - mean) / std
        yield bi, xb, yb

def iter_mlx_batches(x_nhwc_f32, y_i64, batch_size, shuffle=True, seed=SEED):
    batches = make_batches_indices(len(y_i64), batch_size, shuffle=shuffle, seed=seed)
    mean = mx.array(CIFAR_MEAN.reshape(1,1,1,3))
    std  = mx.array(CIFAR_STD.reshape(1,1,1,3))

    for bi, b in enumerate(batches):
        xb = mx.array(x_nhwc_f32[b])
        yb = mx.array(y_i64[b])
        xb = (xb / 255.0 - mean) / std
        yield bi, xb, yb

## models

In [5]:
from torchvision.models import resnet18

def build_torch_resnet18_cifar(num_classes=10):
    model = resnet18(num_classes=num_classes)
    # CIFAR stem: 3x3 conv, stride 1, no maxpool
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model

In [6]:
class MlxBasicBlock(mnn.Module):
    expansion = 1
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = mnn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1   = mnn.BatchNorm(out_ch)
        self.conv2 = mnn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2   = mnn.BatchNorm(out_ch)

        if stride != 1 or in_ch != out_ch:
            self.down = mnn.Sequential(
                mnn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, padding=0, bias=False),
                mnn.BatchNorm(out_ch),
            )
        else:
            self.down = None

    def __call__(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = mnn.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.down is not None:
            identity = self.down(identity)

        out = out + identity
        out = mnn.relu(out)
        return out

class MlxResNet(mnn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # CIFAR stem (NHWC)
        self.conv1 = mnn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1   = mnn.BatchNorm(64)

        self.in_ch = 64
        self.layer1 = self._make_layer(64,  2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)

        self.fc = mnn.Linear(512, num_classes)

    def _make_layer(self, out_ch, blocks, stride):
        layers = []
        layers.append(MlxBasicBlock(self.in_ch, out_ch, stride=stride))
        self.in_ch = out_ch
        for _ in range(1, blocks):
            layers.append(MlxBasicBlock(self.in_ch, out_ch, stride=1))
        return mnn.Sequential(*layers)

    def __call__(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = mnn.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # Global average pool over H,W (NHWC)
        x = mx.mean(x, axis=(1,2))  # -> (N, C)
        x = self.fc(x)
        return x

## benchmark helpers

In [7]:
def run_torch_train(name: str, device: str, use_compile: bool):
    assert device in ("cpu", "mps")
    if device == "mps" and not torch.backends.mps.is_available():
        return {"name": name, "status": "SKIP (no MPS)", "device": device}

    torch_device = torch.device(device)
    model = build_torch_resnet18_cifar().to(torch_device)
    opt = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

    # torch.compile (default backend is "inductor")  [oai_citation:4‡PyTorch Documentation](https://docs.pytorch.org/docs/stable/generated/torch.compile.html?utm_source=chatgpt.com)
    if use_compile:
        try:
            model = torch.compile(model)
        except Exception as e:
            return {"name": name, "status": f"FAIL compile: {type(e).__name__}: {e}", "device": device}

    # Warm-up single step (important esp. for compile + async backends)
    model.train()
    bi, xb, yb = next(iter_torch_batches(x_train_nchw, y_train, BATCH_SIZE, device, shuffle=True, seed=SEED))
    opt.zero_grad(set_to_none=True)
    out = model(xb)
    loss = F.cross_entropy(out, yb)
    loss.backward()
    opt.step()
    torch_sync(device)

    epoch_times = []
    for epoch in range(EPOCHS):
        t0 = now()
        model.train()

        for bi, xb, yb in iter_torch_batches(x_train_nchw, y_train, BATCH_SIZE, device, shuffle=True, seed=SEED+epoch+1):
            opt.zero_grad(set_to_none=True)
            out = model(xb)
            loss = F.cross_entropy(out, yb)
            loss.backward()
            opt.step()

        torch_sync(device)
        t1 = now()
        epoch_times.append(t1 - t0)
        print(f"[{name}] epoch {epoch+1}/{EPOCHS} time: {epoch_times[-1]:.3f}s")

    return {
        "name": name,
        "status": "OK",
        "device": device,
        "epochs": EPOCHS,
        "epoch_time_mean_s": float(np.mean(epoch_times)),
        "epoch_time_std_s": float(np.std(epoch_times)),
        "total_time_s": float(np.sum(epoch_times)),
    }

In [8]:
def run_mlx_train(name: str, use_compile: bool):
    model = MlxResNet(num_classes=10)
    optimizer = moptim.SGD(learning_rate=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

    def loss_fn(model, x, y):
        logits = model(x)
        return mnn.losses.cross_entropy(logits, y)

    # Uncompiled step
    loss_and_grad = mnn.value_and_grad(model, loss_fn)

    def step_eager(x, y):
        loss, grads = loss_and_grad(model, x, y)
        optimizer.update(model, grads)
        return loss

    # Compiled step (capture state in+out)  [oai_citation:5‡ml-explore.github.io](https://ml-explore.github.io/mlx/build/html/usage/compile.html)
    state = [model.state, optimizer.state, mx.random.state]

    @partial(mx.compile, inputs=state, outputs=state)
    def step_compiled(x, y):
        loss_and_grad_fn = mnn.value_and_grad(model, loss_fn)
        loss, grads = loss_and_grad_fn(model, x, y)
        optimizer.update(model, grads)
        return loss

    step = step_compiled if use_compile else step_eager

    # Warm-up one step
    bi, xb, yb = next(iter_mlx_batches(x_train_nhwc, y_train, BATCH_SIZE, shuffle=True, seed=SEED))
    loss = step(xb, yb)
    mlx_sync(state)

    epoch_times = []
    for epoch in range(EPOCHS):
        t0 = now()
        for bi, xb, yb in iter_mlx_batches(x_train_nhwc, y_train, BATCH_SIZE, shuffle=True, seed=SEED+epoch+1):
            loss = step(xb, yb)
            # Force compute so timing is real
            mlx_sync(state)

        t1 = now()
        epoch_times.append(t1 - t0)
        print(f"[{name}] epoch {epoch+1}/{EPOCHS} time: {epoch_times[-1]:.3f}s")

    return {
        "name": name,
        "status": "OK",
        "device": "mlx",
        "epochs": EPOCHS,
        "epoch_time_mean_s": float(np.mean(epoch_times)),
        "epoch_time_std_s": float(np.std(epoch_times)),
        "total_time_s": float(np.sum(epoch_times)),
    }

## runs

In [9]:
results = []
results.append(run_torch_train("torch_cpu_baseline", device="cpu", use_compile=False))

[torch_cpu_baseline] epoch 1/10 time: 974.931s
[torch_cpu_baseline] epoch 2/10 time: 1006.732s
[torch_cpu_baseline] epoch 3/10 time: 1073.400s
[torch_cpu_baseline] epoch 4/10 time: 1052.026s
[torch_cpu_baseline] epoch 5/10 time: 1069.065s
[torch_cpu_baseline] epoch 6/10 time: 1050.840s
[torch_cpu_baseline] epoch 7/10 time: 1063.588s
[torch_cpu_baseline] epoch 8/10 time: 977.699s
[torch_cpu_baseline] epoch 9/10 time: 1040.370s
[torch_cpu_baseline] epoch 10/10 time: 1065.361s


In [None]:
results.append(run_torch_train("torch_cpu_compile", device="cpu", use_compile=True))

In [12]:
results.append(run_torch_train("torch_mps", device="mps", use_compile=False))

[torch_mps] epoch 1/10 time: 142.176s
[torch_mps] epoch 2/10 time: 144.068s
[torch_mps] epoch 3/10 time: 145.051s
[torch_mps] epoch 4/10 time: 150.447s
[torch_mps] epoch 5/10 time: 163.132s
[torch_mps] epoch 6/10 time: 183.788s
[torch_mps] epoch 7/10 time: 207.818s
[torch_mps] epoch 8/10 time: 204.126s
[torch_mps] epoch 9/10 time: 199.734s
[torch_mps] epoch 10/10 time: 157.943s


In [None]:
results.append(run_torch_train("torch_mps_compile", device="mps", use_compile=True))

In [None]:
results.append(run_mlx_train("mlx", use_compile=False))

In [None]:
results.append(run_mlx_train("mlx_compile", use_compile=True))

In [None]:
df = pd.DataFrame(results)

# Keep rows that have timing
timed = df[df["status"].eq("OK")].copy()
timed = timed.sort_values("epoch_time_mean_s", ascending=True)

print("=== Raw results ===")
display(df)

print("\n=== Ranking (fastest mean epoch time first) ===")
display(timed[["name","device","epoch_time_mean_s","epoch_time_std_s","total_time_s"]])