# Classifying MNIST with a Hierarchical PCN

## Set Up the Notebook

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, RandomSampler
from torchvision import datasets
from torchvision.transforms import v2
from tqdm.notebook import tqdm

import pyromancy as pyro
from pyromancy.nodes import StandardGaussianNode

### Set the Compute Device and Datatype

In [None]:
device: str = "auto"
dtype: torch.dtype = torch.float32

if device == "auto":
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"

assert torch.empty([], device=device, dtype=dtype).is_floating_point()

iscpu = device.partition(":")[0].lower() == "cpu"
iscuda = device.partition(":")[0].lower() == "cuda"
ismps = device.partition(":")[0].lower() == "mps"

print(f"using {device} with {dtype} tensors")

### Load the MNIST Dataset with TorchVision

In [None]:
train_set = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=v2.Compose([v2.ToImage(), v2.ToDtype(dtype, scale=True)]),
)

test_set = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=v2.Compose([v2.ToImage(), v2.ToDtype(dtype, scale=True)]),
)

## Define the Model

In [None]:
class PCN(nn.Module):

    def __init__(self) -> None:
        nn.Module.__init__(self)
        self.nodes = nn.ModuleList(StandardGaussianNode(n) for n in (784, 256, 256, 10))
        self.edges = nn.ModuleList(
            nn.Sequential(
                nn.ReLU(), nn.Linear(self.nodes[ell].size, self.nodes[ell + 1].size)
            )
            for ell in range(len(self.nodes) - 1)
        )

    def reset(self) -> None:
        self.zero_grad()
        for node in self.nodes:
            node.reset()

    @torch.no_grad()
    def init_x(self, x: torch.Tensor) -> None:
        self.reset()
        z = self.nodes[0].init(x)
        for node, edge in zip(self.nodes[1:], self.edges):
            z = node.init(edge(z))

    @torch.no_grad()
    def init_xy(self, x: torch.Tensor, y: torch.Tensor) -> None:
        self.reset()
        z = self.nodes[0].init(x)
        for node, edge in zip(self.nodes[1:-1], self.edges[:-1]):
            z = node.init(edge(z))
        _ = self.nodes[-1].init(y)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mu = x
        for edge in self.edges:
            mu = edge(mu)
        return mu

    def energy(self) -> torch.Tensor:
        vfe = self.nodes[0].value.new_zeros(self.nodes[0].value.size(0))

        mu = self.nodes[0].value
        for node, edge in zip(self.nodes[1:], self.edges):
            mu = edge(mu)
            vfe.add_(node.energy(mu))

        return vfe


pcn = PCN().to(dtype=dtype, device=device)

## Configure the Training Procedure

In [None]:
epochs: int = 10
batch_size: int = 500
num_esteps: int = 32

e_opt = optim.SGD(
    pyro.get_estep_params(pcn, exclude=(pcn.nodes[0], pcn.nodes[-1])), lr=0.2
)
m_opt = optim.Adam(pyro.get_mstep_params(pcn), lr=0.001)

## Training/Testing Loop

In [None]:
nbatches = len(train_set) // batch_size
accs = []

for _ in tqdm(range(epochs), desc="Epoch", initial=0, total=epochs, position=0):
    # set training mode and reset
    pcn.train()

    # load and sample training set
    sampler = RandomSampler(
        train_set,
        replacement=False,
    )
    loader = DataLoader(
        train_set,
        batch_size,
        sampler=sampler,
        drop_last=True,
        pin_memory=iscuda,
        pin_memory_device="" if not iscuda else device,
    )

    # training loop
    for x, y in tqdm(
        loader, desc="Batch", initial=0, total=nbatches, leave=False, position=1
    ):
        # prepare data
        x = x.to(device=device).flatten(1)
        y = y.to(device=device)

        # initialize pcn with data
        pcn.init_xy(x, F.one_hot(y, 10).to(dtype=dtype))

        # perform E-steps
        for _ in range(num_esteps):
            pcn.zero_grad()
            pcn.energy().mean().backward(inputs=e_opt.param_groups[0]["params"])
            e_opt.step()

        # perform M-step
        pcn.zero_grad()
        pcn.energy().mean().backward(inputs=m_opt.param_groups[0]["params"])
        m_opt.step()

    # set inference mode and reset
    pcn.eval()

    # load testing set
    loader = DataLoader(
        test_set,
        len(test_set),
        shuffle=False,
        pin_memory=iscuda,
        pin_memory_device="" if not iscuda else device,
    )
    x, y = next(iter(loader))

    # prepare data
    x = x.to(device=device).flatten(1)
    y = y.to(device=device)

    # forward inference
    ypred = pcn(x)
    accs.append((y == ypred.argmax(1)).float().mean().item())

# print results
print("Epoch    Accuracy")
for e, acc in enumerate(accs, 1):
    print(f"{e:>5}    {f'{acc:.5f}':<8}")