In [None]:
# Set working directory one level up (as if mnist_poc folder never existed)
import os

os.chdir("..")

## Define architecture

In [None]:
from torch import nn
from lightning import seed_everything

seed_everything(42)


# Use proper initialization for Linear
class MyLinear(nn.Linear):
    def reset_parameters(self):
        gain = nn.init.calculate_gain("linear")
        # nn.init.xavier_uniform_(self.weight, gain)
        nn.init.orthogonal_(self.weight, gain)
        if self.bias is not None:
            nn.init.zeros_(self.bias)


architecture = (
    [MyLinear(28 * 28, 128, bias=False)]
    + [MyLinear(128, 128, bias=False) for _ in range(18)]
    + [MyLinear(128, 10, bias=False)]
)

## Prerain architecture for a few steps
Overall, this deep linear architecture is really not great. It has no output activation, which gives it bad results and great tendency for instability. That's why we only pretrain on a few batches.

In [None]:
from datamodules import EMNIST
from lightning import Trainer
from pc_variants import BackpropMSE

# 0: load dataset as Lightning DataModule
datamodule = EMNIST(batch_size=64)
print("Training on", datamodule.dataset_name)

# 1: Set up Lightning trainer
trainer = Trainer(
    accelerator="cpu",
    devices=1,
    logger=False,
    max_epochs=2,
    inference_mode=False,  # inference_mode would interfere with the state backward pass
    limit_predict_batches=1,  # enable 1-batch prediction
)

# 2: Train model weights with backprop (fast & neutral method)
pc = BackpropMSE(architecture, iters=None, e_lr=None, w_lr=0.001)
trainer.fit(pc, datamodule=datamodule)
trainer.test(pc, datamodule=datamodule)

## Get a single batch x,y

In [None]:
seed_everything(42)

batch_size = 64
dm = EMNIST(batch_size)
print("Training on", dm.dataset_name)
dm.setup("fit")
dl = dm.train_dataloader()
batch = next(iter(dl))
batch = dm.on_after_batch_transfer(batch, 0)
x, y = batch["img"], batch["y"]
print(x.shape, y.shape)

## Calculate analytical solution

In [None]:
from analytical_solution import get_final_states

true_optimum = get_final_states(architecture, x, y)
for i, s_i in enumerate(true_optimum, start=1):
    print(f"x^{ {i} } shape: {s_i.shape}")

## Do mini hyperparameter search to find best learning rate for states and errors

In [None]:
import torch
from tracked_pce import TrackedPCE


def get_conv_score(states):
    result = sum(
        (torch.norm(o - s, dim=1) / torch.norm(o, dim=1)).mean()
        for o, s in zip(true_optimum, states[-1])
    )
    return result.item() / len(true_optimum)


print("Starting hyperparameter sweep for Error Optimization...")
best_score = float("inf")
best_e_lr = None
for e_lr in [0.01, 0.05, 0.1, 0.3]:
    pc = TrackedPCE(architecture, iters=256, e_lr=e_lr, w_lr=0.001)
    pc.minimize_error_energy(x, y)
    score = get_conv_score(pc.log_errors)
    print(e_lr, score)
    if score <= best_score:
        best_score = score
        best_e_lr = e_lr

print("Starting hyperparameter sweep for State Optimization...")
best_score = float("inf")
best_s_lr = None
for s_lr in [0.1, 0.3, 0.5]:
    pc.minimize_state_energy(x, y, iters=4096, s_lr=s_lr)
    score = get_conv_score(pc.log_states)
    print(s_lr, score)
    if score <= best_score:
        best_score = score
        best_s_lr = s_lr

## Rerun with optimal hyperparams

In [None]:
print(f"Running Error Optimization with {best_e_lr=}")
pc = TrackedPCE(architecture, iters=256, e_lr=best_e_lr, w_lr=0.001)
pc.minimize_error_energy(x, y)
print(f"Running State Optimization with {best_s_lr=}")
pc.minimize_state_energy(x, y, iters=4096, s_lr=best_s_lr)
print("All done here!")

## Make plot to compare activation convergence

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch


