In [10]:
import equinox as eqx
import jax
import jax.numpy as jnp
import optax
import torch
import torchvision
from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree

from linax.architecture.base import AbstractModel
from linax.architecture.linoss import LinossModel, LinossModelConfig

In [11]:
# Hyperparameters

BATCH_SIZE = 10
LEARNING_RATE = 3e-4
STEPS = 300
PRINT_EVERY = 1
SEED = 5678

key = jax.random.PRNGKey(SEED)

In [None]:
# replace with https://edwin-de-jong.github.io/blog/mnist-sequence-data/

In [12]:
normalise_data = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,)),
        torchvision.transforms.Lambda(lambda x: x.view(-1, 1)),
    ]
)
train_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=True,
    download=True,
    transform=normalise_data,
)
test_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=False,
    download=True,
    transform=normalise_data,
)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [13]:
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trainloader))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  # 64x784x1
print(dummy_y.shape)  # 64
print(dummy_y)

(10, 784, 1)
(10,)
[0 7 0 9 5 3 5 0 2 1]


In [14]:
key, subkey = jax.random.split(key, 2)
model = LinossModel(
    in_features=1,
    key=subkey,
    out_features=10,
    cfg=LinossModelConfig(
        hidden_dim=16,
    ),
)
state = eqx.nn.State(model)

In [15]:
print(model)

LinossModel(
  linear_encoder=Linear(
    weight=f32[16,1], bias=None, in_features=1, out_features=16, use_bias=False
  ),
  linear_decoder=Linear(
    weight=f32[10,16],
    bias=None,
    in_features=16,
    out_features=10,
    use_bias=False
  ),
  blocks=[
    LinossEncoderBlock(
      norm=BatchNorm(
        weight=None,
        bias=None,
        ema_first_time_index=StateIndex(
          marker=<object object at 0x1357dfe70>, init=bool[]
        ),
        ema_state_index=StateIndex(
          marker=<object object at 0x1357dfeb0>, init=(f32[16], f32[16])
        ),
        batch_counter=None,
        batch_state_index=None,
        axis_name='batch',
        inference=False,
        input_size=16,
        eps=1e-05,
        channelwise_affine=False,
        momentum=0.99,
        mode='ema'
      ),
      sequence_mixer=LinOSSSequenceMixer(
        A_diag=f32[16],
        G_diag=f32[16],
        B=f32[16,16,2],
        C=f32[16,16,2],
        D=f32[16],
        steps=f32[16],


In [16]:
def loss(
    model: AbstractModel,
    x: Float[Array, "batch 784 1"],
    y: Int[Array, " batch"],
    state: eqx.nn.State,
    key: PRNGKeyArray,
) -> Float[Array, ""]:
    """Batch over loss function."""
    keys = jax.random.split(key, x.shape[0])
    pred_y, model_state = jax.vmap(
        model,
        axis_name="batch",
        in_axes=(0, None, 0),
        out_axes=(0, None),
    )(x, state, keys)
    return cross_entropy(y, pred_y), model_state


def cross_entropy(y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]) -> Float[Array, ""]:
    """Cross entropy loss function."""
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)


# Example loss
loss_value, _ = loss(model, dummy_x, dummy_y, state, key)
print(loss_value.shape)  # scalar loss
print(loss_value)  # scalar loss

()
2.494089


In [17]:
(loss_value, new_state), grads = eqx.filter_value_and_grad(loss, has_aux=True)(
    model, dummy_x, dummy_y, state, key
)

  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


In [18]:
print(loss_value)

2.494089


In [19]:
loss = eqx.filter_jit(loss)  # JIT our loss function from earlier!


@eqx.filter_jit
def compute_accuracy(
    model: AbstractModel,
    x: Float[Array, "batch 784 1"],
    y: Int[Array, " batch"],
    state: eqx.nn.State,
    key: PRNGKeyArray,
) -> Float[Array, ""]:
    """Computes the average accuracy on a batch."""
    keys = jax.random.split(key, x.shape[0])
    pred_y, _ = jax.vmap(
        model,
        axis_name="batch",
        in_axes=(0, None, 0),
        out_axes=(0, None),
    )(x, state, keys)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)

In [20]:
def evaluate(
    model: AbstractModel,
    testloader: torch.utils.data.DataLoader,
    state: eqx.nn.State,
    key: PRNGKeyArray,
):
    """Evaluates the model on the test dataset."""
    inference_model = eqx.tree_inference(model, value=True)
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(inference_model, x, y, state, key)[0]
        avg_acc += compute_accuracy(inference_model, x, y, state, key)
    return avg_loss / len(testloader), avg_acc / len(testloader)

In [21]:
optim = optax.adamw(LEARNING_RATE)

In [None]:
def train(
    model: AbstractModel,
    trainloader: torch.utils.data.DataLoader,
    testloader: torch.utils.data.DataLoader,
    optim: optax.GradientTransformation,
    steps: int,
    print_every: int,
    state: eqx.nn.State,
    key: PRNGKeyArray,
) -> AbstractModel:
    """Trains the model on the training dataset."""
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.
    @eqx.filter_jit
    def make_step(
        model: AbstractModel,
        opt_state: PyTree,
        x: Float[Array, "batch 784 1"],
        y: Int[Array, " batch"],
        state: eqx.nn.State,
        key: PRNGKeyArray,
    ):
        (loss_value, new_state), grads = eqx.filter_value_and_grad(loss, has_aux=True)(
            model, x, y, state, key
        )
        updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value, new_state

    # Loop over our training dataset as many times as we need.
    def infinite_trainloader():
        while True:
            yield from trainloader

    key, train_key = jax.random.split(key, 2)
    for step, (x, y) in zip(range(steps), infinite_trainloader()):
        # PyTorch dataloaders give PyTorch tensors by default,
        # so convert them to NumPy arrays.
        x = x.numpy()
        y = y.numpy()
        model, opt_state, train_loss, new_state = make_step(
            model, opt_state, x, y, state, train_key
        )
        if (step % print_every) == 0 or (step == steps - 1):
            test_loss, test_accuracy = evaluate(model, testloader, new_state, key)
            print(
                f"{step=}, train_loss={train_loss.item()}, "
                f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
            )
    return model


model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY, state, key)

step=0, train_loss=2.4805493354797363, test_loss=2.401088237762451, test_accuracy=0.09579956531524658
