A rough copy of https://blog.paperspace.com/writing-lenet5-from-scratch-in-python/

In [1]:
from functools import partial
from PIL import Image
import numpy as np
import jax.numpy as jnp
import optax
from flax import nnx
from datasets import load_dataset

In [2]:
rngs = nnx.Rngs(0)

In [3]:
def transform(x):
    x = [Image.fromarray(xx).resize((32, 32)) for xx in x]
    x = np.stack([np.asarray(xx) for xx in x], axis=0)
    x = np.expand_dims(x, axis=-1)
    return x

In [4]:
dataset = load_dataset("mnist")

X_train = transform([np.array(image) for image in dataset["train"]["image"]])
Y_train = np.array(dataset["train"]["label"], dtype=np.int32)

X_test = transform([np.array(image) for image in dataset["test"]["image"]])
Y_test = np.array(dataset["test"]["label"], dtype=np.int32)

In [5]:
class LeNet(nnx.Module):
    def __init__(self, *, rngs):
        self.conv1 = nnx.Conv(1, 6, kernel_size=(5, 5), padding="VALID", rngs=rngs)
        self.bn1 = nnx.BatchNorm(num_features=6, rngs=rngs)
        self.max_pool1 = partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2))
        self.conv2 = nnx.Conv(6, 16, kernel_size=(5, 5), padding="VALID", rngs=rngs)
        self.bn2 = nnx.BatchNorm(num_features=16, rngs=rngs)
        self.max_pool2 = partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2))
        self.l1 = nnx.Linear(400, 120, rngs=rngs)
        self.l2 = nnx.Linear(120, 84, rngs=rngs)
        self.l3 = nnx.Linear(84, 10, rngs=rngs)

    def __call__(self, x):
        x = self.bn1(self.conv1(x))
        x = nnx.relu(self.max_pool1(x))
        x = self.max_pool2(nnx.relu(self.bn2(self.conv2(x))))
        x = x.reshape(x.shape[0], -1)
        x = nnx.relu(self.l1(x))
        x = nnx.relu(self.l2(x))
        x = self.l3(x)
        return x

In [6]:
model = LeNet(rngs=rngs)
y = model(jnp.ones((1, 32, 32, 1)))
nnx.display(y)

[[ 1.06200893e-09 -1.02494360e-08 -2.24418906e-09 -1.07206874e-08
   1.28257112e-08  1.35244624e-08 -1.24616042e-08 -1.31719808e-08
   3.65592498e-08 -1.91581808e-08]]


In [7]:
learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average("loss"))

In [8]:
def loss_fn(model, images, labels):
    logits = model(images)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
    return loss, logits

@nnx.jit
def train_step(model, optimizer, metrics, images, labels):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, images, labels)
    metrics.update(loss=loss, logits=logits, labels=labels)
    optimizer.update(grads)

@nnx.jit
def eval_step(model, metrics, images, labels):
    loss, logits = loss_fn(model, images, labels)
    metrics.update(loss=loss, logits=logits, labels=labels)

In [9]:
batch_size = 32
eval_every = 200
train_steps = len(X_train) // batch_size
test_steps = len(X_test) // batch_size
metrics_history = {"train_loss": [], "train_accuracy": [], "test_loss": [], "test_accuracy": []}

for step in range(train_steps):
    sample = np.random.randint(0, len(X_train), size=batch_size)
    images, labels = X_train[sample], Y_train[sample]
    train_step(model, optimizer, metrics, images, labels)

    if step > 0 and (step % eval_every == 0 or step == train_steps - 1):
        for metric, value in metrics.compute().items():
            metrics_history[f"train_{metric}"].append(value)
        metrics.reset()

        for test_step in range(test_steps):
            images = X_test[batch_size*test_step:batch_size*(test_step+1)]
            labels = Y_test[batch_size*test_step:batch_size*(test_step+1)]
            eval_step(model, metrics, images, labels)

        for metric, value in metrics.compute().items():
            metrics_history[f"test_{metric}"].append(value)
        metrics.reset()

        print(
            f"[train] step: {step}, "
            f"loss: {metrics_history['train_loss'][-1]:.4f}, "
            f"accuracy: {metrics_history['train_accuracy'][-1] * 100:.2f}"
        )
        print(
            f"[test] step: {step}, "
            f"loss: {metrics_history['test_loss'][-1]:.4f}, "
            f"accuracy: {metrics_history['test_accuracy'][-1] * 100:.2f}"
        )

[train] step: 200, loss: 0.3610, accuracy: 88.34
[test] step: 200, loss: 0.1361, accuracy: 95.88
[train] step: 400, loss: 0.1499, accuracy: 95.34
[test] step: 400, loss: 0.1141, accuracy: 96.54
[train] step: 600, loss: 0.1319, accuracy: 95.81
[test] step: 600, loss: 0.0962, accuracy: 97.14
[train] step: 800, loss: 0.0966, accuracy: 97.17
[test] step: 800, loss: 0.0675, accuracy: 97.92
[train] step: 1000, loss: 0.0859, accuracy: 97.45
[test] step: 1000, loss: 0.0755, accuracy: 97.71
[train] step: 1200, loss: 0.0866, accuracy: 97.50
[test] step: 1200, loss: 0.0584, accuracy: 98.25
[train] step: 1400, loss: 0.0776, accuracy: 97.67
[test] step: 1400, loss: 0.0744, accuracy: 97.67
[train] step: 1600, loss: 0.0736, accuracy: 97.77
[test] step: 1600, loss: 0.0871, accuracy: 97.31
[train] step: 1800, loss: 0.0689, accuracy: 98.05
[test] step: 1800, loss: 0.0660, accuracy: 98.05
[train] step: 1874, loss: 0.0621, accuracy: 97.80
[test] step: 1874, loss: 0.0562, accuracy: 98.32
