In [13]:
import os
os.environ["TORCHDYNAMO_VERBOSE"] = "1"

In [14]:
import os, time, math, random
import numpy as np
import pandas as pd

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T

# MLX
import mlx.core as mx
import mlx.nn as mnn
import mlx.optimizers as moptim
from functools import partial

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Benchmark knobs
EPOCHS = 20
BATCH_SIZE = 256
LR = 0.1
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4

PRINT_EVERY = 200  # batches

# MNIST normalization (common)
MNIST_MEAN = np.array([0.1307], dtype=np.float32)
MNIST_STD = np.array([0.3081], dtype=np.float32)

def now():
    return time.perf_counter()

def torch_sync(device: str):
    # For accurate timing; MPS is async.
    if device == "mps":
        torch.mps.synchronize()

def mlx_sync(*arrays):
    # MLX is lazy; force compute.
    mx.eval(*arrays)

In [15]:
# Download once using torchvision, then store as uint8 in RAM.
train_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
test_set  = torchvision.datasets.MNIST(root="./data", train=False, download=True)

def dataset_to_numpy(ds):
    # MNIST: ds.data is uint8 [N, H, W]
    x_nhw_u8 = ds.data.numpy()
    y = np.array(ds.targets, dtype=np.int64)
    return x_nhw_u8, y

x_train_nhw_u8, y_train = dataset_to_numpy(train_set)
x_test_nhw_u8,  y_test  = dataset_to_numpy(test_set)

# Torch wants NCHW float32
x_train_nchw = x_train_nhw_u8[:, None, :, :].astype(np.float32)  # N1HW
x_test_nchw  = x_test_nhw_u8[:, None, :, :].astype(np.float32)

# MLX wants NHWC float32
x_train_nhwc = x_train_nhw_u8[:, :, :, None].astype(np.float32)  # NHW1
x_test_nhwc  = x_test_nhw_u8[:, :, :, None].astype(np.float32)

print("Train:", x_train_nchw.shape, y_train.shape, "| Test:", x_test_nchw.shape, y_test.shape)

Train: (60000, 1, 28, 28) (60000,) | Test: (10000, 1, 28, 28) (10000,)


  y = np.array(ds.targets, dtype=np.int64)