def plot_optimization_comparison(true_optimum, method1, method2, layers):
    """
    Plot any number of layers side-by-side, comparing analytical optimum and two iterative methods.

    Parameters:
    - true_optimum: List of analytical solutions for each layer.
    - method1: List of intermediate values for Error optim (indexed by time).
    - method2: List of intermediate values for State optim (indexed by time).
    - layers: List of integers specifying the layers to be plotted.
    """
    if len(layers) == 0:
        raise ValueError("You must specify at least one layer to plot.")

    # Create a figure with subplots for each layer
    fig, axes = plt.subplots(1, len(layers), figsize=(4 * len(layers), 4), sharey=True)
    axes = np.atleast_1d(axes)  # Ensure axes is always iterable, even for one layer

    # Convert method1 and method2 lists into numpy arrays
    method1 = [torch.stack(timestep[:-1]) for timestep in method1]
    method2 = [torch.stack(timestep) for timestep in method2]

    method1_array = np.array(method1)  # Shape: (time_steps_method1, states, batch_size, state_dim)
    method2_array = np.array(method2)  # Shape: (time_steps_method2, states, batch_size, state_dim)

    for i, layer in enumerate(layers):
        ax = axes[i]

        # Extract the true optimum for the current layer
        optimum = np.array(true_optimum[layer].squeeze(0))  # Shape: (components,)

        # Extract the corresponding components for method1 and method2
        method1_layer_values = method1_array[:, layer, :, :]
        method2_layer_values = method2_array[:, layer, :, :]

        # Ensure method1_values have the same length as method2_values
        if method1_layer_values.shape[0] < method2_layer_values.shape[0]:
            method1_layer_values = np.concatenate(
                [
                    method1_layer_values,
                    np.repeat(
                        method1_layer_values[-1:],
                        method2_layer_values.shape[0] - method1_layer_values.shape[0],
                        axis=0,
                    ),
                ],
                axis=0,
            )

        # Take norm across state components
        method1_layer_values = np.linalg.norm(
            np.subtract(method1_layer_values, optimum, dtype=np.float64), ord=2, axis=-1
        )
        method2_layer_values = np.linalg.norm(
            np.subtract(method2_layer_values, optimum, dtype=np.float64), ord=2, axis=-1
        )

        # Calculate medians and quartiles over the batch size
        method1_q1 = np.percentile(method1_layer_values, 25, axis=-1)
        method1_median = np.percentile(method1_layer_values, 50, axis=-1)
        method1_q3 = np.percentile(method1_layer_values, 75, axis=-1)

        method2_q1 = np.percentile(method2_layer_values, 25, axis=-1)
        method2_median = np.percentile(method2_layer_values, 50, axis=-1)
        method2_q3 = np.percentile(method2_layer_values, 75, axis=-1)

        # Plot median and IQR
        time_steps = np.arange(method1_median.shape[0])
        ax.plot(time_steps, method2_median, "C3-", label="State Optimization")
        ax.fill_between(
            time_steps,
            method2_q1,
            method2_q3,
            color="C3",
            alpha=0.3,
            label="IQR (State Optimization)",
        )
        ax.plot(time_steps, method1_median, "b-", label="Error Optimization (ours)")
        ax.fill_between(
            time_steps,
            method1_q1,
            method1_q3,
            color="b",
            alpha=0.3,
            label="IQR (Error Optimization)",
        )

        # Add titles and labels
        ax.set_title(f"Convergence of $\mathbf{{s_{{{layer}}}}}$ to optimum", fontsize=15)
        ax.set_xlabel("Optimization steps (log scale)")
        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.set_ylim(7e-5, 1)

    # Set shared y-axis label
    axes[0].set_ylabel("$\|\mathbf{s_i} - \mathbf{s_i^*}\|$")

    # Add legend
    handles = [
        plt.Line2D([0], [0], color="b", linestyle="-"),
        plt.Line2D([0], [0], color="C3", linestyle="-"),
    ]
    fig.legend(
        handles,
        ["Error Optimization (ours)", "State Optimization"],
        loc="upper center",
        ncol=2,
        prop={"size": 12},
    )

    plt.tight_layout(rect=[0, 0, 1, 0.90])  # Make space for the global legend
    plt.show()

In [None]:
plot_optimization_comparison(
    true_optimum,
    pc.log_errors,
    pc.log_states,
    layers=[0, 9, 18],
)