# Tutorial #7: Zer-divergence Inference Learning (ZIL)

In [1]:
from typing import Callable

import jax
import optax
import numpy as np

import pcx as px
import pcx.predictive_coding as pxc
import pcx.nn as pxnn
import pcx.functional as pxf
import pcx.utils as pxu

In [2]:
class Model(pxc.EnergyModule):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        nm_layers: int,
        act_fn: Callable[[jax.Array], jax.Array],
    ) -> None:
        super().__init__()

        self.act_fn = px.static(act_fn)

        self.layers = (
            [pxnn.Linear(input_dim, hidden_dim)]
            + [pxnn.Linear(hidden_dim, hidden_dim) for _ in range(nm_layers - 2)]
            + [pxnn.Linear(hidden_dim, output_dim)]
        )

        self.vodes = [pxc.Vode() for _ in range(nm_layers - 1)] + [pxc.Vode(energy_fn=pxc.ce_energy)]
        self.vodes[-1].h.frozen = True

    def __call__(self, x, y):
        for v, l in zip(self.vodes[:-1], self.layers[:-1]):
            x = v(self.act_fn(l(x)))

        x = self.vodes[-1](self.layers[-1](x))

        if y is not None:
            self.vodes[-1].set("h", y)

        return self.vodes[-1].get("u")

In [3]:
@pxf.vmap(
    pxu.M(pxc.VodeParam | pxc.VodeParam.Cache).to((None, 0)), in_axes=(0, 0), out_axes=0
)
def forward(x, y, *, model: Model):
    return model(x, y)


@pxf.vmap(
    pxu.M(pxc.VodeParam | pxc.VodeParam.Cache).to((None, 0)),
    in_axes=(0,),
    out_axes=(None, 0),
    axis_name="batch",
)
def energy(x, *, model: Model):
    y_ = model(x, None)
    return jax.lax.psum(model.energy(), "batch"), y_

In [4]:
@pxf.jit()
def zil_train_on_batch(
    x: jax.Array,
    y: jax.Array,
    *,
    model: Model,
    optim_w: pxu.Optim,
    optim_h: pxu.Optim,
):
    model.train()

    # Init step
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        forward(x, y, model=model)
    optim_h.init(pxu.M_hasnot(pxc.VodeParam, frozen=True)(model))

    # Inference steps
    L = len(model.vodes)
    for t in range(L):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            # We create an 'empty' mask, with all the parameters set to False.
            mask = pxu.M(None).to([False, False])(model, is_pytree=True)

            # Mask in vode: in order for ZIL to work, it is necessary to update
            # only the vode of the layer that is being updated.
            mask.vodes[L - t - 2].h = True
            
            # Mask in layer: we update the weights only of layer l, such that l = L - t - 1,
            # as defined by the ZIL algorithm.
            # Here, we could also manually set all the corresponding layer parameters of the
            # mask to True, but we use the `pxu.M` utility to do it for us and replace
            # the whole subtree of the model with the correctly masked one.
            mask.layers[L - t - 1] = pxu.M(pxnn.LayerParam).to([False, True])(
                model.layers[L - t - 1], is_pytree=True
            )

            (e, y_), g = pxf.value_and_grad({"model": mask}, has_aux=True)(energy)(
                x, model=model
            )

        optim_h.step(model, g["model"])
        optim_w.step(model, g["model"], scale_by=1.0 / x.shape[0])

    optim_h.clear()

In [5]:
def eval_on_batch(x: jax.Array, y: jax.Array, *, model: Model):
    model.eval()

    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        y_ = forward(x, None, model=model).argmax(axis=-1)

    return (y_ == y).mean(), y_


# Standard training loop
def zil_train(dl, *, model: Model, optim_w: pxu.Optim, optim_h: pxu.Optim):
    for x, y in dl:
        zil_train_on_batch(
            x, jax.nn.one_hot(y, 2), model=model, optim_w=optim_w, optim_h=optim_h
        )


# Standard evaluation loop
def eval(dl, *, model: Model):
    acc = []
    ys_ = []

    for x, y in dl:
        a, y_ = eval_on_batch(x, y, model=model)
        acc.append(a)
        ys_.append(y_)

    return np.mean(acc), np.concatenate(ys_)

In [6]:
from sklearn.datasets import make_moons

