# Tutorial #4: CIFAR10 via AlexNet

In this notebook we will see how to code a PCN based on AlexNet and train it on CIFAR10.
We will use PyTorch to handle the dataset and the dataloading, so make sure it is installed in the local environment. We only need the CPU version which, currently, can be installed via `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu`

In [None]:
from typing import Callable

# Core dependencies
import jax
import jax.numpy as jnp

# pcax
import pcax as px
import pcax.predictive_coding as pxc
import pcax.nn as pxnn
import pcax.utils as pxu


class AlexNet(pxc.EnergyModule):
    def __init__(
        self,
        nm_classes: int,
        act_fn: Callable[[jax.Array], jax.Array]
    ) -> None:
        super().__init__()

        self.nm_classes = nm_classes
        
        # Note we use a custom activation function and not exclusively ReLU since
        # it does not seem to perform as well as in backpropagation
        self.act_fn = px.static(act_fn)

        # We define the convolutional layers. We organise them in blocks just for clarity.
        # Ideally, pcax will soon support a "pxnn.Sequential" module to ease the definition
        # of such blocks. Layers are based on equinox.nn, so check their documentation for
        # more information.
        self.feature_layers = [
            (
                pxnn.Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
                self.act_fn,
                pxnn.MaxPool2d(kernel_size=2, stride=2)
            ),
            (
                pxnn.Conv2d(64, 192, kernel_size=(3), padding=(1, 1)),
                self.act_fn,
                pxnn.MaxPool2d(kernel_size=2, stride=2)
            ),
            (
                pxnn.Conv2d(192, 384, kernel_size=(3, 3), padding=(1, 1)),
                self.act_fn
            ),
            (
                pxnn.Conv2d(384, 256, kernel_size=(3, 3), padding=(1, 1)),
                self.act_fn
            ),
            (
                pxnn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1)),
                self.act_fn,
                pxnn.MaxPool2d(kernel_size=2, stride=2)
            )
        ]
        # We define the classifier layers. We organise them in blocks just for clarity.
        self.classifier_layers = [
            (
                pxnn.Linear(256 * 2 * 2, 4096),
                self.act_fn
            ),
            (
                pxnn.Linear(4096, 4096),
                self.act_fn
            ),
            (
                pxnn.Linear(4096, self.nm_classes),
            )
        ]

        # We define the Vode modules. Note that currently each vode requires its shape
        # to be manually specified. This will be improved in the near future as lazy
        # initialisation should be possible.
        self.vodes = [
            pxc.Vode(shape) for _, shape in zip(range(len(self.feature_layers)), [
                (64, 8, 8),
                (192, 4, 4), 
                (384, 4, 4),
                (256, 4, 4),
                (256, 2, 2)
            ])
        ] + [
            pxc.Vode((4096,)) for _ in range(len(self.classifier_layers) - 1)
        ] + [pxc.Vode((self.nm_classes,), energy_fn=pxc.ce_energy)]

        # Remember 'frozen' is a user specified attribute used later in the gradient function
        self.vodes[-1].h.frozen = True

    def __call__(self, x: jax.Array, y: jax.Array):
        # Nothing new here: we just define the forward pass of the network by iterating
        # through the blocks and vodes. Each block is followed by a vode, to split the
        # computation in indpendent chunks. 
        for block, node in zip(self.feature_layers, self.vodes[:len(self.feature_layers)]):
            for layer in block:
                x = layer(x)
            x = node(x)

        x = x.flatten()
        for block, node in zip(self.classifier_layers, self.vodes[len(self.feature_layers):]):
            for layer in block:
                x = layer(x)
            x = node(x)

        if y is not None:
            self.vodes[-1].set("h", y)

        return self.vodes[-1].get("u")


In [None]:
import torch
import numpy as np

# This is a simple collate function that stacks numpy arrays used to interface
# the PyTorch dataloader with JAX. In the future we hope to provide custom dataloaders
# that are independent of PyTorch.

def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)


# The dataloader assumes cuda is being used, as such it sets 'pin_memory = True' and
# 'prefetch_factor = 2'. Note that the batch size should be constant during training, so
# we set 'drop_last = True' to avoid having to deal with variable batch sizes. 
class TorchDataloader(torch.utils.data.DataLoader):
    def __init__(
        self,
        dataset,
        batch_size=1,
        shuffle=None,
        sampler=None,
        batch_sampler=None,
        num_workers=1,
        pin_memory=True,
        timeout=0,
        worker_init_fn=None,
        persistent_workers=True,
        prefetch_factor=2,
    ):
        super(self.__class__, self).__init__(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=numpy_collate,
            pin_memory=pin_memory,
            drop_last=True if batch_sampler is None else None,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            persistent_workers=persistent_workers,
            prefetch_factor=prefetch_factor,
        )

In [None]:
import torchvision
import torchvision.transforms as transforms

def get_dataloaders(batch_size: int):
    t = transforms.Compose([
        transforms.ToTensor(),
        # These are normalisation factors found online.
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        lambda x: x.numpy()
    ])

    train_dataset = torchvision.datasets.CIFAR10(
        "~/tmp/cifar10/",
        transform=t,
        download=True,
        train=True,
    )

    train_dataloader = TorchDataloader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
    )

    test_dataset = torchvision.datasets.CIFAR10(
        "~/tmp/cifar10/",
        transform=t,
        download=True,
        train=False,
    )
        
    test_dataloader = TorchDataloader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
    )

    return train_dataloader, test_dataloader

In [None]:
import pcax.functional as pxf

# Training functions are identical to tutorial #0, we only change the model definition.

