In [1]:
from functools import partial
import numpy as np
from PIL import Image
import jax.numpy as jnp
import optax
from flax import nnx
from datasets import load_dataset

In [2]:
rngs = nnx.Rngs(0)



Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



I0000 00:00:1737258191.872177 24777688 service.cc:145] XLA service 0x131a4ecb0 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1737258191.872191 24777688 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1737258191.873634 24777688 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1737258191.873651 24777688 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.


In [3]:
def transform(x):
    x = x.reshape(-1, 3, 32, 32)
    x = [[Image.fromarray(z).resize((224, 224)) for z in y] for y in x]
    x = np.stack([np.stack([np.asarray(z) for z in y], axis=0) for y in x], axis=0)
    x = x.reshape(-1, 224, 224, 3)
    return x

In [4]:
dataset = load_dataset("cifar10")

X_train = np.array([np.array(image) for image in dataset["train"]["img"]]) / 255.0 - 0.5
Y_train = np.array(dataset["train"]["label"], dtype=np.int32)

X_test = np.array([np.array(image) for image in dataset["test"]["img"]]) / 255.0 - 0.5
Y_test = np.array(dataset["test"]["label"], dtype=np.int32)

In [5]:
class BasicBlock:
    def __init__(self, inplanes, planes, strides=(1, 1), downsample=None):
        self.conv1 = nnx.Conv(inplanes, planes, kernel_size=(3, 3), strides=strides, padding=(1, 1), use_bias=False, rngs=rngs)
        self.bn1 = nnx.BatchNorm(num_features=planes, rngs=rngs)
        self.conv2 = nnx.Conv(planes, planes, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), use_bias=False, rngs=rngs)
        self.bn2 = nnx.BatchNorm(num_features=planes, rngs=rngs)
        self.downsample = downsample

    def __call__(self, x):
        out = nnx.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            x = self.downsample(x)
        out = nnx.relu(out + x)
        return out

In [6]:
class ResNet(nnx.Module):
    def __init__(self):
        self.inplanes = 64
        self.conv1 = nnx.Conv(3, 64, kernel_size=(7, 7), strides=(2, 2), padding=(3, 3), use_bias=False, rngs=rngs)
        self.bn1 = nnx.BatchNorm(num_features=64, rngs=rngs)
        self.max_pool = partial(nnx.max_pool, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)))
        self.layer1 = self._make_layer(64, 3, strides=(1, 1))
        self.layer2 = self._make_layer(128, 4, strides=(2, 2))
        self.layer3 = self._make_layer(256, 6, strides=(2, 2))
        self.layer4 = self._make_layer(512, 3, strides=(2, 2))
        self.avg_pool = partial(nnx.avg_pool, window_shape=(7, 7), strides=(1, 1))
        self.linear = nnx.Linear(512, 10, rngs=rngs)

    def _make_layer(self, planes, blocks, strides=(1, 1)):
        downsample = None
        if strides != (1, 1) or self.inplanes != planes:
            downsample = nnx.Sequential(
                nnx.Conv(self.inplanes, planes, kernel_size=(1, 1), strides=strides, use_bias=False, rngs=rngs),
                nnx.BatchNorm(num_features=planes, rngs=rngs),
            )
        layers = [BasicBlock(self.inplanes, planes, strides=strides, downsample=downsample)]
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.inplanes, planes))
        return nnx.Sequential(*layers)

    def __call__(self, x):
        x = nnx.relu(self.bn1(self.conv1(x)))
        x = self.max_pool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avg_pool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.linear(x)
        return x

In [7]:
model = ResNet()
y = model(jnp.ones((1, 224, 224, 3)))
nnx.display(y)

[[-0.19926558  0.10369247 -0.41314387 -0.70792234 -0.1057477   1.4667993
  -0.0979929  -0.6685898  -1.5630112   0.17691338]]


In [8]:
learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(model, optax.sgd(learning_rate, momentum))
metrics = nnx.MultiMetric(accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average("loss"))

In [9]:
def loss_fn(model, images, labels):
    logits = model(images)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
    return loss, logits

@nnx.jit
def train_step(model, optimizer, metrics, images, labels):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, images, labels)
    metrics.update(loss=loss, logits=logits, labels=labels)
    optimizer.update(grads)

@nnx.jit
def eval_step(model, metrics, images, labels):
    loss, logits = loss_fn(model, images, labels)
    metrics.update(loss=loss, logits=logits, labels=labels)

In [10]:
epochs = 10
batch_size = 32
eval_every = len(X_train)
train_steps = len(X_train) // batch_size
test_steps = len(X_test) // batch_size
metrics_history = {"train_loss": [], "train_accuracy": [], "test_loss": [], "test_accuracy": []}

for epoch in range(1, epochs + 1):
    for step in range(train_steps):
        sample = np.random.randint(0, len(X_train), size=batch_size)
        images, labels = transform(X_train[sample]), Y_train[sample]
        train_step(model, optimizer, metrics, images, labels)

        if step > 0 and (step % eval_every == 0 or step == train_steps - 1):
            for metric, value in metrics.compute().items():
                metrics_history[f"train_{metric}"].append(value)
            metrics.reset()

            for test_step in range(test_steps):
                images = transform(X_test[batch_size*test_step:batch_size*(test_step+1)])
                labels = Y_test[batch_size*test_step:batch_size*(test_step+1)]
                eval_step(model, metrics, images, labels)

            for metric, value in metrics.compute().items():
                metrics_history[f"test_{metric}"].append(value)
            metrics.reset()

            print(
                f"[train] epoch: {epoch}, step: {step}, "
                f"loss: {metrics_history['train_loss'][-1]:.4f}, "
                f"accuracy: {metrics_history['train_accuracy'][-1] * 100:.2f}"
            )
            print(
                f"[test] epoch: {epoch}, step: {step}, "
                f"loss: {metrics_history['test_loss'][-1]:.4f}, "
                f"accuracy: {metrics_history['test_accuracy'][-1] * 100:.2f}"
            )
    learning_rate *= 0.8
    optimizer = nnx.Optimizer(model, optax.sgd(learning_rate, momentum))

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[64] wrapped in a JVPTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError