In [1]:
from iresnet.datasets import get_cifar10_data
from iresnet.models import ResNet18
import jax
import equinox as eqx
import optax
from jaxtyping import Array, Float, Int, PyTree

In [2]:
# Hyperparameters

BATCH_SIZE = 64
LEARNING_RATE = 1e-4
STEPS = 1200
PRINT_EVERY = 30
SEED = 5678

key = jax.random.PRNGKey(SEED)

In [3]:
trn, tst = get_cifar10_data(BATCH_SIZE)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trn))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  # batch_size x3x32x32
print(dummy_y.shape)  # batch_size
print(dummy_y)

(64, 3, 32, 32)
(64,)
[3 3 2 3 3 9 5 1 3 2 5 0 0 5 0 2 5 3 5 6 8 2 1 5 1 2 8 1 9 5 1 8 3 6 1 3 0
 1 8 4 5 7 0 5 6 4 1 5 5 1 4 6 8 7 6 1 2 3 1 0 7 2 2 6]


In [5]:
key, subkey = jax.random.split(key, 2)
model = ResNet18(subkey)

In [6]:
print(model)

ResNet18(
  layers=[
    Conv2d(
      num_spatial_dims=2,
      weight=f32[64,3,3,3],
      bias=None,
      in_channels=3,
      out_channels=64,
      kernel_size=(3, 3),
      stride=(1, 1),
      padding=((1, 1), (1, 1)),
      dilation=(1, 1),
      groups=1,
      use_bias=False
    ),
    BatchNorm(
      weight=f32[64],
      bias=f32[64],
      first_time_index=StateIndex(inference=False),
      state_index=StateIndex(inference=False),
      axis_name='batch',
      inference=False,
      input_size=64,
      eps=1e-05,
      channelwise_affine=True,
      momentum=0.99
    ),
    <wrapped function relu>,
    BasicBlock(
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[64,64,3,3],
          bias=None,
          in_channels=64,
          out_channels=64,
          kernel_size=(3, 3),
          stride=(1, 1),
          padding=((1, 1), (1, 1)),
          dilation=(1, 1),
          groups=1,
          use_bias=False
        ),
        BatchNorm(


In [7]:
from jaxtyping import Array, Int, Float
import jax.numpy as jnp

def loss(
    model: ResNet18, x: Float[Array, "batch 3 32 32"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    pred_y = jax.vmap(model, axis_name="batch")(x)
    return cross_entropy(y, pred_y)


def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)


In [8]:
%load_ext autoreload
%autoreload 2

In [10]:
# Example loss
loss_value = loss(model, dummy_x, dummy_y)
print(loss_value.shape)  # scalar loss
# Example inference
output = jax.vmap(model,axis_name="batch")(dummy_x)
print(output.shape)  # batch of predictions

Output Shape of layer 0 is:  (64, 32, 32)
Output Shape of layer 1 is:  (64, 32, 32)
Output Shape of layer 2 is:  (64, 32, 32)
Output Shape of layer 3 is:  (64, 32, 32)
Output Shape of layer 4 is:  (64, 32, 32)
Output Shape of layer 5 is:  (128, 16, 16)
Output Shape of layer 6 is:  (128, 16, 16)
Output Shape of layer 7 is:  (256, 8, 8)
Output Shape of layer 8 is:  (256, 8, 8)
Output Shape of layer 9 is:  (512, 4, 4)
Output Shape of layer 10 is:  (512, 4, 4)
Output Shape of layer 11 is:  (512, 1, 1)
Output Shape of layer 12 is:  (512,)
Output Shape of layer 13 is:  (10,)
Output Shape of layer 14 is:  (10,)
()
Output Shape of layer 0 is:  (64, 32, 32)
Output Shape of layer 1 is:  (64, 32, 32)
Output Shape of layer 2 is:  (64, 32, 32)
Output Shape of layer 3 is:  (64, 32, 32)
Output Shape of layer 4 is:  (64, 32, 32)
Output Shape of layer 5 is:  (128, 16, 16)
Output Shape of layer 6 is:  (128, 16, 16)
Output Shape of layer 7 is:  (256, 8, 8)
Output Shape of layer 8 is:  (256, 8, 8)
Output 

In [11]:
output

Array([[-2.9125946, -2.2175357, -2.737707 , -1.9967215, -2.8336525,
        -2.352201 , -1.7477696, -2.2363877, -2.7961183, -1.9641101],
       [-3.1072884, -2.4489298, -2.8389645, -2.0237842, -2.628903 ,
        -2.5898547, -1.8109677, -1.8311937, -2.636527 , -1.9970496],
       [-3.2473755, -2.1117609, -2.8919764, -2.3016863, -2.9551249,
        -2.5209174, -1.6660374, -2.2128782, -2.4742818, -1.7750113],
       [-3.1889958, -2.0593648, -3.19628  , -2.2887795, -2.629171 ,
        -2.446034 , -1.7082007, -1.9799814, -2.5837007, -1.9994087],
       [-2.903036 , -1.9820261, -2.9480648, -2.1370368, -2.778913 ,
        -2.6690097, -2.0153184, -1.9490082, -2.4834828, -1.921846 ],
       [-2.8478565, -2.1506553, -2.743957 , -2.0174518, -3.0017042,
        -2.6968443, -1.7695409, -2.0603173, -2.6106029, -1.9669756],
       [-3.0763302, -2.34106  , -2.9499195, -2.191332 , -2.5909436,
        -2.5179925, -1.7239666, -1.842611 , -2.5679805, -2.0833452],
       [-3.3212318, -2.0852141, -2.788743

In [None]:
loss = eqx.filter_jit(loss)  # JIT our loss function from earlier!


@eqx.filter_jit
def compute_accuracy(
    model: ResNet18, x: Float[Array, "batch 3 32 32"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    """This function takes as input the current model
    and computes the average accuracy on a batch.
    """
    pred_y = jax.vmap(model,axis_name="batch")(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)



In [None]:
import torch
def evaluate(model: ResNet18, testloader: torch.utils.data.DataLoader):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(model, x, y)
        avg_acc += compute_accuracy(model, x, y)
    return avg_loss / len(testloader), avg_acc / len(testloader)

In [None]:
evaluate(model, tst)

In [None]:
optim = optax.adamw(LEARNING_RATE)

In [None]:
def train(
    model: ResNet18,
    trainloader: torch.utils.data.DataLoader,
    testloader: torch.utils.data.DataLoader,
    optim: optax.GradientTransformation,
    steps: int,
    print_every: int,
) -> ResNet18:
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.
    @eqx.filter_jit
    def make_step(
        model: ResNet18,
        opt_state: PyTree,
        x: Float[Array, "batch 3 32 32"],
        y: Int[Array, " batch"],
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    # Loop over our training dataset as many times as we need.
    def infinite_trainloader():
        while True:
            yield from trainloader

    for step, (x, y) in zip(range(steps), infinite_trainloader()):
        # PyTorch dataloaders give PyTorch tensors by default,
        # so convert them to NumPy arrays.
        x = x.numpy()
        y = y.numpy()
        model, opt_state, train_loss = make_step(model, opt_state, x, y)
        if (step % print_every) == 0 or (step == steps - 1):
            test_loss, test_accuracy = evaluate(model, testloader)
            print(
                f"{step=}, train_loss={train_loss.item()}, "
                f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
            )
    return model

In [None]:
model = train(model, trn, tst, optim, STEPS, PRINT_EVERY)

In [None]:
(20 - 2) // 6

3