# Train with scan and momentum but without resetting the momentum at each inference

In [None]:
!pip install tqdm

In [5]:
from typing import Callable

import jax
from optax._src import base 
from optax._src import combine
from optax._src import transform
import optax
from typing import Any, Callable, Optional

import pcx as px
import pcx.predictive_coding as pxc
import pcx.nn as pxnn
import pcx.functional as pxf
import pcx.utils as pxu

from tqdm import tqdm
import numpy as np
from scipy.stats import wasserstein_distance
import matplotlib.pyplot as plt
import numpy as np

px.RKG.seed(0)

In [6]:
class Model(pxc.EnergyModule):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        nm_layers: int,
        act_fn: Callable[[jax.Array], jax.Array]
    ) -> None:
        super().__init__()

        self.act_fn = px.static(act_fn)
        
        self.layers = [pxnn.Linear(input_dim, hidden_dim)] + [
            pxnn.Linear(hidden_dim, hidden_dim, bias=False) for _ in range(nm_layers - 2)
        ] + [pxnn.Linear(hidden_dim, output_dim, bias=False)]

        self.vodes = [
            pxc.Vode() for _ in range(nm_layers)
        ]
        
        self.vodes[-1].h.frozen = True

    def __call__(self, x, y):
        for v, l in zip(self.vodes[:-1], self.layers[:-1]):
            x = self.act_fn(v(l(x)))

        x = self.vodes[-1](self.layers[-1](x))
        if y is not None:
            self.vodes[-1].set("h", y)
        return self.vodes[-1].get("u")

In [7]:
@pxf.vmap(pxu.M(pxc.VodeParam | pxc.VodeParam.Cache).to((None, 0)), in_axes=(0, 0), out_axes=0)
def forward(x, y, *, model: Model):
    return model(x, y)


@pxf.vmap(pxu.M(pxc.VodeParam | pxc.VodeParam.Cache).to((None, 0)), in_axes=(0,), out_axes=(None, 0), axis_name="batch")
def energy(x, *, model: Model):
    y_ = model(x, None)
    return jax.lax.psum(model.energy(), "batch"), y_

In [8]:
@pxf.jit(static_argnums=0)
def train_on_batch(
    T: int,
    x: jax.Array,
    y: jax.Array,
    *,
    model: Model,
    optim_w: pxu.Optim,
    optim_h: pxu.Optim
):
    def h_step(i, x, *, model, optim_h):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, y_), g = pxf.value_and_grad(
                pxu.M_hasnot(pxc.VodeParam, frozen=True).to([False, True]),
                has_aux=True
            )(energy)(x, model=model)
        optim_h.step(model, g["model"])
        return x, None

    model.train()
        
    # Init step
    with pxu.step(model, (pxc.STATUS.INIT, None), clear_params=pxc.VodeParam.Cache):
        forward(x, y, model=model)
    optim_h.init(pxu.M_hasnot(pxc.VodeParam, frozen=True)(model))
    
    # Inference steps
    pxf.scan(h_step, xs=jax.numpy.arange(T))(x, model=model, optim_h=optim_h)
    
    optim_h.clear()

    # Learning step
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        (e, y_), g = pxf.value_and_grad(pxu.M(pxnn.LayerParam).to([False, True]), has_aux=True)(energy)(x, model=model)
    optim_w.step(model, g["model"], scale_by=1.0/x.shape[0])


def train(dl, T, *, model: Model, optim_w: pxu.Optim, optim_h: pxu.Optim):
    model.vodes[-1].h.frozen = True
    for x, y in tqdm(dl):
        train_on_batch(T, x, y, model=model, optim_w=optim_w, optim_h=optim_h)

In [9]:
@pxf.jit(static_argnums=0)
def eval_on_batch(
    T: int,
    x: jax.Array, 
    *, 
    model: Model,
    optim_h: pxu.Optim
    ):
    def h_step(i, x, *, model, optim_h):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, y_), g = pxf.value_and_grad(
                pxu.M_hasnot(pxc.VodeParam, frozen=True).to([False, True]),
                has_aux=True
            )(energy)(x, model=model)
        optim_h.step(model, g["model"])
        return x, None

    model.train()

    if model.vodes[-1].h.frozen:
        print("vode[-1] should not be frozen! set frozen=False before calling eval function.")
        
    # Init step
    with pxu.step(model, (pxc.STATUS.INIT, None), clear_params=pxc.VodeParam.Cache):
        forward(x, None, model=model)
    optim_h.init(pxu.M(pxc.VodeParam)(model))
    
    # Inference steps
    x, y_ = pxf.scan(h_step, xs=jax.numpy.arange(T))(x, model=model, optim_h=optim_h)
    
    optim_h.clear()


