In [1]:
from functools import partial
from PIL import Image
import numpy as np
import jax.numpy as jnp
import optax
from flax import nnx
from datasets import load_dataset

In [2]:
rngs = nnx.Rngs(0)

In [3]:
def transform(x):
    x = x.reshape(-1, 3, 32, 32)
    x = [[Image.fromarray(z).resize((224, 224)) for z in y] for y in x]
    x = np.stack([np.stack([np.asarray(z) for z in y], axis=0) for y in x], axis=0)
    x = x.reshape(-1, 224, 224, 3)
    return x

In [4]:
dataset = load_dataset("cifar10")

X_train = (np.array([np.array(image) for image in dataset["train"]["img"]]) / 255.0 - 0.5) / 0.25
Y_train = np.array(dataset["train"]["label"], dtype=np.int32)

X_test = (np.array([np.array(image) for image in dataset["test"]["img"]]) / 255.0 - 0.5) / 0.25
Y_test = np.array(dataset["test"]["label"], dtype=np.int32)

In [5]:
class AlexNet(nnx.Module):
    def __init__(self, *, rngs):
        self.conv1 = nnx.Conv(3, 64, kernel_size=(11, 11), strides=(4, 4), padding=(2, 2), rngs=rngs)
        self.max_pool1 = partial(nnx.max_pool, window_shape=(3, 3), strides=(2, 2))
        self.conv2 = nnx.Conv(64, 192, kernel_size=(5, 5), padding=(2, 2), rngs=rngs)
        self.max_pool2 = partial(nnx.max_pool, window_shape=(2, 2), strides=(2, 2))
        self.conv3 = nnx.Conv(192, 384, kernel_size=(3, 3), padding=(1, 1), rngs=rngs)
        self.conv4 = nnx.Conv(384, 256, kernel_size=(3, 3), padding=(1, 1), rngs=rngs)
        self.conv5 = nnx.Conv(256, 256, kernel_size=(3, 3), padding=(1, 1), rngs=rngs)
        self.max_pool3 = partial(nnx.max_pool, window_shape=(3, 3), strides=(2, 2))
        self.avg_pool = partial(nnx.avg_pool, window_shape=(1, 1), strides=(1, 1))
        self.dropout1 = nnx.Dropout(0.5, rngs=rngs)
        self.l1 = nnx.Linear(9216, 4096, rngs=rngs)
        self.dropout2 = nnx.Dropout(0.5, rngs=rngs)
        self.l2 = nnx.Linear(4096, 4096, rngs=rngs)
        self.l3 = nnx.Linear(4096, 10, rngs=rngs)

    def __call__(self, x):
        x = self.max_pool1(nnx.relu(self.conv1(x)))
        x = self.max_pool2(nnx.relu(self.conv2(x)))
        x = nnx.relu(self.conv3(x))
        x = nnx.relu(self.conv4(x))
        x = self.max_pool3(nnx.relu(self.conv5(x)))
        x = self.avg_pool(x)
        x = x.reshape(x.shape[0], -1)
        x = nnx.relu(self.l1(self.dropout1(x)))
        x = nnx.relu(self.l2(self.dropout2(x)))
        x = self.l3(x)
        return x

In [6]:
model = AlexNet(rngs=rngs)
y = model(jnp.ones((1, 224, 224, 3)))
nnx.display(y)

[[-0.03867229 -0.3541497   0.01953721  0.11465882  0.13008535 -0.06420223
   0.05067527  0.01634647  0.17827238  0.13162051]]


In [7]:
learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(model, optax.sgd(learning_rate, momentum))
metrics = nnx.MultiMetric(accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average("loss"))

In [8]:
def loss_fn(model, images, labels):
    logits = model(images)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
    return loss, logits

@nnx.jit
def train_step(model, optimizer, metrics, images, labels):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, images, labels)
    metrics.update(loss=loss, logits=logits, labels=labels)
    optimizer.update(grads)

@nnx.jit
def eval_step(model, metrics, images, labels):
    loss, logits = loss_fn(model, images, labels)
    metrics.update(loss=loss, logits=logits, labels=labels)

In [9]:
epochs = 10
batch_size = 128
eval_every = len(X_train)
train_steps = len(X_train) // batch_size
test_steps = len(X_test) // batch_size
metrics_history = {"train_loss": [], "train_accuracy": [], "test_loss": [], "test_accuracy": []}

for epoch in range(1, epochs + 1):
    for step in range(train_steps):
        sample = np.random.randint(0, len(X_train), size=batch_size)
        images, labels = transform(X_train[sample]), Y_train[sample]
        train_step(model, optimizer, metrics, images, labels)

        if step > 0 and (step % eval_every == 0 or step == train_steps - 1):
            for metric, value in metrics.compute().items():
                metrics_history[f"train_{metric}"].append(value)
            metrics.reset()

            for test_step in range(test_steps):
                images = transform(X_test[batch_size*test_step:batch_size*(test_step+1)])
                labels = Y_test[batch_size*test_step:batch_size*(test_step+1)]
                eval_step(model, metrics, images, labels)

            for metric, value in metrics.compute().items():
                metrics_history[f"test_{metric}"].append(value)
            metrics.reset()

            print(
                f"[train] epoch: {epoch}, step: {step}, "
                f"loss: {metrics_history['train_loss'][-1]:.4f}, "
                f"accuracy: {metrics_history['train_accuracy'][-1] * 100:.2f}"
            )
            print(
                f"[test] epoch: {epoch}, step: {step}, "
                f"loss: {metrics_history['test_loss'][-1]:.4f}, "
                f"accuracy: {metrics_history['test_accuracy'][-1] * 100:.2f}"
            )
    learning_rate *= 0.75
    optimizer = nnx.Optimizer(model, optax.sgd(learning_rate, momentum))

[train] epoch: 1, step: 389, loss: 1.9145, accuracy: 29.80
[test] epoch: 1, step: 389, loss: 1.6382, accuracy: 39.56
[train] epoch: 2, step: 389, loss: 1.5316, accuracy: 44.16
[test] epoch: 2, step: 389, loss: 1.4437, accuracy: 47.78
[train] epoch: 3, step: 389, loss: 1.3592, accuracy: 51.30
[test] epoch: 3, step: 389, loss: 1.2978, accuracy: 53.48
[train] epoch: 4, step: 389, loss: 1.2669, accuracy: 54.59
[test] epoch: 4, step: 389, loss: 1.2917, accuracy: 53.56
[train] epoch: 5, step: 389, loss: 1.1853, accuracy: 57.37
[test] epoch: 5, step: 389, loss: 1.2657, accuracy: 54.58
[train] epoch: 6, step: 389, loss: 1.1194, accuracy: 60.11
[test] epoch: 6, step: 389, loss: 1.1502, accuracy: 58.44
[train] epoch: 7, step: 389, loss: 1.0787, accuracy: 61.75
[test] epoch: 7, step: 389, loss: 1.1314, accuracy: 59.54
[train] epoch: 8, step: 389, loss: 1.0294, accuracy: 63.55
[test] epoch: 8, step: 389, loss: 1.1108, accuracy: 60.46
[train] epoch: 9, step: 389, loss: 1.0123, accuracy: 64.00
[test