# Tutorial 0b: **backpropagation with pcax**
In this notebook, you will learn how to use pcax to a network equivalent to the one in *Tutorial 0*, but trained using backpropagation.
The general idea is to simply remove all the bits related to predictive coding, nothing more than that.

Good luck!

In [1]:
# Core dependencies
import jax
import optax

# pcax

# Importing pcax is equivalent to importing pcax.pc
# which contains the functionalities to build a predictive coding network
import pcax as px

# A filter is core object of pcax. It is used to filter which parameters in a network
# should undergo a specific JAX transformation (e.g. jax.grad, jax.jit, etc.).
# Consequently, despite being a core object, pcax offers a shortuct to use it: pcax.f.

# pcax.nn contains the neural network modules (e.g. Conv2d, Linear, etc.)
# at the moment only Linear is implemented, but more coming soon!
import pcax.nn as nn

# pcax.utils contains some useful utilities to train and use the network
import pcax.utils as pxu

# finally we can import the library's core which is simply a wrapper around JAX,
# unrelated to predictive coding. useful for some advanced configurations.
# We will not use it in this tutorial. 
# import pcax.core as pxc

In [2]:
from typing import Callable, Optional
import timeit
import os
import numpy as np
from torchvision.datasets import MNIST

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

First step: removing all pc layers from the model:

In [3]:
# Note we use px.Module instead of px.EnergyModule as we don't need to compute the energy
class Model(px.Module):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 output_dim: int,
                 nm_layers: int,
                 act_fn: Callable[[jax.Array], jax.Array]) -> None:
        super().__init__()

        self.act_fn = act_fn
        
        self.layers = [nn.Linear(input_dim, hidden_dim)] + [
            nn.Linear(hidden_dim, hidden_dim) for _ in range(nm_layers - 1)
        ] + [nn.Linear(hidden_dim, output_dim)]

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = self.act_fn(layer(x))

        x = self.layers[-1](x)

        return x

Second step: we don't need pc related params:

In [4]:
params = {
    "batch_size": 256,
    # "x_learning_rate": 0.01,
    "w_learning_rate": 1e-3,
    "num_epochs": 4,
    "hidden_dim": 128,
    "input_dim": 28 * 28,
    "output_dim": 10,
    "seed": 0,
    # "T": 4,
}

No changes in the data loading:

In [5]:
def one_hot(t, k):
    return np.array(t[:, None] == np.arange(k), dtype=np.float32)


class FlattenAndCast:
    def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=np.float32) / 255.0)


train_dataset = MNIST(
    "/tmp/mnist/",
    transform=FlattenAndCast(),
    download=True,
    train=True,
)
train_dataloader = pxu.data.TorchDataloader(
    train_dataset,
    batch_size=params["batch_size"],
    num_workers=8,
    shuffle=True,
    persistent_workers=True,
    pin_memory=True,
)


test_dataset = MNIST(
    "/tmp/mnist/",
    transform=FlattenAndCast(),
    download=True,
    train=False,
)
test_dataloader = pxu.data.TorchDataloader(
    test_dataset,
    batch_size=params["batch_size"],
    num_workers=8,
    shuffle=False,
    persistent_workers=True,
    pin_memory=True,
)

In the predict function we don't need the target (it could be used to fix the last pc node which does not exit anymore).

In [6]:
# Note the change to `in_axis` and the absence of a filter
@pxu.vectorize(in_axis=(0,))
def predict(x: jax.Array, *, model = None):
    return model(x)

The loss is slightly different, as now we use the target to compute the actual loss (as we normally do in backpropagation). We also only need one gradient function and a single optimizer (since we only optimize the weights) so we can define all together by merging transformations.

In [7]:
@pxu.grad_and_values(px.f(px.LayerParam))
@pxu.vectorize(in_axis=(0, 0), out_axis=("sum", 0))
def loss(x: jax.Array, t: jax.Array, *, model = None):
    y = model(x)
    # we also return y even if we will not use it, to show how it can be done. Note the change in out_axis.
    return (jax.numpy.square(y - t)).sum(), y # it's MSE loss without the Mean, since in pc we use the sum

In [8]:
train = loss # change of name to be consistent with the other tutorials
model = Model(28 * 28, params["hidden_dim"], 10, 2, jax.nn.tanh)

"""
There's actually no need to initialize the model before defining the optimizer,
however it is good practice to do so as it may be necessary in the general case.
"""
with pxu.train(model): # we don't pass any arguments here, since we don't need to initialize anything
    optim = pxu.Optim(
        optax.adam(params["w_learning_rate"]),
        model.parameters().filter(px.f(px.LayerParam)),
    )

The training and evaluation functions are very intuitive:

In [9]:
@pxu.jit()
def train_on_batch(x, y, *, model, optim):
    """
    Again, we don't need to initialize any node so we don't need to pass any argument to 'pxu.train'.
    If we need the predicted y_hat we can use the value returned by the 'train' function.
    """
    with pxu.train(model):
        g, (v, y_hat) = train(x, y, model=model)
        optim(g)


@pxu.jit()
def evaluate(x, y, *, model):
    # There are no nodes to initialize, but we can use px.eval to compute y_hat.
    # An alternative way would be to use step with the batched predict function.
    with pxu.eval(model, x) as (y_hat,):
        return (y_hat.argmax(-1) == y.argmax(-1)).mean()

    # Alternative version:
    
    # with px.step(model):
    #     y_hat, = predict(x, model=model)
    #     return (y_hat.argmax(-1) == y.argmax(-1)).mean()


def epoch(dl, train_fn):
    for batch in dl:
        x, y = batch
        y = one_hot(y, 10)

        train_fn(x, y)

    return 0


def test(dl, test_fn):
    accuracies = []
    for batch in dl:
        x, y = batch
        y = one_hot(y, 10)

        accuracies.append(test_fn(x, y))

    return np.mean(accuracies)

In [10]:
if __name__ == "__main__":
    train_fn = train_on_batch.snapshot(model=model, optim=optim)
    test_fn = evaluate.snapshot(model=model)

    t = timeit.timeit(lambda: epoch(train_dataloader, train_fn), number=1)
    print("Compiling + Epoch 1 took", t, "seconds")

    # Time of an epoch (without jitting)
    t = timeit.timeit(lambda: epoch(train_dataloader, train_fn), number=params["num_epochs"]) / params["num_epochs"]
    print("An Epoch takes on average", t, "seconds")

    print("Final Accuracy:", test(test_dataloader, test_fn))

    del train_dataloader
    del test_dataloader

Compiling + Epoch 1 took 1.718640837003477 seconds
An Epoch takes on average 0.7030591029906645 seconds
Final Accuracy: 0.95863384


If everything is all right, you should get a final accuracy of ~96% and a training time of ~0.7 second per epoch. The training time is actually heavily bottlenecked by the data transfer between CPU and GPU, that's why we are using 8 workers in the dataloader. So depending on your configuration the final speed my change.