# MCPC evaluation loop for 1D data
def eval(dl, T, *, model: Model, optim_h: pxu.Optim):
    model.vodes[-1].h.frozen = False
    ys = []
    ys_ = []
    
    for x, y in dl:
        eval_on_batch(T, x, model=model, optim_h=optim_h)
        ys.append(y)
        ys_.append(model.vodes[-1].get("h"))

    ys = np.concatenate(ys, axis=0)
    ys_ = np.concatenate(ys_, axis=0)

    return wasserstein_distance(ys.squeeze(), ys_.squeeze()), ys_

In [10]:
batch_size = 32

model = Model(
    input_dim=1,
    hidden_dim=1,
    output_dim=1,
    nm_layers=2,
    act_fn= lambda x:x
)

In [11]:
## define noisy sgd optimiser for MCPC
def sgdld(
    learning_rate: base.ScalarOrSchedule,
    momentum: Optional[float] = None,
    h_var: float = 1.0,
    gamma: float = 0.,
    nesterov: bool = False,
    accumulator_dtype: Optional[Any] = None,
    seed: int = lambda: px.RKG(1)[0],
) -> base.GradientTransformation:
    def optim_fn():
        eta = 2*h_var*(1-momentum)/learning_rate if momentum is not None else 2*h_var/learning_rate
        s = seed()
        return combine.chain(
            transform.add_noise(eta, gamma, s),
            (transform.trace(decay=momentum, nesterov=nesterov,
                            accumulator_dtype=accumulator_dtype)
            if momentum is not None else base.identity()),
            transform.scale_by_learning_rate(learning_rate)
        )
    return optim_fn

In [12]:
h_optimiser_fn = sgdld
lr = 1e-1
momentum = 0.5
h_var = 1.0
gamma = 0.
lr_p = 1e-3

mean = 1
var = 5

nm_elements = 10240
X = np.zeros((batch_size * (nm_elements // batch_size), 1))
y = np.random.randn(batch_size * (nm_elements // batch_size)).reshape(-1,1) * np.sqrt(var) + mean

nm_elements_test = 1024
X_test = np.zeros((batch_size * (nm_elements_test // batch_size), 1))
y_test = np.random.randn(batch_size * (nm_elements // batch_size)).reshape(-1,1) * np.sqrt(var) + mean


# we split the dataset in training batches and do the same for the generated test set.
train_dl = list(zip(X.reshape(-1, batch_size, 1), y.reshape(-1, batch_size, 1)))
test_dl = tuple(zip(X_test.reshape(-1, batch_size, 1), y_test.reshape(-1, batch_size, 1)))

nm_epochs = 5120 // (nm_elements // batch_size)

In [None]:
plt.hist(y, alpha = 0.5, density=True)
plt.hist(y_test, alpha = 0.5, density=True)
plt.show()

In [None]:
import random

with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
    forward(jax.numpy.zeros((batch_size, 1)), None, model=model)
    model.vodes[-1].h.frozen = True
    optim_h = pxu.Optim(h_optimiser_fn(lr, momentum, h_var, gamma))
    optim_w = pxu.Optim(lambda: optax.adam(lr_p), pxu.M(pxnn.LayerParam)(model))
    # make optimiser that also optimises the activity of the model layer[-1]
    model.vodes[-1].h.frozen = False


T = 100
T_eval = 100
w, y_ = eval(test_dl, T = T_eval, model=model, optim_h=optim_h)
print(f"Epoch {0}/{nm_epochs} - Wasserstein distance: {w :.2f}")
for e in range(nm_epochs):
    random.shuffle(train_dl)
    train(train_dl, T=T, model=model, optim_w=optim_w, optim_h=optim_h)
    if e %5 == 4 or e == nm_epochs - 1:
        w, y_ = eval(test_dl, T = T_eval, model=model, optim_h=optim_h)
        print(f"Epoch {e + 1}/{nm_epochs} - Wasserstein distance: {w :.2f}")

print(f"Learned data distribution has mean {y_.mean():.2f} and var {y_.var():.2f} ")
print(f"Learned parameters weight {model.layers[-1].nn.weight.get()[0,0] :.2f} and bias {model.layers[0].nn.bias.get()[0] :.2f}")



In [None]:
plt.hist(y, label = "data", density=True, alpha=0.5, bins=30)
plt.hist(y_, label = "learned", density=True, alpha=0.5, bins=30)
plt.ylabel("pdf")
plt.xlabel("y")
plt.legend()
plt.tight_layout()
plt.show()