In [9]:
from iresnet.datasets import get_cifar10_data
from iresnet.models import resnet_cifar10
import jax
import equinox as eqx


In [10]:
# Hyperparameters

BATCH_SIZE = 2
LEARNING_RATE = 3e-4
STEPS = 300
PRINT_EVERY = 30
SEED = 5678

key = jax.random.PRNGKey(SEED)

In [11]:
trn, tst = get_cifar10_data(BATCH_SIZE)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
dummy_x, dummy_y = next(iter(trn))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  
print(dummy_y.shape)  
print(dummy_y)

(2, 3, 32, 32)
(2,)
[2 7]


In [13]:
key, subkey = jax.random.split(key, 2)
model = resnet_cifar10(subkey)

In [14]:
from jaxtyping import Array, Int, Float
import jax.numpy as jnp

def loss(
    model: resnet_cifar10, x: Float[Array, "batch 3 32 32"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    print(x.shape)
    pred_y = jax.vmap(model)(x)
    return cross_entropy(y, pred_y)


def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)


# Example loss
loss_value = loss(model, dummy_x, dummy_y)
print(loss_value.shape)  # scalar loss
# Example inference
output = jax.vmap(model)(dummy_x)
print(output.shape)  # batch of predictions

(2, 3, 32, 32)
Inside model shape:(3, 32, 32)
()
Inside model shape:(3, 32, 32)
(2, 10)


In [15]:
# This will work too!
value, grads = eqx.filter_value_and_grad(loss)(model, dummy_x, dummy_y)
print(value)

(2, 3, 32, 32)
Inside model shape:(3, 32, 32)
-0.3145994


In [16]:
loss = eqx.filter_jit(loss)  # JIT our loss function from earlier!


@eqx.filter_jit
def compute_accuracy(
    model: resnet_cifar10, x: Float[Array, "batch 3 32 32"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    """This function takes as input the current model
    and computes the average accuracy on a batch.
    """
    pred_y = jax.vmap(model)(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)



In [17]:
import torch
def evaluate(model: resnet_cifar10, testloader: torch.utils.data.DataLoader):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(model, x, y)
        avg_acc += compute_accuracy(model, x, y)
    return avg_loss / len(testloader), avg_acc / len(testloader)

In [18]:
evaluate(model, tst)

(2, 3, 32, 32)
Inside model shape:(3, 32, 32)
Inside model shape:(3, 32, 32)


(Array(-0.05568978, dtype=float32), Array(0.1, dtype=float32))