In [163]:
import equinox as eqx
import jax
import jax.numpy as jnp
import optax  # https://github.com/deepmind/optax
import torch  # https://pytorch.org
import torchvision  # https://pytorch.org
from jaxtyping import Array, Float, Int, PyTree  # https://github.com/google/jaxtyping

In [164]:
# Hyperparameters

BATCH_SIZE = 64
LEARNING_RATE = 3e-4
STEPS = 300
PRINT_EVERY = 30
SEED = 5678

key = jax.random.PRNGKey(SEED)

In [165]:
normalise_data = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
    ]
)
train_dataset = torchvision.datasets.CIFAR10(
    "CIFAR10",
    train=True,
    download=True,
    transform=normalise_data,
)
test_dataset = torchvision.datasets.CIFAR10(
    "CIFAR10",
    train=False,
    download=True,
    transform=normalise_data,
)
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
)

Files already downloaded and verified
Files already downloaded and verified


In [166]:
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trainloader))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  # 64x1x28x28
print(dummy_y.shape)  # 64
print(dummy_y)

(64, 3, 32, 32)
(64,)
[9 9 0 9 8 1 7 1 2 5 0 4 7 4 9 7 9 8 7 6 0 7 7 5 9 7 3 9 6 5 9 7 2 6 9 2 0
 9 9 8 6 6 6 5 3 0 8 9 7 9 6 9 5 1 0 4 4 5 3 8 7 2 9 6]


In [167]:
class MLP(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        self.layers = [
            jnp.ravel,
            eqx.nn.Linear(3072, 1500, key=key1),
            jax.nn.relu,
            eqx.nn.Linear(1500, 200, key=key2),
            jax.nn.relu,
            eqx.nn.Linear(200, 10, key=key3),
            jax.nn.softmax
        ]

    def __call__(self, x: Float[Array, "3 32 32"]) -> Float[Array, "10"]:
        for layer in self.layers:
            x = layer(x)
        return x


key, subkey = jax.random.split(key, 2)
model = MLP(subkey)

In [168]:
print(model)

MLP(
  layers=[
    <wrapped function ravel>,
    Linear(
      weight=f32[1500,3072],
      bias=f32[1500],
      in_features=3072,
      out_features=1500,
      use_bias=True
    ),
    <wrapped function relu>,
    Linear(
      weight=f32[200,1500],
      bias=f32[200],
      in_features=1500,
      out_features=200,
      use_bias=True
    ),
    <wrapped function relu>,
    Linear(
      weight=f32[10,200],
      bias=f32[10],
      in_features=200,
      out_features=10,
      use_bias=True
    ),
    <function softmax>
  ]
)


In [169]:
@eqx.filter_jit
def loss(
    model: MLP, x: Float[Array, "batch 3 32 32"], y: Int[Array, " batch"], lam: Int[Array, ""]) -> Float[Array, ""]:
    # Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
    # a single input input image of shape (1, 28, 28).
    #
    # Therefore, we have to use jax.vmap, which in this case maps our model over the
    # leading (batch) axis.
    pred_y = jax.vmap(model)(x)
    reg = 0
    for layer in model.layers:
        if isinstance(layer, eqx.nn.Linear):
            reg += lam * (jnp.sum(layer.weight ** 2) + jnp.sum(layer.bias ** 2))
    return cross_entropy(y, pred_y) + reg

@eqx.filter_jit
def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)

In [170]:
loss = eqx.filter_jit(loss)  # JIT our loss function from earlier!


@eqx.filter_jit
def compute_accuracy(
    model: MLP, x: Float[Array, "batch 3 32 32"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    """This function takes as input the current model
    and computes the average accuracy on a batch.
    """
    pred_y = jax.vmap(model)(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)

In [171]:
def evaluate(model: MLP, testloader: torch.utils.data.DataLoader, lam: int):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(model, x, y, lam)
        avg_acc += compute_accuracy(model, x, y)
    return avg_loss / len(testloader), avg_acc / len(testloader)

In [172]:
optim = optax.adamw(LEARNING_RATE)

In [173]:
def train(
    model: CNN,
    lam: int,
    trainloader: torch.utils.data.DataLoader,
    optim: optax.GradientTransformation,
    steps: int,
) -> MLP:
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.
    @eqx.filter_jit
    def make_step(
        model: MLP,
        opt_state: PyTree,
        x: Float[Array, "batch 1 28 28"],
        y: Int[Array, " batch"],
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y, lam)
        updates, opt_state = optim.update(
            grads, opt_state, eqx.filter(model, eqx.is_array)
        )
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    # Loop over our training dataset as many times as we need.
    def infinite_trainloader():
        while True:
            yield from trainloader

    for step, (x, y) in zip(range(steps), infinite_trainloader()):
        # PyTorch dataloaders give PyTorch tensors by default,
        # so convert them to NumPy arrays.
        x = x.numpy()
        y = y.numpy()
        model, opt_state, train_loss = make_step(model, opt_state, x, y)
    return model

In [174]:
results = {}
for lam in [10**(-k) for k in range(6, 10)]:
    results[lam] = []
    for trial in range(5):
        # create new model
        key, subkey = jax.random.split(key, 2)
        model = MLP(subkey)
        
        # train & evaluate model
        model = train(model, lam, trainloader, optim, STEPS)
        test_loss, test_acc = evaluate(model, testloader, lam)
        results[lam].append(test_acc)
        
        print(f'lambda: {lam}, test_acc: {test_acc} [trial [{trial}] of 5]')

lambda: 1e-06, test_acc: 0.35250794887542725 [trial [0] of 5]
lambda: 1e-06, test_acc: 0.36634156107902527 [trial [1] of 5]
lambda: 1e-06, test_acc: 0.35798168182373047 [trial [2] of 5]
lambda: 1e-06, test_acc: 0.3364848792552948 [trial [3] of 5]
lambda: 1e-06, test_acc: 0.35280653834342957 [trial [4] of 5]
lambda: 1e-07, test_acc: 0.36514729261398315 [trial [0] of 5]
lambda: 1e-07, test_acc: 0.3498208522796631 [trial [1] of 5]
lambda: 1e-07, test_acc: 0.3543988764286041 [trial [2] of 5]
lambda: 1e-07, test_acc: 0.34016719460487366 [trial [3] of 5]
lambda: 1e-07, test_acc: 0.35061705112457275 [trial [4] of 5]
lambda: 1e-08, test_acc: 0.33399680256843567 [trial [0] of 5]
lambda: 1e-08, test_acc: 0.3418590724468231 [trial [1] of 5]
lambda: 1e-08, test_acc: 0.36813294887542725 [trial [2] of 5]
lambda: 1e-08, test_acc: 0.35867834091186523 [trial [3] of 5]
lambda: 1e-08, test_acc: 0.3280254900455475 [trial [4] of 5]
lambda: 1e-09, test_acc: 0.3460390269756317 [trial [0] of 5]
lambda: 1e-09,

In [175]:
model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)

TypeError: train() takes 5 positional arguments but 6 were given