In [19]:
from iresnet.datasets import get_cifar10_data
from iresnet.models import resnet_cifar10
import jax
import equinox as eqx
import optax
from jaxtyping import Array, Float, Int, PyTree

In [14]:
# Hyperparameters

BATCH_SIZE = 2
LEARNING_RATE = 3e-4
STEPS = 300
PRINT_EVERY = 30
SEED = 5678

key = jax.random.PRNGKey(SEED)

In [3]:
trn, tst = get_cifar10_data(BATCH_SIZE)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
dummy_x, dummy_y = next(iter(trn))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  
print(dummy_y.shape)  
print(dummy_y)

(2, 3, 32, 32)
(2,)
[6 3]


In [5]:
key, subkey = jax.random.split(key, 2)
model = resnet_cifar10(subkey)

In [11]:
from jaxtyping import Array, Int, Float
import jax.numpy as jnp

def loss(
    model: resnet_cifar10, x: Float[Array, "batch 3 32 32"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    pred_y = jax.vmap(model)(x)
    return cross_entropy(y, pred_y)


def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)


# Example loss
loss_value = loss(model, dummy_x, dummy_y)
print(loss_value.shape)  # scalar loss
# Example inference
output = jax.vmap(model)(dummy_x)
print(output.shape)  # batch of predictions

()
(2, 10)


In [12]:
# This will work too!
value, grads = eqx.filter_value_and_grad(loss)(model, dummy_x, dummy_y)
print(value)

0.091591276


In [8]:
loss = eqx.filter_jit(loss)  # JIT our loss function from earlier!


@eqx.filter_jit
def compute_accuracy(
    model: resnet_cifar10, x: Float[Array, "batch 3 32 32"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    """This function takes as input the current model
    and computes the average accuracy on a batch.
    """
    pred_y = jax.vmap(model)(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)



In [9]:
import torch
def evaluate(model: resnet_cifar10, testloader: torch.utils.data.DataLoader):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(model, x, y)
        avg_acc += compute_accuracy(model, x, y)
    return avg_loss / len(testloader), avg_acc / len(testloader)

In [15]:
evaluate(model, tst)

(Array(-0.06289963, dtype=float32), Array(0.1, dtype=float32))

In [16]:
optim = optax.adamw(LEARNING_RATE)

In [17]:
def train(
    model: resnet_cifar10,
    trainloader: torch.utils.data.DataLoader,
    testloader: torch.utils.data.DataLoader,
    optim: optax.GradientTransformation,
    steps: int,
    print_every: int,
) -> resnet_cifar10:
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.
    @eqx.filter_jit
    def make_step(
        model: resnet_cifar10,
        opt_state: PyTree,
        x: Float[Array, "batch 3 32 32"],
        y: Int[Array, " batch"],
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    # Loop over our training dataset as many times as we need.
    def infinite_trainloader():
        while True:
            yield from trainloader

    for step, (x, y) in zip(range(steps), infinite_trainloader()):
        # PyTorch dataloaders give PyTorch tensors by default,
        # so convert them to NumPy arrays.
        x = x.numpy()
        y = y.numpy()
        model, opt_state, train_loss = make_step(model, opt_state, x, y)
        if (step % print_every) == 0 or (step == steps - 1):
            test_loss, test_accuracy = evaluate(model, testloader)
            print(
                f"{step=}, train_loss={train_loss.item()}, "
                f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
            )
    return model

In [20]:
model = train(model, trn, tst, optim, STEPS, PRINT_EVERY)

step=0, train_loss=0.15214696526527405, test_loss=-0.07269386947154999, test_accuracy=0.10019999742507935