In [16]:
def make_batches_indices(n, batch_size, shuffle=True, seed=SEED):
    idx = np.arange(n)
    if shuffle:
        rng = np.random.default_rng(seed)
        rng.shuffle(idx)
    # drop_last
    n_full = (n // batch_size) * batch_size
    idx = idx[:n_full]
    return idx.reshape(-1, batch_size)

def iter_torch_batches(x_nchw_f32, y_i64, batch_size, device, shuffle=True, seed=SEED):
    batches = make_batches_indices(len(y_i64), batch_size, shuffle=shuffle, seed=seed)
    mean = torch.tensor(MNIST_MEAN.reshape(1, 1, 1, 1), device=device)
    std  = torch.tensor(MNIST_STD.reshape(1, 1, 1, 1), device=device)

    for bi, b in enumerate(batches):
        xb = torch.from_numpy(x_nchw_f32[b]).to(device=device)
        yb = torch.from_numpy(y_i64[b]).to(device=device)
        xb = (xb / 255.0 - mean) / std
        yield bi, xb, yb

def iter_mlx_batches(x_nhwc_f32, y_i64, batch_size, shuffle=True, seed=SEED):
    batches = make_batches_indices(len(y_i64), batch_size, shuffle=shuffle, seed=seed)
    mean = mx.array(MNIST_MEAN.reshape(1, 1, 1, 1))
    std  = mx.array(MNIST_STD.reshape(1, 1, 1, 1))

    for bi, b in enumerate(batches):
        xb = mx.array(x_nhwc_f32[b])
        yb = mx.array(y_i64[b])
        xb = (xb / 255.0 - mean) / std
        yield bi, xb, yb

## models

In [17]:
class TorchSimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.pool  = nn.MaxPool2d(kernel_size=2, stride=2)  # 28->14->7
        self.fc1   = nn.Linear(64 * 7 * 7, 128, bias=True)
        self.fc2   = nn.Linear(128, num_classes, bias=True)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def build_torch_simplecnn_mnist(num_classes=10):
    return TorchSimpleCNN(num_classes=num_classes)

In [18]:
class MlxSimpleCNN(mnn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = mnn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = mnn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.pool  = mnn.MaxPool2d(kernel_size=2, stride=2)  # 28->14->7
        self.fc1   = mnn.Linear(64 * 7 * 7, 128, bias=True)
        self.fc2   = mnn.Linear(128, num_classes, bias=True)

    def __call__(self, x):
        x = mnn.relu(self.conv1(x))
        x = self.pool(x)
        x = mnn.relu(self.conv2(x))
        x = self.pool(x)
        x = mx.reshape(x, (x.shape[0], -1))
        x = mnn.relu(self.fc1(x))
        x = self.fc2(x)
        return x


## benchmark helpers

In [19]:
def run_torch_train(name: str, device: str, use_compile: bool):
    assert device in ("cpu", "mps")
    if device == "mps" and not torch.backends.mps.is_available():
        return {"name": name, "status": "SKIP (no MPS)", "device": device}

    torch_device = torch.device(device)
    model = build_torch_simplecnn_mnist(num_classes=10).to(device=torch_device)
    opt = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

    # torch.compile (default backend is "inductor")  [oai_citation:4â€¡PyTorch Documentation](https://docs.pytorch.org/docs/stable/generated/torch.compile.html?utm_source=chatgpt.com)
    if use_compile:
        try:
            model = torch.compile(model, backend="aot_eager")
        except Exception as e:
            return {"name": name, "status": f"FAIL compile: {type(e).__name__}: {e}", "device": device}

    # Warm-up single step (important esp. for compile + async backends)
    model.train()
    bi, xb, yb = next(iter_torch_batches(x_train_nchw, y_train, BATCH_SIZE, device, shuffle=True, seed=SEED))
    opt.zero_grad(set_to_none=True)
    out = model(xb)
    loss = F.cross_entropy(out, yb)
    loss.backward()
    opt.step()
    torch_sync(device)

    epoch_times = []
    for epoch in range(EPOCHS):
        t0 = now()
        model.train()

        for bi, xb, yb in iter_torch_batches(x_train_nchw, y_train, BATCH_SIZE, device, shuffle=True, seed=SEED+epoch+1):
            opt.zero_grad(set_to_none=True)
            out = model(xb)
            loss = F.cross_entropy(out, yb)
            loss.backward()
            opt.step()

        torch_sync(device)
        t1 = now()
        epoch_times.append(t1 - t0)
        print(f"[{name}] epoch {epoch+1}/{EPOCHS} time: {epoch_times[-1]:.3f}s")

    return {
        "name": name,
        "status": "OK",
        "device": device,
        "epochs": EPOCHS,
        "epoch_time_mean_s": float(np.mean(epoch_times)),
        "epoch_time_std_s": float(np.std(epoch_times)),
        "total_time_s": float(np.sum(epoch_times)),
    }

In [28]:
def run_mlx_train(name: str, use_compile: bool, *, sync_every: int = 1):
    """
    MLX training loop optimized for correctness + performance.

    - Correctness: forces model/optimizer state to be evaluated so updates actually happen.
    - Performance: avoids recreating grad functions inside the compiled function, keeps batch shapes fixed,
      and lets you control evaluation frequency via sync_every.
    """
    assert sync_every >= 1, "sync_every must be >= 1 to avoid building an unbounded lazy graph"

    model = MlxSimpleCNN(num_classes=10)
    optimizer = moptim.SGD(
        learning_rate=LR,
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY,
    )

    def loss_fn(m, x, y):
        logits = m(x)
        return mx.mean(mnn.losses.cross_entropy(logits, y))

    # Build the grad function ONCE (important for compile time + stability).
    loss_and_grad_fn = mnn.value_and_grad(model, loss_fn)

    # Track state that must be treated as inputs/outputs for compilation correctness.
    # Keep mx.random.state only if you use randomness inside the step (dropout, stochastic aug, etc.).
    state = [model.state, optimizer.state, mx.random.state]

    def step_eager(x, y):
        loss, grads = loss_and_grad_fn(model, x, y)
        optimizer.update(model, grads)
        return loss

    @partial(mx.compile, inputs=state, outputs=state)
    def step_compiled(x, y):
        # No Python object creation in the hot path.
        loss, grads = loss_and_grad_fn(model, x, y)
        optimizer.update(model, grads)
        return loss

    step = step_compiled if use_compile else step_eager

    # Warm-up (also triggers first-time compile if use_compile=True)
    bi, xb, yb = next(iter_mlx_batches(x_train_nhwc, y_train, BATCH_SIZE, shuffle=True, seed=SEED))
    loss = step(xb, yb)
    # Force loss + state so the update is not optimized away.
    mx.eval(loss, *state)

    epoch_times = []
    for epoch in range(EPOCHS):
        t0 = now()

        for i, (bi, xb, yb) in enumerate(
            iter_mlx_batches(x_train_nhwc, y_train, BATCH_SIZE, shuffle=True, seed=SEED + epoch + 1)
        ):
            loss = step(xb, yb)

            # Correctness: force the update to execute.
            # Performance: you can reduce how often you sync, but don't set it too high.
            if (i + 1) % sync_every == 0:
                mx.eval(loss, *state)

        # Ensure the tail of the epoch is executed.
        mx.eval(loss, *state)

        t1 = now()
        epoch_times.append(t1 - t0)
        print(f"[{name}] epoch {epoch+1}/{EPOCHS} time: {epoch_times[-1]:.3f}s")

    epoch_times = np.array(epoch_times, dtype=np.float64)
    return {
        "name": name,
        "status": "OK",
        "device": "mlx",
        "epochs": EPOCHS,
        "use_compile": bool(use_compile),
        "sync_every": int(sync_every),
        # Helpful to separate warmup/compile cost from steady state
        "epoch1_time_s": float(epoch_times[0]) if len(epoch_times) else None,
        "epoch_time_mean_s": float(epoch_times.mean()) if len(epoch_times) else None,
        "epoch_time_std_s": float(epoch_times.std()) if len(epoch_times) else None,
        "epoch_time_mean_steady_s": float(epoch_times[1:].mean()) if len(epoch_times) > 1 else float(epoch_times.mean()),
        "total_time_s": float(epoch_times.sum()) if len(epoch_times) else None,
    }

## runs

In [21]:
results = []
results.append(run_torch_train("torch_cpu_baseline", device="cpu", use_compile=False))

[torch_cpu_baseline] epoch 1/20 time: 26.782s
[torch_cpu_baseline] epoch 2/20 time: 47.867s
[torch_cpu_baseline] epoch 3/20 time: 60.175s
[torch_cpu_baseline] epoch 4/20 time: 54.318s
[torch_cpu_baseline] epoch 5/20 time: 21.071s
[torch_cpu_baseline] epoch 6/20 time: 21.543s
[torch_cpu_baseline] epoch 7/20 time: 19.541s
[torch_cpu_baseline] epoch 8/20 time: 19.579s
[torch_cpu_baseline] epoch 9/20 time: 19.582s
[torch_cpu_baseline] epoch 10/20 time: 19.521s
[torch_cpu_baseline] epoch 11/20 time: 19.599s
[torch_cpu_baseline] epoch 12/20 time: 19.887s
[torch_cpu_baseline] epoch 13/20 time: 19.669s
[torch_cpu_baseline] epoch 14/20 time: 19.512s
[torch_cpu_baseline] epoch 15/20 time: 19.584s
[torch_cpu_baseline] epoch 16/20 time: 19.418s
[torch_cpu_baseline] epoch 17/20 time: 19.683s
[torch_cpu_baseline] epoch 18/20 time: 19.432s
[torch_cpu_baseline] epoch 19/20 time: 19.699s
[torch_cpu_baseline] epoch 20/20 time: 19.584s


In [22]:
results.append(run_torch_train("torch_cpu_compile", device="cpu", use_compile=True))

[torch_cpu_compile] epoch 1/20 time: 19.365s
[torch_cpu_compile] epoch 2/20 time: 19.286s
[torch_cpu_compile] epoch 3/20 time: 19.367s
[torch_cpu_compile] epoch 4/20 time: 19.361s
[torch_cpu_compile] epoch 5/20 time: 19.624s
[torch_cpu_compile] epoch 6/20 time: 19.291s
[torch_cpu_compile] epoch 7/20 time: 19.455s
[torch_cpu_compile] epoch 8/20 time: 19.069s
[torch_cpu_compile] epoch 9/20 time: 19.235s
[torch_cpu_compile] epoch 10/20 time: 19.838s
[torch_cpu_compile] epoch 11/20 time: 19.314s
[torch_cpu_compile] epoch 12/20 time: 19.235s
[torch_cpu_compile] epoch 13/20 time: 19.529s
[torch_cpu_compile] epoch 14/20 time: 19.025s
[torch_cpu_compile] epoch 15/20 time: 18.865s
[torch_cpu_compile] epoch 16/20 time: 19.262s
[torch_cpu_compile] epoch 17/20 time: 19.598s
[torch_cpu_compile] epoch 18/20 time: 19.251s
[torch_cpu_compile] epoch 19/20 time: 19.304s
[torch_cpu_compile] epoch 20/20 time: 19.259s


In [32]:
results.append(run_torch_train("torch_mps", device="mps", use_compile=False))

[torch_mps] epoch 1/20 time: 3.863s
[torch_mps] epoch 2/20 time: 3.841s
[torch_mps] epoch 3/20 time: 3.883s
[torch_mps] epoch 4/20 time: 3.865s
[torch_mps] epoch 5/20 time: 4.359s
[torch_mps] epoch 6/20 time: 3.824s
[torch_mps] epoch 7/20 time: 3.870s
[torch_mps] epoch 8/20 time: 3.833s
[torch_mps] epoch 9/20 time: 3.823s
[torch_mps] epoch 10/20 time: 3.821s
[torch_mps] epoch 11/20 time: 3.833s
[torch_mps] epoch 12/20 time: 3.890s
[torch_mps] epoch 13/20 time: 3.826s
[torch_mps] epoch 14/20 time: 3.835s
[torch_mps] epoch 15/20 time: 3.858s
[torch_mps] epoch 16/20 time: 3.844s
[torch_mps] epoch 17/20 time: 3.873s
[torch_mps] epoch 18/20 time: 4.449s
[torch_mps] epoch 19/20 time: 4.113s
[torch_mps] epoch 20/20 time: 4.074s


In [24]:
results.append(run_torch_train("torch_mps_compile", device="mps", use_compile=True))

[torch_mps_compile] epoch 1/20 time: 4.087s
[torch_mps_compile] epoch 2/20 time: 3.954s
[torch_mps_compile] epoch 3/20 time: 3.986s
[torch_mps_compile] epoch 4/20 time: 4.135s
[torch_mps_compile] epoch 5/20 time: 4.014s
[torch_mps_compile] epoch 6/20 time: 4.040s
[torch_mps_compile] epoch 7/20 time: 4.016s
[torch_mps_compile] epoch 8/20 time: 4.058s
[torch_mps_compile] epoch 9/20 time: 4.062s
[torch_mps_compile] epoch 10/20 time: 3.988s
[torch_mps_compile] epoch 11/20 time: 3.964s
[torch_mps_compile] epoch 12/20 time: 3.983s
[torch_mps_compile] epoch 13/20 time: 4.053s
[torch_mps_compile] epoch 14/20 time: 4.043s
[torch_mps_compile] epoch 15/20 time: 3.961s
[torch_mps_compile] epoch 16/20 time: 3.955s
[torch_mps_compile] epoch 17/20 time: 3.944s
[torch_mps_compile] epoch 18/20 time: 3.969s
[torch_mps_compile] epoch 19/20 time: 3.979s
[torch_mps_compile] epoch 20/20 time: 4.099s


In [29]:
results.append(run_mlx_train("mlx", use_compile=False))

[mlx] epoch 1/20 time: 6.890s
[mlx] epoch 2/20 time: 6.920s
[mlx] epoch 3/20 time: 6.884s
[mlx] epoch 4/20 time: 6.893s
[mlx] epoch 5/20 time: 6.876s
[mlx] epoch 6/20 time: 7.326s
[mlx] epoch 7/20 time: 7.042s
[mlx] epoch 8/20 time: 6.869s
[mlx] epoch 9/20 time: 6.853s
[mlx] epoch 10/20 time: 6.892s
[mlx] epoch 11/20 time: 6.880s
[mlx] epoch 12/20 time: 6.868s
[mlx] epoch 13/20 time: 6.867s
[mlx] epoch 14/20 time: 6.861s
[mlx] epoch 15/20 time: 6.871s
[mlx] epoch 16/20 time: 6.913s
[mlx] epoch 17/20 time: 6.931s
[mlx] epoch 18/20 time: 6.918s
[mlx] epoch 19/20 time: 6.948s
[mlx] epoch 20/20 time: 6.912s


In [30]:
results.append(run_mlx_train("mlx_compile", use_compile=True))

[mlx_compile] epoch 1/20 time: 5.809s
[mlx_compile] epoch 2/20 time: 5.801s
[mlx_compile] epoch 3/20 time: 5.825s
[mlx_compile] epoch 4/20 time: 5.788s
[mlx_compile] epoch 5/20 time: 5.779s
[mlx_compile] epoch 6/20 time: 5.784s
[mlx_compile] epoch 7/20 time: 5.774s
[mlx_compile] epoch 8/20 time: 5.798s
[mlx_compile] epoch 9/20 time: 5.784s
[mlx_compile] epoch 10/20 time: 5.786s
[mlx_compile] epoch 11/20 time: 5.789s
[mlx_compile] epoch 12/20 time: 5.786s
[mlx_compile] epoch 13/20 time: 5.821s
[mlx_compile] epoch 14/20 time: 5.801s
[mlx_compile] epoch 15/20 time: 5.799s
[mlx_compile] epoch 16/20 time: 5.785s
[mlx_compile] epoch 17/20 time: 5.785s
[mlx_compile] epoch 18/20 time: 5.840s
[mlx_compile] epoch 19/20 time: 5.824s
[mlx_compile] epoch 20/20 time: 5.826s


In [31]:
df = pd.DataFrame(results)

# Keep rows that have timing
timed = df[df["status"].eq("OK")].copy()
timed = timed.sort_values("epoch_time_mean_s", ascending=True)

print("=== Raw results ===")
display(df)

print("\n=== Ranking (fastest mean epoch time first) ===")
display(timed[["name","device","epoch_time_mean_s","epoch_time_std_s","total_time_s"]])

=== Raw results ===


Unnamed: 0,name,status,device,epochs,epoch_time_mean_s,epoch_time_std_s,total_time_s,use_compile,sync_every,epoch1_time_s,epoch_time_mean_steady_s
0,torch_cpu_baseline,OK,cpu,20,25.302321,12.365466,506.046421,,,,
1,torch_cpu_compile,OK,cpu,20,19.326674,0.211065,386.533488,,,,
2,torch_mps,OK,mps,20,3.864911,0.015598,77.298214,,,,
3,torch_mps_compile,OK,mps,20,4.014458,0.053598,80.289153,,,,
4,mlx,OK,mlx,20,6.92076,0.101687,138.41521,False,1.0,6.890366,6.92236
5,mlx_compile,OK,mlx,20,5.799211,0.018315,115.984212,True,1.0,5.809232,5.798683



=== Ranking (fastest mean epoch time first) ===


Unnamed: 0,name,device,epoch_time_mean_s,epoch_time_std_s,total_time_s
2,torch_mps,mps,3.864911,0.015598,77.298214
3,torch_mps_compile,mps,4.014458,0.053598,80.289153
5,mlx_compile,mlx,5.799211,0.018315,115.984212
4,mlx,mlx,6.92076,0.101687,138.41521
1,torch_cpu_compile,cpu,19.326674,0.211065,386.533488
0,torch_cpu_baseline,cpu,25.302321,12.365466,506.046421
