# Tutorial 1: **Node initialization**
In this notebook, you will learn how to use pcax to perform custom node initialization for you complex models.
Since the library is still in its early development, expect major syntax changes. You can keep coming back to this notebook to stay updated.

Good luck!

## Part 0: Importing dependencies

In [1]:
# Core dependencies
import numpy as np
import jax
import jax.numpy as jnp
import optax

# pcax
import pcax as px
import pcax.core as pxc
import pcax.utils as pxu
import pcax.nn as nn

# Environment variables
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [2]:
from typing import Callable, Optional

## Part 1: Defining a Model
In addition to previous tutorials, we define a model that accepts a custom initialization function that can be used to initialize `px.Node()`.

In tutorial #0, we saw that "the syntax `node(x)["x"]` is, by default, a shortcut for the following:
```python
node["u"] = x
if node.is_init:
    node["x"] = node["u"]
return node["x"]
```
while it is actually:
```python
node["u"] = x
if node.is_init:
    # rkg is a random key generator that is always passed to __call__
    init_fn(node, rkg)
return node["x"]
```
so we can customize `init_fn` behaviour.

Here are some examples:

In [3]:
# This is the default function used to initialize the nodes.
def forward_init(node: px.Node, rkg: pxc.RandomKeyGenerator):
    node["x"] = node["u"]

# This initializes the nodes with zeros. Consider that, given the presence of the bias term,
# this may be equivalent to initializing the nodes with a constant value (or even the node average:
# you force the average to be zero). The last comment doesn't apply to nodes observed during training,
# as their average may not be zero.
def zero_init(node: px.Node, rkg: pxc.RandomKeyGenerator):
    node["x"] = jnp.zeros(node["u"].shape)

# This initializes the nodes with a random normal distribution, which is the original initialization
# method used in predictive coding (of course a std of 1 may be too high, but this is just an example)
def random_init(node: px.Node, rkg: pxc.RandomKeyGenerator, std: float = 1.):
    node["x"] = jax.random.normal(rkg(), node["u"].shape) * std

# It may also be the case that we don't want to initialize the nodes at all. This is useful when
# we want to reuse the node values in the next training batch. This is an area yet to be explored.
# We will try to make the first steps in this notebook.
def no_init(node: px.Node, rkg: pxc.RandomKeyGenerator):
    pass

We also want to use a different energy function from (M)SE (we don't actually do the mean as we want the error to be individually propagated to each batched sample in the value nodes) as in classification tasks CE is normally the go-to choice. Thus, we create a different energy function to be used by the last `px.Node`.

In [4]:
def ce_energy(node: px.Node, rkg: pxc.RandomKeyGenerator):
    u = jax.nn.softmax(node["u"])
    return (node["x"] * jnp.log(node["x"] / (u + 1e-8) + 1e-8)).sum(axis=-1)

And now we can define the model as before, but with an extra argument: `init_fn`.

In [5]:
class Model(px.EnergyModule):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 output_dim: int,
                 nm_layers: int,
                 act_fn: Callable[[jax.Array], jax.Array],
                 init_fn: Optional[Callable[[px.Node, pxc.RandomKeyGenerator], None]] = None
                ) -> None:
        super().__init__()

        self.act_fn = act_fn
        self.layers = [nn.Linear(input_dim, hidden_dim)] + [
            nn.Linear(hidden_dim, hidden_dim) for _ in range(nm_layers - 1)
        ] + [nn.Linear(hidden_dim, output_dim)]

        self.nodes = [px.Node(init_fn=init_fn) for _ in range(nm_layers)] + [px.Node(energy_fn=ce_energy)]

        self.nodes[-1].x.frozen = True

    def __call__(self,
                 x: jax.Array,
                 t: Optional[jax.Array] = None):
        
        for node, layer in zip(self.nodes[:-1], self.layers[:-1]):
            x = node(self.act_fn(layer(x)))["x"]

        x = self.nodes[-1](self.layers[-1](x))["x"]

        if t is not None:
            self.nodes[-1]["x"] = t
        
        return self.nodes[-1]["u"]

We now define all the functions necessary to train a model. Note how compared to tutorial #0, we change the definitions order. This is because, normally, the model will be the last entity defined and these training functions may actually be defined in imported files in order to be reused.

Some data-loading utilities:

In [6]:
# Used to convert the square 0-255 images to 0-1 float vectors.
class FlattenAndCast:
    def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=np.float32) / 255.0)

Loos function (nothing new):

In [7]:
@pxu.vectorize(px.f(px.NodeParam, with_cache=True), in_axis=(0, 0), out_axis=("sum",))
def loss(x: jax.Array, t: Optional[jax.Array] = None, *, model):
    model(x)
    return model.energy()

train_x = pxu.grad_and_values(
    px.f(px.NodeParam)(frozen=False),
)(loss)

train_w = pxu.grad_and_values(
    px.f(px.LayerParam)
)(loss)

Train function:

