# Tutorial 0: **pcax**
In this notebook, you will learn how to use pcax to build arbitrary predictive coding networks.
Since the library is still in its early development, expect major syntax changes. You can keep coming back to this notebook to stay updated.

Good luck!

## Part 0: Importing dependencies
Pcax is based on jax and optax, so we will need those.
The library is divided in three modules:
- *core*: defines the basic building blocks of pcax, unrelated to predictive coding itself. You normally will not need to touch this unless you need some rather custom behaviour.
- *pc*: here lies all the predictive coding implementation, which will probably keep changing and being updated.
- *nn*: simply contains the typical layers you could expect from a deep learning library, which are currently built as a wrap around *equinox* layers.
- *utils*: various optional tools.

Furthermore, pcax requires cuda>=11.4 and cudnn >= 8.4. On pssr2 they can be activate with the terminal command `source switch-cuda 11.7`.
Unfortunately it seems to be necessary to repeat this for each jupyter cell (normally you just call it once when you create the terminal you'll work in).

In [1]:
!source switch-cuda 11.7

# Core dependencies
import jax
import optax

# pcax
import pcax as px # same as import pcax.pc as px
import pcax.nn as nn
from pcax.core import _ # _ is the filter object, more about it later!
from pcax.utils.data import TorchDataloader

Switched to CUDA 11.7.


In this example, we will also use the following dependencies:

In [2]:
!source switch-cuda 11.7

import numpy as np
from torchvision.datasets import MNIST
import timeit
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" # remember this line if you are sharing the GPU with others

Switched to CUDA 11.7.


## Part 1: Defining a Model
Defining a basic pcax model is very straightforward: simply interpone in the forward call a `px.Layer` between any two `nn.Link`s (the way we will call standard deep learning layers). To do so first define them in the `__init__`, no arguments are required for basic usage!

In the `__init__`, we also define the activation function we are going to use and the specify we do not want to train the `x` of the last pc layer, since it will contain the label we want to train our network with (it is probably how you will structure most of the PCNs you will use). Note this model does not directly support inference on the inputs as there is no pc layer that stores the input `x` of the forward call.

In [3]:
!source switch-cuda 11.7

"""
A model is defined in a way that is similar to PyTorch. In particular:
- we inherit from px.Module;
- we define the activation functions, layers (i.e., pc layers) and links (i.e., standard layers), in the __init__ method; 
- ModuleList is necessary to define a list of layers/links;
- we define the forward pass in the __call__ method.
"""
class Model(px.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, nm_layers=2) -> None:
        super().__init__()

        self.act_fn = jax.nn.tanh

        """
        This is quite standard. We define the layers and links (one layer for each link).
        """
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear_h = nn.ModuleList(
            [nn.Linear(hidden_dim, hidden_dim) for _ in range(nm_layers)]
        )
        self.linear2 = nn.Linear(hidden_dim, output_dim)

        self.pc1 = px.Layer()
        self.pc_h = nn.ModuleList([px.Layer() for _ in range(nm_layers)])
        self.pc2 = px.Layer()

        """
        We normally use the x of the last layer as the target, therefore we don't want to update it.
        """
        self.pc2.x.frozen = True

    """
    Here things are a bit different. __call__ accepts an optional target t (used during training),
    which is used to set the x of the last layer.
    """
    def __call__(self, x, t=None):
        """
        !!! IMPORTANT !!!
        Each (pc) layer contains a cache the stores the important intermediate values computed in the forward pass.
        By default, these are the incoming activation (u), the node values (x) and the energy (e).
        You can access them by using the [] operator, e.g., self.pc["x"].
        """
        x = self.pc1(self.act_fn(self.linear1(x)))["x"]

        for i in range(len(self.linear_h)):
            x = self.pc_h[i](self.act_fn(self.linear_h[i](x)))["x"]

        x = self.pc2(self.linear2(x))["x"]

        if t is not None:
            self.pc2["x"] = t

        """
        The output of the network is the activation received by the last layer (since its x is clamped to the label).
        """
        return self.pc2["u"]

Switched to CUDA 11.7.


Let's now define the training parameters we are going to use:

In [4]:
params = {
    "batch_size": 256,
    "x_learning_rate": 0.01,
    "w_learning_rate": 1e-3,
    "num_epochs": 4,
    "hidden_dim": 128,
    "input_dim": 28 * 28,
    "output_dim": 10,
    "seed": 0,
    "T": 4,
}

Let's now define the dataloaders we'll need to train and test our model.

In [5]:
!source switch-cuda 11.7


"""
This is all standard and uses PyTorch's datasets and dataloaders. We are assuming cuda is available to set the dataloaders' parameters.
"""


"""
We'll train with the standard pc energy function, that is, the sum of the squared differences between the node values and the target.
Therefore, we need to convert the targets to one-hot vectors.
"""
def one_hot(t, k):
    return np.array(t[:, None] == np.arange(k), dtype=np.float32)


"""
Used to convert the square 0-255 images to 0-1 float vectors.
"""
class FlattenAndCast:
    def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=np.float32) / 255.0)


