# Tutorial #5: Decoder-only PCN

In this notebook we will see how to code a decoder-only PCN and train it to generate FashionMNIST images.
We will use PyTorch to handle the dataset and the dataloading, so make sure it is installed in the local environment. We only need the CPU version which, currently, can be installed via `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu`

We use this tutorial to also introduce the concept of *ruleset*. In pcax, to configure the behaviour of a Vode is not always necessary to inherit from the base class, as a lot can be configured via the class constructur. In particular, a ruleset allows us to specify the node behaviour for input and output values (i.e., what transformations to apply to incoming activation and outgoing state values). A detailed explanation is provided within the library, as a comment of the ruleset class, so please refer to that.
In summary, we have the following:
- **input rules** are rules that follow the pattern `t1, t2, ... <- v:f1:f2:...`. Every time the activation `v` is set it is tranformed by all `f_i` and then saved to all `t_j` in the cache. The default behavior is clearly `v <- v` (i.e., a received activation is cached as it is). By default, each node uses the following ruleset: `STATUS.INIT: ("h, u <- u",)` which specifies that, if the status is set to `pxc.STATUS.INIT` the incoming activation `u` is not only saved to `u` but also to `h` (setting `h = u` which is forward initialisation)
- **output rules** are rules that follow the pattern `k -> t:f1:f2:...`. Similarly to input rules, every time the state value `k` is queried, the result of applying each `f_i` to the state value `t` is instead returned. Note that the result is cached, so subsequent calls to the same `k` will reuse the already computed value, unless the cache is cleared (this is why we specify `clear_params=pxc.VodeParam.Cache` in `pxu.step`). This is convenient because, for example, the energy is cached as `"E"` (which is thus a reserved value) and can be queried outside the loss function without performing new computations.

In this example, we use input rules to override the forward initialisation and initialise nodes to 0s instead (as we do not have a value to forward). Similarly, we define a new status `STATUS_FORWARD` that allows us to perform a forward pass by defining the output rule `h -> u` according to which every time we query `h` we get `u` instead (basically ignoring the Vode altogether).
Each set of rules is associated to a status (or multiple of them) via its key (such as `STATUS_FORWARD: ...`). The key can be a regex to match multiple statuses.

In [None]:
from typing import Callable

# Core dependencies
import jax
import jax.numpy as jnp

# pcax
import pcax as px
import pcax.predictive_coding as pxc
import pcax.nn as pxnn
import pcax.utils as pxu

STATUS_FORWARD = "forward"


class Decoder(pxc.EnergyModule):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        nm_layers: int,
        act_fn: Callable[[jax.Array], jax.Array],
    ) -> None:
        super().__init__()

        self.act_fn = px.static(act_fn)

        self.layers = (
            [pxnn.Linear(input_dim, hidden_dim)]
            + [pxnn.Linear(hidden_dim, hidden_dim) for _ in range(nm_layers - 2)]
            + [pxnn.Linear(hidden_dim, output_dim)]
        )

        # We initialise the first node to zero.
        # We use 'zero_energy' as we do not want any prior on the first layer.
        self.vodes = (
            [
                pxc.Vode(
                    (input_dim,),
                    energy_fn=pxc.zero_energy,
                    ruleset={pxc.STATUS.INIT: ("h, u <- u:to_zero",)},
                    tforms={"to_zero": lambda n, k, v, rkg: jnp.zeros(n.shape)},
                )
            ]
            + [
                # we stick with default forward initialisation for now for the remaining nodes,
                # however we enable a "forward mode" where we forward the incoming activation instead
                # of the node state; this is used during evaluation to generate the encoded output.
                pxc.Vode(
                    (hidden_dim,),
                    ruleset={
                        pxc.STATUS.INIT: ("h, u <- u:to_zero",),
                        STATUS_FORWARD: ("h -> u",)
                    },
                    tforms={"to_zero": lambda n, k, v, rkg: jnp.zeros_like(v)},
                )
                for _ in range(nm_layers - 1)
            ]
            + [pxc.Vode((output_dim,))]
        )
        self.vodes[-1].h.frozen = True

    def __call__(self, y: jax.Array | None):
        # The defined ruleset for the first node is to set the hidden state to zero,
        # independent of the input, so we always pass '-1'.
        x = self.vodes[0](-1)
        for i, layer in enumerate(self.layers):
            act_fn = self.act_fn if i != len(self.layers) - 1 else lambda x: x
            x = act_fn(layer(x))
            x = self.vodes[i + 1](x)

        if y is not None:
            self.vodes[-1].set("h", y.flatten())

        return self.vodes[-1].get("u")

In [None]:
import torch
import numpy as np


# This is a simple collate function that stacks numpy arrays used to interface
# the PyTorch dataloader with JAX. In the future we hope to provide custom dataloaders
# that are independent of PyTorch.
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)