In [8]:
@pxu.jit()
def train_batch(x, y, *, T, model, optim_w, optim_x):
    y = jax.nn.one_hot(y, 10)
    with pxu.train(model, x, y) as (y_hat,):
        for i in range(T):
            with pxu.step(model):
                g, (v,) = train_x(x, y, model=model)
                optim_x(g)

        with pxu.step(model):
            g, (v,) = train_w(x, y, model=model)
            optim_w(g)

Evaluate function:

In [9]:
@pxu.jit()
def eval_batch(x, y, *, model, optim_x):
    y = jax.nn.one_hot(y, 10)
    with pxu.eval(model, x, y) as (y_hat,):
        return (y_hat.argmax(-1) == y.argmax(-1)).mean()

Epoch and testing functions:

In [10]:
def epoch(dl, train_fn, p):
    for i, batch in enumerate(dl):
        x, y = batch

        train_fn(x, y)
        
        if i / len(dl) >= p:
            break

def test(dl, eval_fn):
    accuracies = []
    for batch in dl:
        x, y = batch

        accuracies.append(eval_fn(x, y))

    return np.mean(accuracies)

And, finally, the `__main__`

In [11]:
from torchvision.datasets import FashionMNIST
from tqdm import tqdm

# Define parameters
params = {
        "batch_size": 256,
        "x_learning_rate": 0.05,
        "w_learning_rate": 1e-4,
        "num_epochs": 8,
        "num_layers": 4,
        "hidden_dim": 128,
        "input_dim": 28 * 28,
        "output_dim": 10,
        "T": 4,
        "p": 1.0
    }

# Create dataloaders
train_dataset = FashionMNIST(
    "/tmp/mnist/",
    transform=FlattenAndCast(),
    download=True,
    train=True,
)
train_dataloader = pxu.data.TorchDataloader(
    train_dataset,
    batch_size=params["batch_size"],
    num_workers=8,
    shuffle=True,
    persistent_workers=True,
    pin_memory=True,
)

test_dataset = FashionMNIST(
    "/tmp/mnist/",
    transform=FlattenAndCast(),
    download=True,
    train=False,
)
test_dataloader = pxu.data.TorchDataloader(
    test_dataset,
    batch_size=params["batch_size"],
    num_workers=8,
    shuffle=False,
    persistent_workers=True,
    pin_memory=True,
)

if __name__ == "__main__":
    # Create model
    model = Model(28 * 28, params["hidden_dim"], 10, params["num_layers"], jax.nn.gelu, forward_init)

    # Create optimizers
    with pxu.train(model, jnp.zeros((params["batch_size"], 28 * 28)), None):
        optim_x = pxu.Optim(
            optax.sgd(params["x_learning_rate"]),
            model.parameters().filter(px.f(px.NodeParam)(frozen=False))
        )
        optim_w = pxu.Optim(
            optax.adam(params["w_learning_rate"]),
            model.parameters().filter(px.f(px.LayerParam)),
        )

    # Create snapshots
    train_fn = train_batch.snapshot(T=params["T"], model=model, optim_x=optim_x, optim_w=optim_w)
    evaluate_fn = eval_batch.snapshot(model=model, optim_x=optim_w)

    # Train:
    for e in tqdm(range(params["num_epochs"])):
        epoch(train_dataloader, train_fn, params["p"])

    # Evaluation:
    accuracy = test(test_dataloader, evaluate_fn)
    print(f"Accuracy: {accuracy * 100:.2f}%")

100%|██████████| 8/8 [00:08<00:00,  1.09s/it]


Accuracy: 83.88%


# Part 2: Different initialization techniques
Let's now try with a different initialization: `zero_init`. Note however, that we still want to perform a forward init during evaluation as it guarantees to obtain the lowest energy. Thus, we want a custom initialization function that behaves differently during training and evaluation: 

In [None]:
def zero_init(node: px.Node, rkg: pxc.RandomKeyGenerator):
    if node.is_train:
        node["x"] = jnp.zeros_like(node["u"])
    else:
        node["x"] = node["u"]

In [None]:
if __name__ == "__main__":
    model = Model(28 * 28, params["hidden_dim"], 10, params["num_layers"], jax.nn.tanh, zero_init)
    
    # Create optimizers
    with pxu.train(model, jnp.zeros((params["batch_size"], 28 * 28)), None):
        optim_x = pxu.Optim(
            optax.sgd(params["x_learning_rate"]),
            model.parameters().filter(px.f(px.NodeParam)(frozen=False))
        )
        optim_w = pxu.Optim(
            optax.adam(params["w_learning_rate"]),
            model.parameters().filter(px.f(px.LayerParam)),
        )

    # Create snapshots
    train_fn = train_batch.snapshot(T=params["T"], model=model, optim_x=optim_x, optim_w=optim_w)
    eval_fn = eval_batch.snapshot(model=model, optim_x=optim_w)

    # Train:
    for e in tqdm(range(params["num_epochs"])):
        epoch(train_dataloader, train_fn, params["p"])

    # Evaluation:
    accuracy = test(test_dataloader, eval_fn)
    print(f"Accuracy: {accuracy * 100:.2f}%")