@pxf.vmap(pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=(0, 0), out_axes=0)
def forward(x, y, *, model: AlexNet):
    return model(x, y)

@pxf.vmap(pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=(0,), out_axes=(None, 0), axis_name="batch")
def energy(x, *, model: AlexNet):
    y_ = model(x, None)
    return jax.lax.pmean(model.energy().sum(), "batch"), y_

In [None]:
@pxf.jit(static_argnums=0)
def train_on_batch(
    T: int,
    x: jax.Array,
    y: jax.Array,
    *,
    model: AlexNet,
    optim_w: pxu.Optim,
    optim_h: pxu.Optim
):
    model.train()

    # Init step
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        forward(x, y, model=model)
    
    # Inference steps
    for _ in range(T):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            _, g = pxf.value_and_grad(
                pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
                has_aux=True
            )(energy)(x, model=model)
        
        optim_h.step(model, g["model"], True)

    # Learning step
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        _, g = pxf.value_and_grad(pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True)(energy)(x, model=model)
    optim_w.step(model, g["model"])


@pxf.jit()
def eval_on_batch(x: jax.Array, y: jax.Array, *, model: AlexNet):
    model.eval()
    
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        y_ = forward(x, None, model=model).argmax(axis=-1)
    
    return (y_ == y).mean(), y_

In [None]:
def train(dl, T, *, model: AlexNet, optim_w: pxu.Optim, optim_h: pxu.Optim):
    for x, y in dl:
        train_on_batch(T, x, jax.nn.one_hot(y, model.nm_classes), model=model, optim_w=optim_w, optim_h=optim_h)


def eval(dl, *, model: AlexNet):
    acc = []
    ys_ = []
    
    for x, y in dl:
        a, y_ = eval_on_batch(x, y, model=model)
        acc.append(a)
        ys_.append(y_)
    
    return np.mean(acc), np.concatenate(ys_)

In [None]:
import optax

batch_size = 128
nm_epochs = 24

model = AlexNet(
    nm_classes=10,
    act_fn=jax.nn.gelu
)

with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
    forward(jnp.zeros((batch_size, 3, 32, 32)), None, model=model)
    
    optim_h = pxu.Optim(optax.sgd(5e-2), pxu.Mask(pxc.VodeParam)(model))
    optim_w = pxu.Optim(optax.adamw(1e-4), pxu.Mask(pxnn.LayerParam)(model))

In [None]:
train_dataloader, test_dataloader = get_dataloaders(batch_size)

for e in range(nm_epochs):
    train(train_dataloader, T=13, model=model, optim_w=optim_w, optim_h=optim_h)
    a, y = eval(test_dataloader, model=model)
    print(f"Epoch # {e + 1}/{nm_epochs} - Test Accuracy: {a * 100:.2f}%")

If we want to switch the training method to iPC, we can simply perform weight updates at every step, instead that only after T. In general, hybrid approaches are possible and can be described in terms of 3 values:
- **T**: number of total iterations;
- **T_0**: number of warmup steps;
- **T_i**: number of inference steps between weight updates.

In terms of these variables, *inference learning* would be (T, 0, T) and *iPC* (T, 0, 0)

In [None]:
@pxf.jit(static_argnums=(0, 1, 2))
def train_on_batch_hybrid(
    T: int,
    T_0: int,
    T_i: int,
    x: jax.Array,
    y: jax.Array,
    *,
    model: AlexNet,
    optim_w: pxu.Optim,
    optim_h: pxu.Optim
):
    model.train()

    # Init step
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        forward(x, y, model=model)
    
    # Hybrid steps
    for i in range(T):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            if i >= T_0 and (T_i == 0 or ((i - T_0) % T_i == 0)):
                # gradient with respect of state h and weights
                _, g = pxf.value_and_grad(
                    pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True) | pxu.m(pxnn.LayerParam), [False, True]),
                    has_aux=True
                )(energy)(x, model=model)
                
                optim_w.step(model, g["model"])
            else:
                # gradient only with respect of state h
                _, g = pxf.value_and_grad(
                    pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
                    has_aux=True
                )(energy)(x, model=model)
        
            optim_h.step(model, g["model"], True)

    # we always do a final weight update to avoid wasting any x updates.
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        _, g = pxf.value_and_grad(pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True)(energy)(x, model=model)
    optim_w.step(model, g["model"])


def train_hybrid(dl, T, T_0, T_i, *, model: AlexNet, optim_w: pxu.Optim, optim_h: pxu.Optim):
    for x, y in dl:
        train_on_batch_hybrid(T, T_0, T_i, x, jax.nn.one_hot(y, model.nm_classes), model=model, optim_w=optim_w, optim_h=optim_h)

Note the hyperparameters below don't work :( as training is not stable, but it's very fast in the first epochs!

In [None]:
model = AlexNet(
    nm_classes=10,
    act_fn=jax.nn.leaky_relu
)

with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
    forward(jnp.zeros((batch_size, 3, 32, 32)), None, model=model)
    
    optim_h = pxu.Optim(optax.sgd(3e-1), pxu.Mask(pxc.VodeParam)(model))
    optim_w = pxu.Optim(optax.adamw(5e-5), pxu.Mask(pxnn.LayerParam)(model))

train_dataloader, test_dataloader = get_dataloaders(batch_size)

for e in range(nm_epochs):
    train_hybrid(train_dataloader, T=11, T_0=1, T_i=0, model=model, optim_w=optim_w, optim_h=optim_h)
    a, y = eval(test_dataloader, model=model)
    print(f"Epoch {e + 1}/{nm_epochs} - Test Accuracy: {a * 100:.2f}%")