A rough copy of https://blog.paperspace.com/writing-lenet5-from-scratch-in-python/

In [1]:
from functools import partial
from PIL import Image
import numpy as np
import jax.numpy as jnp
import optax
from flax import nnx
from datasets import load_dataset

In [2]:
rngs = nnx.Rngs(0)

In [3]:
def transform(x):
    x = [Image.fromarray(xx).resize((32, 32)) for xx in x]
    x = np.stack([np.asarray(xx) for xx in x], axis=0)
    x = np.expand_dims(x, axis=-1)
    return x

In [4]:
dataset = load_dataset("mnist")

X_train = transform(np.array([np.array(image) for image in dataset["train"]["image"]], dtype=np.float32))
Y_train = np.array(dataset["train"]["label"], dtype=np.int32)

X_test = transform(np.array([np.array(image) for image in dataset["test"]["image"]], dtype=np.float32))
Y_test = np.array(dataset["test"]["label"], dtype=np.int32)

In [8]:
class LeNet(nnx.Module):
    def __init__(self, *, rngs):
        self.conv1 = nnx.Conv(1, 6, kernel_size=(5, 5), padding="VALID", rngs=rngs)
        self.bn1 = nnx.BatchNorm(num_features=6, rngs=rngs)
        self.max_pool1 = partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2))
        self.conv2 = nnx.Conv(6, 16, kernel_size=(5, 5), padding="VALID", rngs=rngs)
        self.bn2 = nnx.BatchNorm(num_features=16, rngs=rngs)
        self.max_pool2 = partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2))
        self.l1 = nnx.Linear(400, 120, rngs=rngs)
        self.l2 = nnx.Linear(120, 84, rngs=rngs)
        self.l3 = nnx.Linear(84, 10, rngs=rngs)

    def __call__(self, x):
        x = self.bn1(self.conv1(x))
        x = nnx.relu(self.max_pool1(x))
        x = self.max_pool2(nnx.relu(self.bn2(self.conv2(x))))
        x = x.reshape(x.shape[0], -1)
        x = nnx.relu(self.l1(x))
        x = nnx.relu(self.l2(x))
        x = self.l3(x)
        return x

In [9]:
model = LeNet(rngs=rngs)
y = model(jnp.ones((1, 32, 32, 1)))
nnx.display(y)

[[ 9.4948254e-09 -2.1230468e-08 -2.3708544e-09 -9.1365280e-09
   3.1189669e-09 -1.0155607e-08  1.0608820e-09  4.7077546e-09
   1.5445300e-08 -5.4524523e-09]]


In [10]:
learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average("loss"))

In [11]:
def loss_fn(model, images, labels):
    logits = model(images)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
    return loss, logits

@nnx.jit
def train_step(model, optimizer, metrics, images, labels):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, images, labels)
    metrics.update(loss=loss, logits=logits, labels=labels)
    optimizer.update(grads)

@nnx.jit
def eval_step(model, metrics, images, labels):
    loss, logits = loss_fn(model, images, labels)
    metrics.update(loss=loss, logits=logits, labels=labels)

In [12]:
batch_size = 32
eval_every = 200
train_steps = len(X_train) // batch_size + 1
test_steps = len(X_test) // batch_size + 1
metrics_history = {"train_loss": [], "train_accuracy": [], "test_loss": [], "test_accuracy": []}

for step in range(train_steps):
    sample = np.random.randint(0, len(X_train), size=batch_size)
    images, labels = X_train[sample], Y_train[sample]
    train_step(model, optimizer, metrics, images, labels)

    if step > 0 and (step % eval_every == 0 or step == train_steps - 1):
        for metric, value in metrics.compute().items():
            metrics_history[f"train_{metric}"].append(value)
        metrics.reset()

        for test_step in range(test_steps):
            images = X_test[batch_size*test_step:batch_size*(test_step+1)]
            labels = Y_test[batch_size*test_step:batch_size*(test_step+1)]
            eval_step(model, metrics, images, labels)

        for metric, value in metrics.compute().items():
            metrics_history[f"test_{metric}"].append(value)
        metrics.reset()

        print(
            f"[train] step: {step}, "
            f"loss: {metrics_history['train_loss'][-1]:.4f}, "
            f"accuracy: {metrics_history['train_accuracy'][-1] * 100:.2f}"
        )
        print(
            f"[test] step: {step}, "
            f"loss: {metrics_history['test_loss'][-1]:.4f}, "
            f"accuracy: {metrics_history['test_accuracy'][-1] * 100:.2f}"
        )

[train] step: 200, loss: 0.3921, accuracy: 87.39
[test] step: 200, loss: 0.1467, accuracy: 95.26
[train] step: 400, loss: 0.1486, accuracy: 95.50
[test] step: 400, loss: 0.1109, accuracy: 96.55
[train] step: 600, loss: 0.1071, accuracy: 96.78
[test] step: 600, loss: 0.0736, accuracy: 97.88
[train] step: 800, loss: 0.0970, accuracy: 97.23
[test] step: 800, loss: 0.0696, accuracy: 97.80
[train] step: 1000, loss: 0.0919, accuracy: 97.41
[test] step: 1000, loss: 0.0547, accuracy: 98.28
[train] step: 1200, loss: 0.0763, accuracy: 97.64
[test] step: 1200, loss: 0.0711, accuracy: 97.93
[train] step: 1400, loss: 0.0766, accuracy: 97.66
[test] step: 1400, loss: 0.0694, accuracy: 97.76
[train] step: 1600, loss: 0.0789, accuracy: 97.67
[test] step: 1600, loss: 0.0583, accuracy: 98.32
[train] step: 1800, loss: 0.0764, accuracy: 97.88
[test] step: 1800, loss: 0.0666, accuracy: 98.14
[train] step: 1875, loss: 0.0647, accuracy: 97.88
[test] step: 1875, loss: 0.0485, accuracy: 98.49
