In [None]:
# Set working directory one level up (as if mnist_poc folder never existed)
import os

os.chdir("..")

In [None]:
# Handy imports
import torch
from torch import nn
from lightning import seed_everything

## First: choose CrossEntropyLoss or MSELoss

In [None]:
USE_CROSSENTROPY_INSTEAD_OF_MSE = False

In [None]:
from types import MethodType


def use_CrossEntropyLoss(pc_module):
    """CELoss to avoid vanishing grads with state optim..."""

    # Define the new loss method using CrossEntropyLoss
    def class_loss(self, y_pred, y):
        return nn.functional.cross_entropy(y_pred, y, reduction="sum")

    # Override pc_module.class_loss with the new method
    pc_module.class_loss = MethodType(class_loss, pc_module)

    return pc_module

## Define architecture

In [None]:
seed_everything(42)  # needed for reproducible weights


# Use proper initialization for Linear
class MyLinear(nn.Linear):
    def reset_parameters(self):
        gain = nn.init.calculate_gain("relu")
        # nn.init.xavier_uniform_(self.weight, gain)
        nn.init.orthogonal_(self.weight, gain)
        if self.bias is not None:
            nn.init.zeros_(self.bias)


architecture = (
    [nn.Sequential(MyLinear(28 * 28, 128), nn.GELU())]
    + [nn.Sequential(MyLinear(128, 128), nn.GELU()) for _ in range(18)]
    + (
        [MyLinear(128, 10)]  # for CrossEntropy
        if USE_CROSSENTROPY_INSTEAD_OF_MSE
        else [nn.Sequential(MyLinear(128, 10), nn.Sigmoid())]  # for MSE
    )
)

## Pretrain architecture

In [None]:
from datamodules import EMNIST
from lightning import Trainer
from pc_e import PCE

# 0: load dataset as Lightning DataModule
datamodule = EMNIST(batch_size=64)
print("Training on", datamodule.dataset_name)

# 1: Set up Lightning trainer
trainer = Trainer(
    accelerator="cpu",  # keep everything on CPU, to make analysis easier...
    devices=1,
    logger=False,
    max_epochs=2,
    inference_mode=False,  # inference_mode would interfere with the state backward pass
    limit_predict_batches=1,  # enable 1-batch prediction
)


# 2: Define backprop version of PCE
class BackpropMSE(PCE):
    """Train weights with backprop as neutral baseline (not favoring EO or SO)"""

    def training_step(self, batch, batch_idx):
        x, y = batch["img"], batch["y"]
        self.forward(x)  # sets all errors to 0
        return self.class_loss(self.y_pred(x), y) / self.batch_size


# 3: Train model weights with backprop (fast & neutral method)
pc = BackpropMSE(architecture, iters=None, e_lr=None, w_lr=0.0003)
if USE_CROSSENTROPY_INSTEAD_OF_MSE:
    pc = use_CrossEntropyLoss(pc)
trainer.fit(pc, datamodule=datamodule)
trainer.test(pc, datamodule=datamodule)

## Find the easiest and the most difficult training sample

In [None]:
from tqdm.notebook import tqdm

seed_everything(42)  # needed for reproducible batch selection

batch_size = 1
dm = EMNIST(batch_size)
dm.setup("fit")
dl = dm.train_dataloader()

min_loss = float("inf")
min_batch = None
max_loss = float("-inf")
max_batch = None

with torch.no_grad():
    for batch in tqdm(dl):
        batch = dm.on_after_batch_transfer(batch, 0)
        loss = pc.training_step(batch, None)

        if loss < min_loss:
            min_loss = loss
            min_batch = batch
        if loss > max_loss:
            max_loss = loss
            max_batch = batch

print(f"Done! {min_loss=}, {max_loss=}")

## Get a single batch x,y

In [None]:
# Manually select which batch you want to run...
# x, y = min_batch["img"], min_batch["y"]
x, y = max_batch["img"], max_batch["y"]
print(x.shape, y.shape)

## Instantiate PCE model with state / error tracking

In [None]:
from mnist_poc.tracked_pce import TrackedPCE

pc = TrackedPCE(architecture, iters=256, e_lr=None, w_lr=0.001)
if USE_CROSSENTROPY_INSTEAD_OF_MSE:
    pc = use_CrossEntropyLoss(pc)

## Do mini hyperparameter search to find best learning rate for states and errors

In [None]:
print("Starting hyperparameter sweep for Error Optimization...")
best_score = float("inf")
best_e_lr = None
for e_lr in [0.001, 0.005, 0.01, 0.05, 0.1, 0.3]:
    pc.e_lr = e_lr
    pc.minimize_error_energy(x, y)
    score = pc.E(x, y)
    print(e_lr, score)
    if score <= best_score:
        best_score = score
        best_e_lr = e_lr