# The dataloader assumes cuda is being used, as such it sets 'pin_memory = True' and
# 'prefetch_factor = 2'. Note that the batch size should be constant during training, so
# we set 'drop_last = True' to avoid having to deal with variable batch sizes.
class TorchDataloader(torch.utils.data.DataLoader):
    def __init__(
        self,
        dataset,
        batch_size=1,
        shuffle=None,
        sampler=None,
        batch_sampler=None,
        num_workers=1,
        pin_memory=True,
        timeout=0,
        worker_init_fn=None,
        persistent_workers=True,
        prefetch_factor=2,
    ):
        super(self.__class__, self).__init__(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=numpy_collate,
            pin_memory=pin_memory,
            drop_last=True if batch_sampler is None else None,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            persistent_workers=persistent_workers,
            prefetch_factor=prefetch_factor,
        )

In [None]:
import torchvision
import torchvision.transforms as transforms


def get_dataloaders(batch_size: int):
    t = transforms.Compose(
        [
            transforms.ToTensor(),
            lambda x: x.numpy(),
        ]
    )

    train_dataset = torchvision.datasets.FashionMNIST(
        "~/tmp/fashion-mnist/",
        transform=t,
        download=True,
        train=True,
    )

    train_dataloader = TorchDataloader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
    )

    test_dataset = torchvision.datasets.FashionMNIST(
        "~/tmp/fashion-mnist/",
        transform=t,
        download=True,
        train=False,
    )

    test_dataloader = TorchDataloader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
    )

    return train_dataloader, test_dataloader

In [None]:
import pcax.functional as pxf


@pxf.vmap(pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=0, out_axes=0)
def forward(x, *, model: Decoder):
    return model(x)


@pxf.vmap(pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), out_axes=(None, 0), axis_name="batch")
def energy(*, model: Decoder):
    y_ = model(None)
    return jax.lax.pmean(model.energy().sum(), "batch"), y_

In [None]:
@pxf.jit(static_argnums=0)
def train_on_batch(T: int, x: jax.Array, *, model: Decoder, optim_w: pxu.Optim, optim_h: pxu.Optim):
    model.train()

    inference_step = pxf.value_and_grad(pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]), has_aux=True)(
        energy
    )

    learning_step = pxf.value_and_grad(pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True)(energy)

    # Init step
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        forward(x, model=model)

    optim_h.init(pxu.Mask(pxc.VodeParam)(model))

    # Inference steps
    for _ in range(T):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            _, g = inference_step(model=model)

        optim_h.step(model, g["model"], True)

    # Learning step
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        _, g = learning_step(model=model)
    optim_w.step(model, g["model"])


@pxf.jit(static_argnums=0)
def eval_on_batch(T: int, x: jax.Array, *, model: Decoder, optim_h: pxu.Optim):
    model.eval()

    inference_step = pxf.value_and_grad(pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]), has_aux=True)(
        energy
    )

    # Init step
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        forward(x, model=model)
    
    optim_h.init(pxu.Mask(pxc.VodeParam)(model))

    # Inference steps
    for _ in range(T):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            _, g = inference_step(model=model)

        optim_h.step(model, g["model"], True)

    with pxu.step(model, STATUS_FORWARD, clear_params=pxc.VodeParam.Cache):
        x_hat = forward(None, model=model)

    l = jnp.square(jnp.clip(x_hat.flatten(), 0.0, 1.0) - x.flatten()).mean()

    return l, x_hat

In [None]:
def train(dl, T, *, model: Decoder, optim_w: pxu.Optim, optim_h: pxu.Optim):
    for x, y in dl:
        train_on_batch(T, x, model=model, optim_w=optim_w, optim_h=optim_h)


def eval(dl, T, *, model: Decoder, optim_h: pxu.Optim):
    losses = []

    for x, y in dl:
        e, y_hat = eval_on_batch(T, x, model=model, optim_h=optim_h)
        losses.append(e)

    return np.mean(e)


In [None]:
import optax

batch_size = 128
nm_epochs = 24

model = Decoder(input_dim=64, hidden_dim=256, output_dim=28 * 28, nm_layers=4, act_fn=jax.nn.tanh)

with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
    forward(jnp.zeros((batch_size, 28 * 28)), model=model)

    optim_h = pxu.Optim(optax.sgd(5e-2, momentum=0.1), pxu.Mask(pxc.VodeParam)(model))
    optim_w = pxu.Optim(optax.adamw(1e-4), pxu.Mask(pxnn.LayerParam)(model))


In [None]:
train_dataloader, test_dataloader = get_dataloaders(batch_size)

for e in range(nm_epochs):
    train(train_dataloader, T=20, model=model, optim_w=optim_w, optim_h=optim_h)
    l = eval(test_dataloader, T=20, model=model, optim_h=optim_h)
    print(f"Epoch {e + 1}/{nm_epochs} - Test Loss: {l:.4f}")