# Tutorial 0b: **backpropagation with pcax**
In this notebook, you will learn how to use pcax to a network equivalent to the one in *Tutorial 0*, but trained using backpropagation.
The general idea is to simply remove all the bits related to predictive coding, nothing more than that.

Good luck!

In [1]:
!source switch-cuda 11.7

# Core dependencies
import jax
import optax

# pcax
import pcax as px # same as import pcax.pc as px
import pcax.nn as nn
from pcax.core import _ # _ is the filter object, more about it later!
import pcax.interface as pxi # remember it will be soon deprecated

import numpy as np
from torchvision.datasets import MNIST
import timeit
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" # remember this line if you are sharing the GPU with others

Switched to CUDA 11.7.


First step: removing all pc layers from the model:

In [2]:
class Model(px.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, nm_layers=2) -> None:
        super().__init__()

        self.act_fn = jax.nn.tanh

        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear_h = nn.ModuleList(
            [nn.Linear(hidden_dim, hidden_dim) for _ in range(nm_layers)]
        )
        self.linear2 = nn.Linear(hidden_dim, output_dim)

    def __call__(self, x):
        x = self.act_fn(self.linear1(x))

        for i in range(len(self.linear_h)):
            x = self.act_fn(self.linear_h[i](x))

        x = self.linear2(x)

        return x

Second step: we don't need pc related params:

In [3]:
params = {
    "batch_size": 256,
    # "x_learning_rate": 0.01,
    "w_learning_rate": 1e-3,
    "num_epochs": 4,
    "hidden_dim": 128,
    "input_dim": 28 * 28,
    "output_dim": 10,
    "seed": 0,
    # "T": 4,
}

No changes in the data loading:

In [4]:
def one_hot(t, k):
    return np.array(t[:, None] == np.arange(k), dtype=np.float32)


class FlattenAndCast:
    def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=np.float32) / 255.0)


train_dataset = MNIST(
    "/tmp/mnist/",
    transform=FlattenAndCast(),
    download=True,
    train=True,
)
train_dataloader = pxi.data.Dataloader(
    train_dataset,
    batch_size=params["batch_size"],
    num_workers=16,
    shuffle=True,
    persistent_workers=True,
    pin_memory=True,
)


test_dataset = MNIST(
    "/tmp/mnist/",
    transform=FlattenAndCast(),
    download=True,
    train=False,
)
test_dataloader = pxi.data.Dataloader(
    test_dataset,
    batch_size=params["batch_size"],
    num_workers=4,
    shuffle=False,
    persistent_workers=False,
    pin_memory=True,
)

The model is still a `px.Module` so it needs to be in the global scope.

In [5]:
model = Model(28 * 28, params["hidden_dim"], 10)

In the predict function we don't need the target (it was used to fix the last pc layer which is not there anymore).

In [6]:
@px.vectorize(_(px.NodeVar), in_axis=(0,)) # note the change to `in_axis`
@px.bind(model)
def predict(x):
    return model(x)

The loss is slightly different, as now we use the target to compute the actual loss (as we normally do in backpropagation):

In [7]:
@px.vectorize(_(px.NodeVar), in_axis=(0, 0), out_axis=("sum", 0))
@px.bind(model)
def loss(x, t):
    y = model(x)
    # we also return y even if we will not use it, to show how it can be done. Note the change in out_axis.
    return ((y - t) ** 2).sum(), y # it's MSE loss without the Mean, since in pc we use the sum

We only need one gradient function and a single optimizer (since we only optimize the weights).

In [8]:
train = px.gradvalues(
    _(px.TrainVar)
)(loss)

"""
There's actually no need to call predict here as nothing needs to be initialized,
however it is good practice to do so as it may be necessary in the general case.
"""
with px.eval(model):
    predict(np.zeros((params["batch_size"], 28 * 28)))
    optim = px.Optim(
        optax.adam(params["w_learning_rate"]),
        model.vars(_(px.TrainVar)),
    )

The training and evaluation functions are very intuitive:

In [9]:
@px.jit()
@px.bind(model, optim)
def train_on_batch(x, y):
    """
    Again, we don't need to call predict since there's no value nodes to be initialized.
    If we need the predicted y_hat we can use the value returned by the 'train' function.
    """
    with px.train(model):
        g, (v, y_hat) = train(x, y)
        optim(g)


@px.jit()
@px.bind(model)
def evaluate(x, y):
    with px.eval(model):
        y_hat, = predict(x) # remeber that 'predict' always returns a tuple

    return (y_hat.argmax(-1) == y.argmax(-1)).mean()


def epoch(dl):
    for batch in dl:
        x, y = batch
        y = one_hot(y, 10)

        train_on_batch(x, y)

    return 0


def test(dl):
    accuracies = []
    for batch in dl:
        x, y = batch
        y = one_hot(y, 10)

        accuracies.append(evaluate(x, y))

    return np.mean(accuracies)

In [10]:
!source switch-cuda 11.7

if __name__ == "__main__":
    t = timeit.timeit(lambda: epoch(train_dataloader), number=1)
    print("Compiling + Epoch 1 took", t, "seconds")

    # Time of an epoch (without jitting)
    t = timeit.timeit(lambda: epoch(train_dataloader), number=params["num_epochs"]) / params["num_epochs"]
    print("An Epoch takes on average", t, "seconds")

    print("Final Accuracy:", test(test_dataloader))

    del train_dataloader
    del test_dataloader

Switched to CUDA 11.7.
Compiling + Epoch 1 took 4.027969808958005 seconds
An Epoch takes on average 0.691992550244322 seconds
Final Accuracy: 0.9625401


If everything is all right, you should get a final accuracy of ~96% and a training time of ~0.7 second per epoch. The training time is actually heavily bottlenecked by the data transfer between CPU and GPU, that's why we are using 16 workers in the dataloader. So depending on your configuration the final speed my change.