print("Starting hyperparameter sweep for State Optimization...")
best_score = float("inf")
best_s_lr = None
for s_lr in [0.01, 0.05, 0.1, 0.3, 0.5]:
    final_states = pc.minimize_state_energy(x, y, iters=4096, s_lr=s_lr)
    score = pc.E_states_only(x, y, final_states)
    print(s_lr, score)
    if score <= best_score:
        best_score = score
        best_s_lr = s_lr

## Rerun with optimal hyperparams

In [None]:
# Cast everything to float64 (needed for easy inputs)
dtype = torch.float64
pc.to(dtype)
x64 = x.to(dtype)
y64 = y.to(dtype)

print(f"Running Error Optimization with {best_e_lr=}")
pc.e_lr = best_e_lr
pc.iters = 4096 * 4
pc.minimize_error_energy(x64, y64)
print(f"Running State Optimization with {best_s_lr=}")
pc.minimize_state_energy(x64, y64, iters=131072, s_lr=best_s_lr)

# Cast back to default (for further experimenting in the notebook)
dtype = torch.get_default_dtype()
pc.to(dtype)
x.to(dtype)
y.to(dtype)
print("All done here!")

## Make plot to compare activation convergence

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch


def plot_optimization_comparison(true_optimum, method1, method2, layers):
    """
    Plot any number of layers side-by-side, comparing analytical optimum and two iterative methods.

    Parameters:
    - true_optimum: List of analytical solutions for each layer.
    - method1: List of intermediate values for Error optim (indexed by time).
    - method2: List of intermediate values for State optim (indexed by time).
    - layers: List of integers specifying the layers to be plotted.
    """
    if len(layers) == 0:
        raise ValueError("You must specify at least one layer to plot.")

    # Create a figure with subplots for each layer
    fig, axes = plt.subplots(1, len(layers), figsize=(4 * len(layers), 4), sharey=True)
    axes = np.atleast_1d(axes)  # Ensure axes is always iterable, even for one layer

    # Convert method1 and method2 lists into numpy arrays
    method1 = [torch.stack(timestep[:-1]) for timestep in method1]
    method2 = [torch.stack(timestep) for timestep in method2]

    method1_array = np.array(method1)  # Shape: (time_steps_method1, states, batch_size, state_dim)
    method2_array = np.array(method2)  # Shape: (time_steps_method2, states, batch_size, state_dim)

    for i, layer in enumerate(layers):
        ax = axes[i]

        # Extract the true optimum for the current layer
        optimum = np.array(true_optimum[layer].squeeze(0))  # Shape: (components,)

        # Extract the corresponding components for method1 and method2
        method1_layer_values = method1_array[:, layer, :, :]
        method2_layer_values = method2_array[:, layer, :, :]

        # Ensure method1_values have the same length as method2_values
        if method1_layer_values.shape[0] < method2_layer_values.shape[0]:
            method1_layer_values = np.concatenate(
                [
                    method1_layer_values,
                    np.repeat(
                        method1_layer_values[-1:],
                        method2_layer_values.shape[0] - method1_layer_values.shape[0],
                        axis=0,
                    ),
                ],
                axis=0,
            )

        # Take norm across state components
        method1_layer_values = np.linalg.norm(
            np.subtract(method1_layer_values, optimum, dtype=np.float64), ord=2, axis=-1
        )
        method2_layer_values = np.linalg.norm(
            np.subtract(method2_layer_values, optimum, dtype=np.float64), ord=2, axis=-1
        )
        method1_median = np.squeeze(method1_layer_values)
        method2_median = np.squeeze(method2_layer_values)

        # Plot median and IQR
        ax.plot(method2_median, "C3-", label="State Optimization")
        ax.plot(method1_median, "b-", label="Error Optimization (ours)")

        # Add titles and labels
        ax.set_title(f"Convergence of $\mathbf{{s_{{{layer}}}}}$ to optimum", fontsize=15)
        ax.set_xlabel("Optimization steps (log scale)")
        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.set_ylim(7e-8, 3)

    # Set shared y-axis label
    axes[0].set_ylabel("$\|\mathbf{s_i} - \mathbf{s_i^*}\|$")

    # Add legend
    handles = [
        plt.Line2D([0], [0], color="b", linestyle="-"),
        plt.Line2D([0], [0], color="C3", linestyle="-"),
    ]
    fig.legend(
        handles,
        ["Error Optimization (ours)", "State Optimization"],
        loc="upper center",
        ncol=2,
        prop={"size": 12},
    )

    plt.tight_layout(rect=[0, 0, 1, 0.90])  # Make space for the global legend
    plt.show()

In [None]:
plot_optimization_comparison(
    pc.log_errors[-1],  # assume error optimization gets to final equilibrium
    pc.log_errors,
    pc.log_states,
    layers=[0, 18],
)

In [None]:
# Sanity check (easy vs hard input): check what Error Optim gives as y_pred
print(pc.log_errors[-1][-1])
print(y)