In [1]:
from functools import partial
import numpy as np
import jax.numpy as jnp
import optax
from flax import nnx
from tqdm import trange
from datasets import load_dataset

In [2]:
np.random.seed(0)
rngs = nnx.Rngs(0)

In [3]:
dataset = load_dataset("mnist")

X_train = np.array([np.array(image) for image in dataset["train"]["image"]], dtype=np.float32)
X_train = np.expand_dims(X_train, -1) / 255.0
Y_train = np.array(dataset["train"]["label"], dtype=np.int32)

X_test = np.array([np.array(image) for image in dataset["test"]["image"]], dtype=np.float32)
X_test = np.expand_dims(X_test, -1) / 255.0
Y_test = np.array(dataset["test"]["label"], dtype=np.int32)

In [4]:
class ConvNet(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
        self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
        self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
        self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
        self.linear2 = nnx.Linear(256, 10, rngs=rngs)

    def __call__(self, x):
        x = self.avg_pool(nnx.relu(self.conv1(x)))
        x = self.avg_pool(nnx.relu(self.conv2(x)))
        x = x.reshape(x.shape[0], -1)  # flatten
        x = nnx.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [5]:
model = ConvNet(rngs=rngs)
y = model(jnp.ones((1, 28, 28, 1)))
nnx.display(y)

[[-0.06820839 -0.14743432  0.00265857 -0.2173656   0.16673787 -0.00923921
  -0.06636689  0.28341877  0.33754364 -0.20142877]]


In [6]:
learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average("loss"))

In [7]:
def loss_fn(model: ConvNet, images, labels):
    logits = model(images)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
    return loss, logits

@nnx.jit
def train_step(model: ConvNet, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, images, labels):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, images, labels)
    metrics.update(loss=loss, logits=logits, labels=labels)
    optimizer.update(grads)

@nnx.jit
def eval_step(model: ConvNet, metrics: nnx.MultiMetric, images, labels):
    loss, logits = loss_fn(model, images, labels)
    metrics.update(loss=loss, logits=logits, labels=labels)

In [8]:
batch_size = 32
eval_every = 200
train_steps = len(X_train) // batch_size + 1
test_steps = len(X_test) // batch_size + 1

metrics_history = {
    "train_loss": [],
    "train_accuracy": [],
    "test_loss": [],
    "test_accuracy": [],
}

for step in range(train_steps):
    sample = np.random.randint(0, len(X_train), size=batch_size)
    images, labels = X_train[sample], Y_train[sample]
    train_step(model, optimizer, metrics, images, labels)

    if step > 0 and (step % eval_every == 0 or step == train_steps - 1):
        for metric, value in metrics.compute().items():
            metrics_history[f"train_{metric}"].append(value)
        metrics.reset()

        for test_step in range(test_steps):
            images = X_test[batch_size*test_step:batch_size*(test_step+1)]
            labels = Y_test[batch_size*test_step:batch_size*(test_step+1)]
            eval_step(model, metrics, images, labels)

        for metric, value in metrics.compute().items():
            metrics_history[f"test_{metric}"].append(value)
        metrics.reset()

        print(
            f"[train] step: {step}, "
            f"loss: {metrics_history['train_loss'][-1]:.4f}, "
            f"accuracy: {metrics_history['train_accuracy'][-1] * 100:.2f}"
        )
        print(
            f"[test] step: {step}, "
            f"loss: {metrics_history['test_loss'][-1]:.4f}, "
            f"accuracy: {metrics_history['test_accuracy'][-1] * 100:.2f}"
        )

[train] step: 200, loss: 0.3459, accuracy: 89.43
[test] step: 200, loss: 0.1246, accuracy: 95.92
[train] step: 400, loss: 0.1196, accuracy: 96.38
[test] step: 400, loss: 0.1015, accuracy: 96.58
[train] step: 600, loss: 0.0824, accuracy: 97.22
[test] step: 600, loss: 0.0968, accuracy: 96.68
[train] step: 800, loss: 0.0797, accuracy: 97.69
[test] step: 800, loss: 0.0577, accuracy: 98.08
[train] step: 1000, loss: 0.0673, accuracy: 98.23
[test] step: 1000, loss: 0.0665, accuracy: 97.84
[train] step: 1200, loss: 0.0664, accuracy: 97.98
[test] step: 1200, loss: 0.0492, accuracy: 98.31
[train] step: 1400, loss: 0.0591, accuracy: 98.17
[test] step: 1400, loss: 0.0457, accuracy: 98.50
[train] step: 1600, loss: 0.0510, accuracy: 98.45
[test] step: 1600, loss: 0.0886, accuracy: 97.06
[train] step: 1800, loss: 0.0504, accuracy: 98.47
[test] step: 1800, loss: 0.0458, accuracy: 98.56
[train] step: 1875, loss: 0.0577, accuracy: 98.54
[test] step: 1875, loss: 0.0453, accuracy: 98.68