# this is unrelated to pcax: we generate and display the training set.
nm_elements = 1024
batch_size = nm_elements // 4
nm_epochs = 8
X, y = make_moons(
    n_samples=batch_size * (nm_elements // batch_size), noise=0.2, random_state=42
)

In [7]:
# we split the dataset in training batches and do the same for the generated test set.
train_dl = list(zip(X.reshape(-1, batch_size, 2), y.reshape(-1, batch_size)))

X_test, y_test = make_moons(
    n_samples=batch_size * (nm_elements // batch_size), noise=0.2, random_state=0
)
test_dl = tuple(zip(X_test.reshape(-1, batch_size, 2), y_test.reshape(-1, batch_size)))

In [8]:
n_seeds = 8

In [None]:
px.RKG.seed(0)


def zil_test():
    max_acc = 0
    max_acc_at_half_epochs = 0
    model = Model(
        input_dim=2, hidden_dim=32, output_dim=2, nm_layers=3, act_fn=jax.nn.leaky_relu
    )


    # 'pxu.Optim' accepts a optax optimizer and the target parameters in input. pxu.Mask
    # can be used to partition between target parameters and not: when no 'map_to' is
    # provided, such as here, it acts as 'eqx.partition', using pxc.VodeParam as filter.
    optim_w = pxu.OptimTree(
        lambda: optax.sgd(0.2, momentum=0.9, nesterov=True),
        lambda x: isinstance(x, pxnn.Layer),
        pxu.M(pxnn.LayerParam)(model),
    )

    # We only create the state optimizer `optim_h` without initialising it, since its state
    # is batch-dependent and we want to re-initialise it for each new batch.
    optim_h = pxu.OptimTree(
        lambda: optax.sgd(1.0), lambda x: isinstance(x, pxc.Vode)
    )


    for e in range(nm_epochs):
        zil_train(train_dl, model=model, optim_w=optim_w, optim_h=optim_h)
        a, y = eval(test_dl, model=model)
        
        if a > max_acc:
            max_acc = a
            
        if e == nm_epochs // 2:
            max_acc_at_half_epochs = a

    return max_acc, max_acc_at_half_epochs


zil_accs = []
for _ in range(n_seeds):
    zil_accs.append(zil_test())

avg_a = np.mean(zil_accs, axis=0)
std_a = np.std(zil_accs, axis=0)
print(f"ZIL: {avg_a[0]:.2%} ± {std_a[0]:.2%}")
print(f"ZIL[{nm_epochs // 2} epochs]: {avg_a[1]:.2%} ± {std_a[1]:.2%}")

## Hybrid training
Here we combine ZIL with PC by first performing a weight update for each weight when $t=l$ and then, after $T$ steps performing a complete weight update, as we would do in PC.

Note that ZIL and PC require significantly different learning rates, so that aspect could be still refined by creating a custom optimiser. However, as it is, hybrid training significantly outperform ZIL (and thus BP).

In [10]:
@pxf.jit(static_argnums=0)
def hybrid_train_on_batch(
    T: int,
    x: jax.Array,
    y: jax.Array,
    *,
    model: Model,
    optim_w: pxu.Optim,
    optim_h: pxu.Optim,
):
    model.train()

    # Init step
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        forward(x, y, model=model)
    optim_h.init(pxu.M_hasnot(pxc.VodeParam, frozen=True)(model))

    # Inference steps
    L = len(model.vodes)
    for t in range(T):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            mask = pxu.M_hasnot(pxc.VodeParam, frozen=True).to([False, True])(
                model, is_pytree=True
            )

            # Mask out vodes for better efficiency.
            for j in range(0, L - t - 2):
                mask.vodes[j].h = False
            
            # Mask in layer: we update the weights only of layer l, such that l = L - t - 1,
            # as defined by the ZIL algorithm. If t > L - 1, we do not update any layer as
            # we have already done the "backpropagation" pass, and we simply keep updating
            # the state to then perform a "predictive coding" updated.
            if L - t - 1 >= 0:
                mask.layers[L - t - 1] = pxu.M(pxnn.LayerParam).to([False, True])(
                    model.layers[L - t - 1], is_pytree=True
                )

            (e, y_), g = pxf.value_and_grad({"model": mask}, has_aux=True)(energy)(
                x, model=model
            )

        optim_h.step(model, g["model"])
        if L - t - 1 >= 0:
            optim_w.step(model, g["model"], scale_by=1.0 / x.shape[0])

    optim_h.clear()

    # Weight update step
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        (e, y_), g = pxf.value_and_grad(
            pxu.M(pxnn.LayerParam).to([False, True]), has_aux=True
        )(energy)(x, model=model)

    optim_w.step(model, g["model"], scale_by=1.0 / x.shape[0])


# Standard training loop
def hybrid_train(dl, T, *, model: Model, optim_w: pxu.Optim, optim_h: pxu.Optim):
    for x, y in dl:
        hybrid_train_on_batch(
            T, x, jax.nn.one_hot(y, 2), model=model, optim_w=optim_w, optim_h=optim_h
        )

In [None]:
px.RKG.seed(0)

def zil_test(T):
    max_acc = 0
    max_acc_at_half_epochs = 0
    model = Model(
        input_dim=2, hidden_dim=32, output_dim=2, nm_layers=3, act_fn=jax.nn.leaky_relu
    )

    optim_w = pxu.OptimTree(
        lambda: optax.sgd(0.2, momentum=0.9, nesterov=True),
        lambda x: isinstance(x, pxnn.Layer),
        pxu.M(pxnn.LayerParam)(model),
    )

    optim_h = pxu.OptimTree(
        lambda: optax.sgd(
            optax.linear_schedule(1.0, 0.1, 1)
        ), lambda x: isinstance(x, pxc.Vode)
    )

    for e in range(nm_epochs):
        hybrid_train(train_dl, T, model=model, optim_w=optim_w, optim_h=optim_h)
        a, y = eval(test_dl, model=model)
        
        if a > max_acc:
            max_acc = a
            
        if e == nm_epochs // 2:
            max_acc_at_half_epochs = a

    return max_acc, max_acc_at_half_epochs


T = 8
zil_accs = []
for _ in range(n_seeds):
    zil_accs.append(zil_test(T))

avg_a = np.mean(zil_accs, axis=0)
std_a = np.std(zil_accs, axis=0)
print(f"HIL: {avg_a[0]:.2%} ± {std_a[0]:.2%}")
print(f"HIL[{nm_epochs // 2} epochs]: {avg_a[1]:.2%} ± {std_a[1]:.2%}")