In [None]:
from typing import Callable

# Core dependencies
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

# pcax
import pcax as px
import pcax.predictive_coding as pxc
import pcax.nn as pxnn
import pcax.functional as pxf
import pcax.utils as pxu

import torchvision
import torchvision.transforms as transforms
import torch
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Model definition
class TwoLayerNN(pxc.EnergyModule):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, act_fn: Callable[[jax.Array], jax.Array]) -> None:
        super().__init__()
        self.act_fn = px.static(act_fn)
        self.layers = [
            pxnn.Linear(input_dim, hidden_dim),
            pxnn.Linear(hidden_dim, output_dim)
        ]

        # create a glorot uniform initializer
        initializer = jax.nn.initializers.glorot_uniform()
        # now apply glorot uniform initialization to the weights only
        # the basic syntax is: model.layers[i].nn.weight.set(initializer(key, model.layers[i].nn.weight.shape))
        for l in self.layers:
            l.nn.weight.set(initializer(px.RKG(), l.nn.weight.shape))
        
        self.vodes = [
            pxc.Vode((hidden_dim,)),
            pxc.Vode((output_dim,), pxc.ce_energy)
        ]
        self.vodes[-1].h.frozen = True

    def __call__(self, x, y):
        for v, l in zip(self.vodes[:-1], self.layers[:-1]):
            x = v(self.act_fn(l(x)))
        x = self.vodes[-1](self.layers[-1](x))
        if y is not None:
            self.vodes[-1].set("h", y)
        return self.vodes[-1].get("u")


In [None]:
# This is a simple collate function that stacks numpy arrays used to interface
# the PyTorch dataloader with JAX. In the future we hope to provide custom dataloaders
# that are independent of PyTorch.
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

# The dataloader assumes cuda is being used, as such it sets 'pin_memory = True' and
# 'prefetch_factor = 2'. Note that the batch size should be constant during training, so
# we set 'drop_last = True' to avoid having to deal with variable batch sizes.
class TorchDataloader(torch.utils.data.DataLoader):
    def __init__(
        self,
        dataset,
        batch_size=1,
        shuffle=None,
        sampler=None,
        batch_sampler=None,
        num_workers=16,
        pin_memory=True,
        timeout=0,
        worker_init_fn=None,
        persistent_workers=True,
        prefetch_factor=2,
    ):
        super(self.__class__, self).__init__(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=numpy_collate,
            pin_memory=pin_memory,
            drop_last=True if batch_sampler is None else None,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            persistent_workers=persistent_workers,
            prefetch_factor=prefetch_factor,
        )

# This function returns the MNIST dataloaders for training and testing.
def get_dataloaders(batch_size: int):
    t = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        lambda x: x.view(-1).numpy()  # Flatten the image to a vector
    ])

    train_dataset = torchvision.datasets.MNIST(
        "~/tmp/mnist/",
        transform=t,
        download=True,
        train=True,
    )

    train_dataloader = TorchDataloader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=16,
    )

    test_dataset = torchvision.datasets.MNIST(
        "~/tmp/mnist/",
        transform=t,
        download=True,
        train=False,
    )
        
    test_dataloader = TorchDataloader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=16,
    )

    return train_dataloader, test_dataloader

In [None]:
# Training and evaluation functions
@pxf.vmap(pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=(0, 0), out_axes=0)
def forward(x, y, *, model: TwoLayerNN):
    return model(x, y)

@pxf.vmap(pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=(0,), out_axes=(None, 0), axis_name="batch")
def energy(x, *, model: TwoLayerNN):
    y_ = model(x, None)
    return jax.lax.pmean(model.energy().sum(), "batch"), y_

@pxf.jit(static_argnums=0)
def train_on_batch(T: int, x: jax.Array, y: jax.Array, *, model: TwoLayerNN, optim_w: pxu.Optim, optim_h: pxu.Optim):
    model.train()
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        forward(x, y, model=model)

    for _ in range(T):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            _, g = pxf.value_and_grad(pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]), has_aux=True)(energy)(x, model=model)
        optim_h.step(model, g["model"], True)

    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        _, g = pxf.value_and_grad(pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True)(energy)(x, model=model)
    optim_w.step(model, g["model"])

@pxf.jit()
def eval_on_batch(x: jax.Array, y: jax.Array, *, model: TwoLayerNN):
    model.eval()
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        y_ = forward(x, None, model=model).argmax(axis=-1)
    return (y_ == y).mean(), y_

def train(dl, T, *, model: TwoLayerNN, optim_w: pxu.Optim, optim_h: pxu.Optim):
    for x, y in dl:
        #print(f"x shape: {x.shape}, y shape: {y.shape}")  # Debugging line
        train_on_batch(T, x, jax.nn.one_hot(y, 10), model=model, optim_w=optim_w, optim_h=optim_h)

def eval(dl, *, model: TwoLayerNN):
    acc = []
    ys_ = []
    for x, y in dl:
        a, y_ = eval_on_batch(x, y, model=model)
        acc.append(a)
        ys_.append(y_)
    return np.mean(acc), np.concatenate(ys_)

In [None]:
# Model training and evaluation
batch_size = 128
nm_epochs = 10
model = TwoLayerNN(input_dim=784, hidden_dim=128, output_dim=10, act_fn=jax.nn.relu)

# Initialize the model and optimizers
with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
    forward(jax.numpy.zeros((batch_size, 784)), None, model=model)
    optim_h = pxu.Optim(optax.sgd(0.1), pxu.Mask(pxc.VodeParam)(model))
    optim_w = pxu.Optim(optax.adamw(1e-3), pxu.Mask(pxnn.LayerParam)(model))

train_dataloader, test_dataloader = get_dataloaders(batch_size)

In [None]:
for e in range(nm_epochs):
    train(train_dataloader, T=10, model=model, optim_w=optim_w, optim_h=optim_h)
    a, y = eval(test_dataloader, model=model)
    print(f"Epoch {e + 1}/{nm_epochs} - Test Accuracy: {a * 100:.2f}%")

In [None]:
print(model)

In [None]:
initializer = jax.nn.initializers.glorot_uniform()

In [None]:
# show the mean of weights of model.layers[0].nn.weight before initialization
model.layers[0].nn.weight.get().mean()

In [None]:
model.layers[0].nn.weight.set(initializer(px.RKG(), model.layers[0].nn.weight.shape))

In [None]:
# show the weights mean of model.layers[0].nn.weight after initialization
model.layers[0].nn.weight.get().mean()