100%|██████████| 16/16 [00:14<00:00,  1.07it/s]


Accuracy: 47.88%


As you can see, `zero_init` does not work! This is because the model performs only `T=4` x steps with a low learning rate of `0.05`. These settings only work when the necessary x correction is small enough (target vs current x values). We can try to solve this problem by performing more x steps and using a larger learning rate.

In [None]:
if __name__ == "__main__":
    model = Model(28 * 28, params["hidden_dim"], 10, params["num_layers"], jax.nn.tanh, zero_init)
    
    # Create optimizers
    with pxu.train(model, jnp.zeros((params["batch_size"], 28 * 28)), None):
        optim_x = pxu.Optim(
            optax.sgd(0.5),
            model.parameters().filter(px.f(px.NodeParam)(frozen=False))
        )
        optim_w = pxu.Optim(
            optax.adam(params["w_learning_rate"]),
            model.parameters().filter(px.f(px.LayerParam)),
        )

    # Create snapshots
    train_fn = train_batch.snapshot(T=32, model=model, optim_x=optim_x, optim_w=optim_w)
    eval_fn = eval_batch.snapshot(model=model, optim_x=optim_w)

    # Train:
    for e in tqdm(range(params["num_epochs"])):
        epoch(train_dataloader, train_fn, params["p"])

    # Evaluation:
    accuracy = test(test_dataloader, eval_fn)
    print(f"Accuracy: {accuracy * 100:.2f}%")

  0%|          | 0/16 [00:00<?, ?it/s]

100%|██████████| 16/16 [00:23<00:00,  1.49s/it]


Accuracy: 79.09%


# Part 3: Custom initialization
Until now, we have been using `pxu.train` and `pxu.eval` to perform automatic initialization of the node values. However, this can also be performed manually for more flexibility. Here, we will initialize the nodes with the class average of the values computed for the previous batch. Note that the evaluation will still be based on forward initialization.

# 

In [12]:
@pxu.jit()
def train_batch_avg_init(x, y, nodes_avg_by_class, *, T, model, optim_w, optim_x):
    with pxu.train(model, x):
        # Initialize nodes
        if nodes_avg_by_class is not None:
            for node, avg_by_class in zip(model.nodes[1:-1], nodes_avg_by_class):
                node["x"] = jax.vmap(lambda y, avg: jax.lax.select_n(y, *jax.numpy.vsplit(avg, 10)).squeeze(), in_axes=(0, None), out_axes=0)(y, avg_by_class)

        # Convert y to one_hot
        y = jax.nn.one_hot(y, 10)

        # Set last node to y
        model.nodes[-1]["x"] = y

        for i in range(T):
            with pxu.step(model):
                g, (v,) = train_x(x, y, model=model)
                optim_x(g)

        with pxu.step(model):
            g, (v,) = train_w(x, y, model=model)
            optim_w(g)

    nodes_avg_by_class = [
        jnp.sum(jnp.reshape(jnp.tile(node["x"], (10, 1)), (10, x.shape[0], -1)) * jnp.expand_dims(y.T, axis=-1), axis=1)
        / jnp.expand_dims(jnp.sum(y, axis=0), axis=-1) for node in model.nodes[1:-1]
    ]

    return nodes_avg_by_class


def epoch_avg_init(dl, train_fn, p):
    nodes_avg_by_class = None

    for i, batch in enumerate(dl):
        x, y = batch
        nodes_avg_by_class = train_fn(x, y, nodes_avg_by_class)

        if i / len(dl) > p:
            break

In [14]:

if __name__ == "__main__":
    model = Model(28 * 28, params["hidden_dim"], 10, params["num_layers"], jax.nn.gelu)
    
    # Create optimizers
    with pxu.train(model, jax.random.normal(pxc.RKG(), (params["batch_size"], 28 * 28)), None):
        optim_x = pxu.Optim(
            optax.sgd(params["x_learning_rate"]),
            model.parameters().filter(px.f(px.NodeParam)(frozen=False))
        )
        optim_w = pxu.Optim(
            optax.adam(params["w_learning_rate"]),
            model.parameters().filter(px.f(px.LayerParam)),
        )
    

    # Create snapshots
    train_fn_avg_init = train_batch_avg_init.snapshot(T=params["T"], model=model, optim_x=optim_x, optim_w=optim_w)
    eval_fn = eval_batch.snapshot(model=model, optim_x=optim_w)

    # Train:
    for e in tqdm(range(params["num_epochs"])):
        epoch_avg_init(train_dataloader, train_fn_avg_init, params["p"])

    # Evaluation:
    accuracy = test(test_dataloader, eval_fn)
    print(f"Accuracy: {accuracy * 100:.2f}%")

  0%|          | 0/8 [00:00<?, ?it/s]

100%|██████████| 8/8 [00:09<00:00,  1.17s/it]


Accuracy: 84.08%
