A rough copy of https://blog.paperspace.com/writing-lenet5-from-scratch-in-python/

In [1]:
from functools import partial
from PIL import Image
import numpy as np
import jax.numpy as jnp
import optax
from flax import nnx
from datasets import load_dataset

In [2]:
rngs = nnx.Rngs(0)

In [34]:
def transform(x):
    x = [Image.fromarray(xx).resize((32, 32)) for xx in x]
    x = np.stack([np.asarray(xx) for xx in x], axis=0)
    x = np.expand_dims(x, axis=-1)
    return x

In [35]:
dataset = load_dataset("mnist")

X_train = transform(np.array([np.array(image) for image in dataset["train"]["image"]], dtype=np.float32))
Y_train = np.array(dataset["train"]["label"], dtype=np.int32)

X_test = transform(np.array([np.array(image) for image in dataset["test"]["image"]], dtype=np.float32))
Y_test = np.array(dataset["test"]["label"], dtype=np.int32)

In [55]:
class LeNet:
    def __init__(self, *, rngs):
        self.conv1 = nnx.Conv(1, 6, kernel_size=(5, 5), rngs=rngs)
        self.bn1 = nnx.BatchNorm(num_features=6, rngs=rngs)
        self.max_pool1 = partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2))
        self.conv2 = nnx.Conv(6, 16, kernel_size=(5, 5), rngs=rngs)
        self.bn2 = nnx.BatchNorm(num_features=16, rngs=rngs)
        self.max_pool2 = partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2))
        self.l1 = nnx.Linear(1024, 256, rngs=rngs)
        self.l2 = nnx.Linear(256, 64, rngs=rngs)
        self.l3 = nnx.Linear(64, 10, rngs=rngs)

    def __call__(self, x):
        x = self.bn1(self.conv1(x))
        x = nnx.relu(self.max_pool1(x))
        x = self.max_pool2(nnx.relu(self.bn2(self.conv2(x))))
        x = x.reshape(x.shape[0], -1)
        x = nnx.relu(self.l1(x))
        x = nnx.relu(self.l2(x))
        x = self.l3(x)
        return x

In [56]:
model = LeNet(rngs=rngs)
y = model(jnp.ones((9001, 32, 32, 1)))
nnx.display(y)

(9001, 32, 32, 1)
(9001, 32, 32, 6)
(9001, 16, 16, 6)
[[-0.5924146   0.36494392 -0.32361963 ... -0.20976512  0.6094093
  -0.06287368]
 [-0.5924146   0.36494392 -0.32361963 ... -0.20976512  0.6094093
  -0.06287368]
 [-0.5924146   0.36494392 -0.32361963 ... -0.20976512  0.6094093
  -0.06287368]
 ...
 [-0.5924146   0.36494392 -0.32361963 ... -0.20976512  0.6094093
  -0.06287368]
 [-0.5924146   0.36494392 -0.32361963 ... -0.20976512  0.6094093
  -0.06287368]
 [-0.5924146   0.36494392 -0.32361963 ... -0.20976512  0.6094093
  -0.06287368]]