train_dataset = MNIST(
    "/tmp/mnist/",
    transform=FlattenAndCast(),
    download=True,
    train=True,
)
train_dataloader = TorchDataloader(
    train_dataset,
    batch_size=params["batch_size"],
    num_workers=16,
    shuffle=True,
    persistent_workers=True,
    pin_memory=True,
)


test_dataset = MNIST(
    "/tmp/mnist/",
    transform=FlattenAndCast(),
    download=True,
    train=False,
)
test_dataloader = TorchDataloader(
    test_dataset,
    batch_size=params["batch_size"],
    num_workers=4,
    shuffle=False,
    persistent_workers=False,
    pin_memory=True,
)

Switched to CUDA 11.7.


## Part 2: Where the fun begins
As mentioned, pcax is based on jax, which is a functional framework. Consequently, in order to offer a simple object oriented interface to it, there are some compromises to be made and strict patterns to follow.

(Note: pcax.core is basically an heavily modified version of *objax*, so if you're lost you may find it helpful to also look at that documentation.)

First of all, we **have** to instantiate the model, this is because pcax relies on binding global variables to the local scope of each function in order to work. So we need to have a model to begin with, and it's structure should (of course there are exception) be fixed for the rest of the program.

In [6]:
model = Model(28 * 28, params["hidden_dim"], 10)

### Part 2.1: Defining a function
Let's see how a function involving a `px.Module` is defined. We'll proceed step by step (so only the last version is the correct one).

Consider the function `predict(x, t)` which simply calls the forward pass of the model (we need to pass the target `t` as well since it is needed during training; unfortunately for `binded` function, as for now, default arguments are not supported, so you'll have to manually pass `None` if `t` is not necessary).
In normal settings, `predict` should look similar to the following:

In [7]:
def predict(x, t):
    return model(x, t)

**However**, since this function uses a `px.Module` (i.e., `model`), we need to *bind* it to the function (it cannot be passed as an argument and therefore needs to be in the global scope that's why we defined it above). There's not much too add, it needs to be done :P
The syntax is the following:

In [8]:
@px.bind(model)
def predict(x, t):
    return model(x, t)

The next step is **batching** (aka **vectorization**). Jax (and therefore pcax) defines each computation on a tensor assuming it contains a single sample (that is, there is no batch dimension as it happens, for instance, in Pytorch). Therefore, if in the function above, `x` is supposed to be an array with shape [28x28,] (remember we are flattening the MNIST images inside the dataloader) and `t` an array with shape [10,] (remember we are one-hot encoding it). If we want `predict` to be able to work on batched input (i.e., `x` with shape [n, 28x28]), we need to *vectorize* the function such that the `predict` computation will be repeated along the batch dimension.

We will use another decorator to achieve this:

In [9]:
@px.vectorize(_(px.NodeVar), in_axis=(0, 0))
@px.bind(model)
def predict(x, t):
    return model(x, t)

Let's ignore the first argument for a second and focus on `in_axis`: it is telling the function that both its inputs have to be batched along the 0th dimension, which means that now `predict` expects as inputs two arrays with shape [n, *] (in this case [n, 784] and [n, 10]). Similarly to the inputs, we have to specify if the outputs of the function will have an added "batch" dimension. In this case, `model` outputs a predictions for each `x`, so if `x` is batched, also the output will be. By default, `px.vectorize` assumes the decorated function has a single input to be batched, so we do not have to specify anything else for now.
If an input does not have a batch dimensions but it is, instead, a constant value to be used for each `x`, you can pass `None` instead of `0` for that particular parameter.

The first argument in `px.vectorize` specifies which tensors of the modules binded to the function (in this case only `model`) needs to have a batch dimension as well. Well, in predictive coding, those tensors are the ones representing the value nodes, so in 99% of the cases you will end up always using the same first parameter for `px.vectorize`. Recall that `_` is the filter object, then `_(px.NodeVar)` simply is saying all the `px.NodeVar` parameters inside the binded modules. Conveniently, `px.NodeVar` are exclusively the tensors used to represent the node values.

Try yourself:



In [10]:
model.vars(_(px.NodeVar))

{'(Model).pc1(Layer).x': pcax.pc.variables.NodeVar(None, reduce=reduce_id),
 '(Model).pc_h(ModuleList)[0](Layer).x': pcax.pc.variables.NodeVar(None, reduce=reduce_id),
 '(Model).pc_h(ModuleList)[1](Layer).x': pcax.pc.variables.NodeVar(None, reduce=reduce_id),
 '(Model).pc2(Layer).x': pcax.pc.variables.NodeVar(None, reduce=reduce_id)}

If instead we want to select the remaining tensors of the model, which in this case are all the link weights, we can use

In [11]:
model.vars(_[px.TrainVar, -_(px.NodeVar)])

{'(Model).linear1(Linear).nn.weight': pcax.core.structure.TrainVar(DeviceArray([[ 0.01229068,  0.01379374,  0.00090728, ..., -0.01822209,
                0.03359856,  0.01654346],
              [ 0.02093922, -0.01357643,  0.03248229, ...,  0.02490543,
               -0.00358149, -0.02930153],
              [ 0.02113941, -0.01080453,  0.00963924, ..., -0.00074164,
               -0.00332546, -0.02536105],
              ...,
              [-0.00435486, -0.0348849 , -0.01562772, ...,  0.02979221,
               -0.0200628 , -0.01586745],
              [ 0.02002883, -0.02582436, -0.00198849, ...,  0.02906173,
               -0.02739396, -0.01241719],
              [-0.03521763,  0.0162226 ,  0.01106472, ..., -0.03282779,
                0.01568388,  0.00918417]], dtype=float32), reduce=reduce_none),
 '(Model).linear1(Linear).nn.bias': pcax.core.structure.TrainVar(DeviceArray([-1.05218217e-02, -3.13689187e-03, -1.56988334e-02,
              -9.60610621e-03,  1.40408874e-02, -2.94154044e-02,

In particular all trainable tensors (so weights and value nodes) are `px.TrainVar`s, while value nodes are also the subclass `px.NodeVar`.
The filter object `_` has the following syntax:
- `_(...)`: *or*
- `_[...]`: *and*
- `-_`: *not*

Let's go to the next function we need: the loss function.  
Similarly to how in Pytorch you compute the gradients from the loss by calling `loss.backward()`, here we compute the gradients by transforming the loss function such that is also output the gradients.

In predictive coding, the standard loss function simply computes and returns the model's energy:

In [12]:
@px.vectorize(_(px.NodeVar), in_axis=(0, 0), out_axis=("sum",))
@px.bind(model)
def loss(x, t):
    y = model(x)
    return model.energy

Let's observe a few things:
- we pass `t` and compute `y` as it's standard practice, however, here, they are not strictly necessary: we don't need the model's output to compute the error (as it is used already when computing the energy) and we assume that the target has already been set to the `x` of the last layer (so again it is already included in the energy computation).
- The first two parameters of `px.vectorize` are the same of `predict`, however here we add a modifier for the output (notice that, even if we have a single output, `out_axis` is specified using a tuple). We use `"sum"` to specify that the output of the function should be the summed over the batch dimension. We do this because, as it happens in Pytorch, the loss must be a single floating point value (we don't use the mean since we want the total energy coming from each error node, not their average).
- `model.energy` automatically computes the energy value for each layer in `model`.

Now, we want to define two different "backward" functions, one for the *x step* (in which we update the value nodes) and one for the *w step* (in which we update the weights). They are identical except for the fact that they compute gradients with respect of different elements.  
`px.gradvalues` transform a function such that is outputs the gradients with respect to the specified variables.

In [13]:
"""
We compute the gradients with respect to the node values that are not frozen.
"""
train_x = px.gradvalues(
    _(px.NodeVar)(frozen=False), # _::(**kawrgs) is used to apply a filter, selecting all the variables with the specified properties
)(loss)

"""
We compute the gradients with respect to the weights (every px.TrainVar that is not a px.NodeVar).
"""
train_w = px.gradvalues(
    _[px.TrainVar, -_(px.NodeVar)],
)(loss)

This covers almost everything you need to do about defining a function the operates on one (or more) `px.Module`. Later we will see a couple of extra details.

### Part 2.2: Optimizers
In pcax (as it derives from Objax), optimizers are `px.Module`s as well, as a consequence they need to be treated similarly to how we used `model`. In particular, they must be bound to any function that uses them. So we need to define them beforehand in the global scope.  
pcax offers a single `px.Optim` class that allows `px.Module`s to interact with most `optax` optimizers (the most common jax library for optimizers). As you can see from the following example, you simply have to specify which `optax` optimizer to use for which subset of the `px.Module` variables. In this case, similarly to how we defined the two loss functions, we defined an *x optimizer* and a *w optimizer*.

In [14]:
!source switch-cuda 11.7

# dummy run to init the optimizer parameters
with px.eval(model):
    predict(np.zeros((params["batch_size"], 28 * 28)), (None,) * params["batch_size"])
    optim_x = px.Optim(
        optax.sgd(params["x_learning_rate"]), model.vars(_(px.NodeVar)(frozen=False))
    )
    optim_w = px.Optim(
        optax.adam(params["w_learning_rate"]),
        model.vars(_[px.TrainVar, -_(px.NodeVar)]),
    )

Switched to CUDA 11.7.


You may ask yourself what are the first two lines for?

It may be the case that your optimizers defines some parameters that are linked to the individual trainable parameter inside your model. For instance, you could choose to use `optax.adam` for `optim_x` (not recommended, for all we know you should stick with stateless optimizators for the value nodes). In this case, the optimizer needs to instantiate the `adam` parameters and, thus, it requires to know the shape of all the `px.NodeVar`s inside your model. However, if we do not run the model at least once, all the node values will be empty, since they are lazily created. Furthermore, the shape of the value nodes depends on the batch size (since, remember, we want different value nodes for each different input), so we need to perform a dummy run on a dummmy input with the same shape of the samples we are gonna train on to correctly initialize all the parameters inside the model. Only then we can safely create the two optimizers. (Note that this requires `batch_size` to be constant throught the program. To guarantee this, by default, the dataloaders have `drop_last=True`. Unless you know what you're doing, do not attempt to modify that.)

It actually doesn't really matter what you pass as `x` and `t`, but only that their shape matches the one of the true input batches (but in this way you can see how you can pass `None` as the target of the predict function: it needs to be batched as well).

This explains the `predict(...)` line. And the `with ...` line comes with it: `px.Module` uses an internal caching system to store intermediate values for later computations. `px.eval` takes care of clearing this cache. Therefore, every time you call `predict` (or, more precisely, you use `px.Module` to compute some values given an input, as it happens in the `__call__` function), you must enclose it under a `px.eval` manager (which also sets the module in eval mode), so that when you'll change the input, you are not gonna reuse the previously computed cached values. There's a similar context manager for training: `px.train`. They do slightly different things, but don't worry for now.

### Part 2.3: Training the model
We are almost done. We have all the ingredients to train the model, we just need to assemble them in the training and evaluation functions. There's only one core concept to introduce: *jitting*. *Just In Time* compilation allows the program to get compiled and optimized (it's the reason we have so many constraints on how code should be written). We jit the computations executed on a single batch as we want to compile as much code as possible, and `train_on_batch` represents the single largest repeated operation occurring during training.

In [15]:
!source switch-cuda 11.7

@px.jit()
@px.bind(model, optim_w=optim_w, optim_x=optim_x)
def train_on_batch(x, y):
    with px.eval(model):
        y_hat, = predict(x, y)

        for i in range(params["T"]):
            with px.train(model):
                g, (v,) = train_x(x, y)
                optim_x(g)

        with px.train(model):
            g, (v,) = train_w(x, y)
            optim_w(g)

Switched to CUDA 11.7.


A few remarks:
- in `px.bind` we specify the keys for the optim arguments since optim_w and optim_x are the same type and the need a name to distinguish them.
- this is the standard pc training procedure:
    - we initialize the node values using `predict` (which by default uses forward initialization)
    - we repeat the *x-update* step for `T` times and then we perform a single *w-update*.
- `optim_*(g)` updates the node values/weights of the model, therefore all the computed cached values must be cleared. That's why we enclose each trainig operation in a `px.train`. The difference with `px.eval` (which encloses the whole function block) is that the latter also clears the node values themselves, such that, a new call of `train_on_batch` will populate them with the new forward values computed by calling `predict` on the new batch.
- you can see that the values returned by `predict` and `train_*` (the first value returned by `train_*` are the computed gradients, and then there are the actual return values of the function) are all tuples, even if a single value is acutally returned by the orginal functions (before being transformed). Just something to keep in mind.

Let's see the remaining functions. Nothing new here:

In [16]:
!source switch-cuda 11.7

"""
Again we jit the operation executed on each batch.
"""
@px.jit()
@px.bind(model)
def evaluate(x, y):
    with px.eval(model):
        y_hat, = predict(x, y) # remeber that 'predict' always returns a tuple

    return (y_hat.argmax(-1) == y.argmax(-1)).mean()


def epoch(dl):
    for batch in dl:
        x, y = batch
        y = one_hot(y, 10)

        train_on_batch(x, y)

    return 0


def test(dl):
    accuracies = []
    for batch in dl:
        x, y = batch
        y = one_hot(y, 10)

        accuracies.append(evaluate(x, y))

    return np.mean(accuracies)

Switched to CUDA 11.7.


And the main body:

In [17]:
!source switch-cuda 11.7

if __name__ == "__main__":
    t = timeit.timeit(lambda: epoch(train_dataloader), number=1)
    print("Compiling + Epoch 1 took", t, "seconds")

    # Time of an epoch (without jitting)
    t = timeit.timeit(lambda: epoch(train_dataloader), number=params["num_epochs"]) / params["num_epochs"]
    print("An Epoch takes on average", t, "seconds")

    print("Final Accuracy:", test(test_dataloader))

    del train_dataloader
    del test_dataloader

Switched to CUDA 11.7.
Compiling + Epoch 1 took 7.22816416202113 seconds
An Epoch takes on average 0.9288925940054469 seconds
Final Accuracy: 0.96324116


If everything is all right, you should get a final accuracy of ~96% and a training time of ~0.7 second per epoch. The training time is actually heavily bottlenecked by the data transfer between CPU and GPU, that's why we are using 16 workers in the dataloader. So depending on your configuration the final speed my change.
