In [1]:
from iresnet.datasets import get_cifar10_data
from iresnet.models import ResNet18
import jax
import equinox as eqx
import optax
from jaxtyping import Array, Float, Int, PyTree

In [2]:
# Hyperparameters

BATCH_SIZE = 64
LEARNING_RATE = 1e-4
STEPS = 1200
PRINT_EVERY = 30
SEED = 5678

key = jax.random.PRNGKey(SEED)

In [3]:
trn, tst = get_cifar10_data(BATCH_SIZE)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trn))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  # batch_size x3x32x32
print(dummy_y.shape)  # batch_size
print(dummy_y)

(64, 3, 32, 32)
(64,)
[8 5 6 1 8 3 8 4 8 7 2 0 4 3 1 7 9 5 4 7 3 5 0 6 0 1 8 5 3 0 1 5 7 6 6 3 4
 2 2 9 7 3 8 6 0 3 7 8 2 6 7 1 0 9 1 6 3 0 7 6 4 1 3 5]


In [5]:
key, subkey = jax.random.split(key, 2)
model = ResNet18(subkey)

In [6]:
print(model)

ResNet18(
  layers=[
    Conv2d(
      num_spatial_dims=2,
      weight=f32[64,3,3,3],
      bias=None,
      in_channels=3,
      out_channels=64,
      kernel_size=(3, 3),
      stride=(1, 1),
      padding=((1, 1), (1, 1)),
      dilation=(1, 1),
      groups=1,
      use_bias=False
    ),
    BatchNorm(
      weight=f32[64],
      bias=f32[64],
      first_time_index=StateIndex(inference=False),
      state_index=StateIndex(inference=False),
      axis_name='batch',
      inference=False,
      input_size=64,
      eps=1e-05,
      channelwise_affine=True,
      momentum=0.99
    ),
    <wrapped function relu>,
    BasicBlock(
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[64,64,3,3],
          bias=None,
          in_channels=64,
          out_channels=64,
          kernel_size=(3, 3),
          stride=(1, 1),
          padding=((1, 1), (1, 1)),
          dilation=(1, 1),
          groups=1,
          use_bias=False
        ),
        BatchNorm(


In [7]:
from jaxtyping import Array, Int, Float
import jax.numpy as jnp

def loss(
    model: ResNet18, x: Float[Array, "batch 3 32 32"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    pred_y = jax.vmap(model, axis_name="batch")(x)
    return cross_entropy(y, pred_y)


def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)


In [8]:
%load_ext autoreload
%autoreload 2

In [9]:
# Example loss
loss_value = loss(model, dummy_x, dummy_y)
print(loss_value.shape)  # scalar loss
# Example inference
output = jax.vmap(model,axis_name="batch")(dummy_x)
print(output.shape)  # batch of predictions

()
(64, 10)


In [10]:
output

Array([[-2.9789512, -2.3950043, -2.6886563, -1.8656045, -2.5833943,
        -2.6215875, -1.9210961, -1.9728549, -2.5671916, -2.0814779],
       [-3.008628 , -1.9291335, -3.0460534, -2.3002076, -2.5005283,
        -2.6957893, -1.9287262, -1.9577394, -2.4861603, -1.9782414],
       [-3.0528016, -2.137925 , -2.7659683, -2.175425 , -2.5703459,
        -2.5074096, -1.970276 , -1.9773171, -2.6465983, -1.8861097],
       [-3.0155644, -2.2839663, -2.5423312, -1.9006605, -2.805281 ,
        -2.6580024, -1.9196169, -1.9611559, -2.7647502, -1.9659599],
       [-2.8039777, -2.1381135, -3.162402 , -2.4216886, -2.884741 ,
        -2.4300232, -1.8191922, -1.98104  , -2.613413 , -1.7533762],
       [-4.1119623, -2.4638362, -3.9997034, -2.830008 , -3.0697064,
        -2.1560962, -1.780485 , -1.8142817, -1.8747239, -1.7483344],
       [-3.367105 , -2.0736187, -3.204104 , -2.4452202, -2.9308858,
        -2.5664139, -1.8493356, -1.9590516, -2.1562192, -1.7824687],
       [-3.200592 , -2.3197198, -2.788048

In [11]:
loss = eqx.filter_jit(loss)  # JIT our loss function from earlier!


@eqx.filter_jit
def compute_accuracy(
    model: ResNet18, x: Float[Array, "batch 3 32 32"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    """This function takes as input the current model
    and computes the average accuracy on a batch.
    """
    pred_y = jax.vmap(model,axis_name="batch")(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)



In [12]:
import torch
def evaluate(model: ResNet18, testloader: torch.utils.data.DataLoader):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(model, x, y)
        avg_acc += compute_accuracy(model, x, y)
    return avg_loss / len(testloader), avg_acc / len(testloader)

In [13]:
evaluate(model, tst)

(Array(2.3932571, dtype=float32), Array(0.09136146, dtype=float32))

In [14]:
optim = optax.adamw(LEARNING_RATE)

In [15]:
def train(
    model: ResNet18,
    trainloader: torch.utils.data.DataLoader,
    testloader: torch.utils.data.DataLoader,
    optim: optax.GradientTransformation,
    steps: int,
    print_every: int,
) -> ResNet18:
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.
    @eqx.filter_jit
    def make_step(
        model: ResNet18,
        opt_state: PyTree,
        x: Float[Array, "batch 3 32 32"],
        y: Int[Array, " batch"],
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    # Loop over our training dataset as many times as we need.
    def infinite_trainloader():
        while True:
            yield from trainloader

    for step, (x, y) in zip(range(steps), infinite_trainloader()):
        # PyTorch dataloaders give PyTorch tensors by default,
        # so convert them to NumPy arrays.
        x = x.numpy()
        y = y.numpy()
        model, opt_state, train_loss = make_step(model, opt_state, x, y)
        if (step % print_every) == 0 or (step == steps - 1):
            test_loss, test_accuracy = evaluate(model, testloader)
            print(
                f"{step=}, train_loss={train_loss.item()}, "
                f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
            )
    return model

In [16]:
model = train(model, trn, tst, optim, STEPS, PRINT_EVERY)

step=0, train_loss=2.3915891647338867, test_loss=2.5397748947143555, test_accuracy=0.0895700678229332
step=30, train_loss=1.985224962234497, test_loss=2.014582395553589, test_accuracy=0.2395501583814621
step=60, train_loss=2.0089426040649414, test_loss=1.833449125289917, test_accuracy=0.3064291477203369
step=90, train_loss=1.9649834632873535, test_loss=1.7725344896316528, test_accuracy=0.3427547812461853
step=120, train_loss=1.8716309070587158, test_loss=1.7114099264144897, test_accuracy=0.3677348792552948
step=150, train_loss=1.7450339794158936, test_loss=1.6816104650497437, test_accuracy=0.378085196018219
step=180, train_loss=1.627855658531189, test_loss=1.6423696279525757, test_accuracy=0.39649683237075806
step=210, train_loss=1.6442811489105225, test_loss=1.6072242259979248, test_accuracy=0.3958996832370758
step=240, train_loss=1.6049822568893433, test_loss=1.573119044303894, test_accuracy=0.4198845624923706
step=270, train_loss=1.5509659051895142, test_loss=1.5599132776260376, tes

KeyboardInterrupt: 