In [None]:
# Set working directory one level up (as if mnist_poc folder never existed)
import os

os.chdir("..")

## Define architecture

In [None]:
from torch import nn
from lightning import seed_everything

seed_everything(42)


# Use proper initialization for Linear
class MyLinear(nn.Linear):
    def reset_parameters(self):
        gain = nn.init.calculate_gain("linear")
        # nn.init.xavier_uniform_(self.weight, gain)
        nn.init.orthogonal_(self.weight, gain)
        if self.bias is not None:
            nn.init.zeros_(self.bias)


architecture = (
    [MyLinear(28 * 28, 128, bias=False)]
    + [MyLinear(128, 128, bias=False) for _ in range(18)]
    + [MyLinear(128, 10, bias=False)]
)

## Prerain architecture for a few steps
Overall, this deep linear architecture is really not great. It has no output activation, which gives it bad results and great tendency for instability. That's why we only pretrain on a few batches.

In [None]:
from datamodules import EMNIST
from lightning import Trainer
from pc_variants import BackpropMSE

# 0: load dataset as Lightning DataModule
datamodule = EMNIST(batch_size=64)
print("Training on", datamodule.dataset_name)

# 1: Set up Lightning trainer
trainer = Trainer(
    accelerator="cpu",
    devices=1,
    logger=False,
    max_epochs=2,
    inference_mode=False,  # inference_mode would interfere with the state backward pass
    limit_predict_batches=1,  # enable 1-batch prediction
)

# 2: Train model weights with backprop (fast & neutral method)
pc = BackpropMSE(architecture, iters=None, e_lr=None, w_lr=0.001)
trainer.fit(pc, datamodule=datamodule)
trainer.test(pc, datamodule=datamodule)

## Get a single batch x,y

In [None]:
seed_everything(42)

batch_size = 64
dm = EMNIST(batch_size)
print("Training on", dm.dataset_name)
dm.setup("fit")
dl = dm.train_dataloader()
batch = next(iter(dl))
batch = dm.on_after_batch_transfer(batch, 0)
x, y = batch["img"], batch["y"]
print(x.shape, y.shape)

## Calculate analytical solution

In [None]:
from analytical_solution import get_final_states

true_optimum = get_final_states(architecture, x, y)
for i, s_i in enumerate(true_optimum, start=1):
    print(f"x^{ {i} } shape: {s_i.shape}")

## Do mini hyperparameter search to find best learning rate for states and errors

In [None]:
import torch
from tracked_pce import TrackedPCE


def get_conv_score(states):
    result = sum(
        (torch.norm(o - s, dim=1) / torch.norm(o, dim=1)).mean()
        for o, s in zip(true_optimum, states[-1])
    )
    return result.item() / len(true_optimum)


print("Starting hyperparameter sweep for Error Optimization...")
best_score = float("inf")
best_e_lr = None
for e_lr in [0.01, 0.05, 0.1, 0.3]:
    pc = TrackedPCE(architecture, iters=256, e_lr=e_lr, w_lr=0.001)
    pc.minimize_error_energy(x, y)
    score = get_conv_score(pc.log_errors)
    print(e_lr, score)
    if score <= best_score:
        best_score = score
        best_e_lr = e_lr

print("Starting hyperparameter sweep for State Optimization...")
best_score = float("inf")
best_s_lr = None
for s_lr in [0.1, 0.3, 0.5]:
    pc.minimize_state_energy(x, y, iters=4096, s_lr=s_lr)
    score = get_conv_score(pc.log_states)
    print(s_lr, score)
    if score <= best_score:
        best_score = score
        best_s_lr = s_lr

## Rerun with optimal hyperparams

In [None]:
print(f"Running Error Optimization with {best_e_lr=}")
pc = TrackedPCE(architecture, iters=256, e_lr=best_e_lr, w_lr=0.001)
pc.minimize_error_energy(x, y)
print(f"Running State Optimization with {best_s_lr=}")
pc.minimize_state_energy(x, y, iters=4096, s_lr=best_s_lr)
print("All done here!")

## Make plot to compare activation convergence

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch


def plot_optimization_comparison(true_optimum, method1, method2, layers):
    """
    Plot any number of layers side-by-side, comparing analytical optimum and two iterative methods.

    Parameters:
    - true_optimum: List of analytical solutions for each layer.
    - method1: List of intermediate values for Error optim (indexed by time).
    - method2: List of intermediate values for State optim (indexed by time).
    - layers: List of integers specifying the layers to be plotted.
    """
    if len(layers) == 0:
        raise ValueError("You must specify at least one layer to plot.")

    # Create a figure with subplots for each layer
    fig, axes = plt.subplots(1, len(layers), figsize=(4 * len(layers), 4), sharey=True)
    axes = np.atleast_1d(axes)  # Ensure axes is always iterable, even for one layer

    # Convert method1 and method2 lists into numpy arrays
    method1 = [torch.stack(timestep[:-1]) for timestep in method1]
    method2 = [torch.stack(timestep) for timestep in method2]

    method1_array = np.array(method1)  # Shape: (time_steps_method1, states, batch_size, state_dim)
    method2_array = np.array(method2)  # Shape: (time_steps_method2, states, batch_size, state_dim)

    for i, layer in enumerate(layers):
        ax = axes[i]

        # Extract the true optimum for the current layer
        optimum = np.array(true_optimum[layer].squeeze(0))  # Shape: (components,)

        # Extract the corresponding components for method1 and method2
        method1_layer_values = method1_array[:, layer, :, :]
        method2_layer_values = method2_array[:, layer, :, :]

        # Ensure method1_values have the same length as method2_values
        if method1_layer_values.shape[0] < method2_layer_values.shape[0]:
            method1_layer_values = np.concatenate(
                [
                    method1_layer_values,
                    np.repeat(
                        method1_layer_values[-1:],
                        method2_layer_values.shape[0] - method1_layer_values.shape[0],
                        axis=0,
                    ),
                ],
                axis=0,
            )

        # Take norm across state components
        method1_layer_values = np.linalg.norm(
            np.subtract(method1_layer_values, optimum, dtype=np.float64), ord=2, axis=-1
        )
        method2_layer_values = np.linalg.norm(
            np.subtract(method2_layer_values, optimum, dtype=np.float64), ord=2, axis=-1
        )

        # Calculate medians and quartiles over the batch size
        method1_q1 = np.percentile(method1_layer_values, 25, axis=-1)
        method1_median = np.percentile(method1_layer_values, 50, axis=-1)
        method1_q3 = np.percentile(method1_layer_values, 75, axis=-1)

        method2_q1 = np.percentile(method2_layer_values, 25, axis=-1)
        method2_median = np.percentile(method2_layer_values, 50, axis=-1)
        method2_q3 = np.percentile(method2_layer_values, 75, axis=-1)

        # Plot median and IQR
        time_steps = np.arange(method1_median.shape[0])
        ax.plot(time_steps, method2_median, "C3-", label="State Optimization")
        ax.fill_between(
            time_steps,
            method2_q1,
            method2_q3,
            color="C3",
            alpha=0.3,
            label="IQR (State Optimization)",
        )
        ax.plot(time_steps, method1_median, "b-", label="Error Optimization (ours)")
        ax.fill_between(
            time_steps,
            method1_q1,
            method1_q3,
            color="b",
            alpha=0.3,
            label="IQR (Error Optimization)",
        )

        # Add titles and labels
        ax.set_title(f"Convergence of $\mathbf{{s_{{{layer}}}}}$ to optimum", fontsize=15)
        ax.set_xlabel("Update steps T (log scale)")
        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.set_ylim(7e-5, 1)

        # Emphasize faster convergence
        eo_converged = np.argmax(method1_median < 7e-5)
        so_converged = np.argmax(method2_median < 7e-5)

        ax.annotate(
            "",
            xy=(eo_converged, 9e-5),
            xycoords="data",
            xytext=(so_converged, 9e-5),
            textcoords="data",
            arrowprops=dict(arrowstyle="<->", color="black", lw=2),
        )

        # Caption above midpoint
        how_much_faster = ["140", "140", "130"]
        ax.text(
            (eo_converged * so_converged) ** 0.5,
            1e-4,
            "Same equilibrium,\n~" + how_much_faster[i] + "× faster",
            ha="center",
            va="bottom",
            fontsize=10,
        )

    # Set shared y-axis label
    axes[0].set_ylabel("$\|\mathbf{s_i} - \mathbf{s_i^*}\|$")

    # Add legend
    handles = [
        plt.Line2D([0], [0], color="b", linestyle="-"),
        plt.Line2D([0], [0], color="C3", linestyle="-"),
    ]
    fig.legend(
        handles,
        ["Error-based PC (ours)", "State-based PC"],
        loc="upper center",
        ncol=2,
        prop={"size": 12},
    )

    plt.tight_layout(rect=[0, 0, 1, 0.90])  # Make space for the global legend
    plt.show()

In [None]:
plot_optimization_comparison(
    true_optimum,
    pc.log_errors,
    pc.log_states,
    layers=[0, 9, 18],
)

## Plot to compare gradient similarities between EO, SO and backprop

In [None]:
def from_states_to_grads(states_list):
    result = []
    pc.log_everything = False  # disable state tracking
    for states in states_list:
        pc.zero_grad()
        E = pc.E_states_only(x, y, states) / batch_size
        E.backward()
        allgrads = [p.grad.clone().ravel() for p in pc.layers.parameters() if p.grad is not None]
        result.append(allgrads)

    # Reset everything back to how it was before...
    pc.zero_grad()
    pc.log_everything = True

    return result

In [None]:
true_pc_grads = from_states_to_grads([true_optimum])  # shape: (1, layers, flat_param_len)
eo_grads = from_states_to_grads(pc.log_errors)  # shape: (timesteps, layers, flat_param_len)
so_grads = from_states_to_grads(pc.log_states)  # shape: (timesteps, layers, flat_param_len)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from torch.nn.functional import cosine_similarity


def plot_gradient_comparison_BP(true_pc_grads, eo_grads, so_grads, layers):
    """
    Plot cosine similarity of gradients wrt BP (the 1-step EO gradient).
    Uses a broken y-axis to show both [0.0] and [0.9, 1.0].
    A horizontal dashed line marks the analytical PC gradients similarity to BP.

    Parameters:
    - true_pc_grads: list of lists of tensors [timesteps][layers] with gradients of the analytical optimum
    - eo_grads: list of lists of tensors [timesteps][layers] with gradients from Error Optimization
    - so_grads: list of lists of tensors [timesteps][layers] with gradients from State Optimization
    - layers: list of integers specifying the layers to be plotted
    """
    if len(layers) == 0:
        raise ValueError("You must specify at least one layer to plot.")

    # Create a figure with 2 rows (for broken axis) and len(layers) columns
    fig, axes = plt.subplots(
        2,
        len(layers),
        figsize=(4 * len(layers), 4),
        sharex=True,
        gridspec_kw={"height_ratios": [8, 1]},  # top panel taller
    )
    axes = np.atleast_2d(axes)  # shape (2, n_layers)

    for i, layer in enumerate(layers):
        ax_top = axes[0, i]
        ax_bottom = axes[1, i]

        # Define BP = 1-step EO
        bp_grads = eo_grads[1][layer]

        # Extract trajectories
        eo_layer_values = [step[layer] for step in eo_grads]
        so_layer_values = [step[layer] for step in so_grads]

        # Ensure same length by padding EO with its last value if needed
        if len(eo_layer_values) < len(so_layer_values):
            eo_layer_values += [eo_layer_values[-1]] * (len(so_layer_values) - len(eo_layer_values))

        # Compute cosine similarity wrt bp_grads
        eo_sims = [cosine_similarity(val, bp_grads, dim=0).cpu().numpy() for val in eo_layer_values]
        so_sims = [cosine_similarity(val, bp_grads, dim=0).cpu().numpy() for val in so_layer_values]

        eo_sims = np.array(eo_sims)
        so_sims = np.array(so_sims)

        # True optimum similarity wrt bp_grads (constant line)
        optimum = true_pc_grads[0][layer].detach()
        optimum_sim = cosine_similarity(optimum, bp_grads, dim=0).cpu().numpy().item()

        # Plot curves on both axes
        time_steps = np.arange(len(eo_sims))
        for ax in (ax_top, ax_bottom):
            ax.plot(time_steps, so_sims, "C3-", label="State-based PC")
            ax.plot(time_steps, eo_sims, "b-", label="Error-based PC (ours)")
            ax.axhline(optimum_sim, color="k", linestyle="--", label="Analytical PC gradients")
            ax.axhline(1.0, color="k", linestyle=":", label="Backpropagation")
            ax.set_xscale("log")

        # Set y-limits for broken axis
        ax_top.set_ylim(0.9, 1.012)
        ax_bottom.set_ylim(-0.1, 0.1)
        ax_bottom.set_yticks([0])
        ax_bottom.set_yticklabels(["0"])

        # Add diagonal break marks
        d = 0.015
        kwargs = dict(transform=ax_top.transAxes, color="k", clip_on=False, linewidth=1)
        ax_top.plot((-d, +d), (0, 0), **kwargs)  # left diagonal mark
        ax_top.plot((1 - d, 1 + d), (0, 0), **kwargs)  # right diagonal mark

        kwargs.update(transform=ax_bottom.transAxes)
        ax_bottom.plot((-d, +d), (1, 1), **kwargs)
        ax_bottom.plot((1 - d, 1 + d), (1, 1), **kwargs)

        # Titles and labels
        ax_top.set_title(f"Gradient orientation at layer {layer}", fontsize=15)
        ax_bottom.set_xlabel("Update steps T (log scale)")

        # Hide spines between top and bottom
        ax_top.spines["bottom"].set_visible(False)
        ax_bottom.spines["top"].set_visible(False)
        ax_top.xaxis.set_ticks_position("none")  # hide bottom ticks of top axis
        ax_bottom.tick_params(top=False)  # hide top ticks of bottom axis

    # Shared y-axis label
    axes[0, 0].set_ylabel("Cosine similarity w.r.t. backprop")

    # Legend
    handles = [
        plt.Line2D([0], [0], color="b", linestyle="-"),
        plt.Line2D([0], [0], color="C3", linestyle="-"),
        plt.Line2D([0], [0], color="k", linestyle=":"),
        plt.Line2D([0], [0], color="k", linestyle="--"),
    ]
    fig.legend(
        handles,
        [
            "Error-based PC (ours)",
            "State-based PC",
            "Backpropagation",
            "Analytical PC gradients",
        ],
        loc="upper center",
        ncol=4,
        prop={"size": 12},
    )

    plt.tight_layout(rect=[0, 0, 1, 0.90])
    plt.show()

In [None]:
plot_gradient_comparison_BP(
    true_pc_grads,
    eo_grads,
    so_grads,
    layers=[0, 9, 18],  # no biases, so the 18th param is simply the weight matrix of layer 18
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch


def plot_gradient_norms(true_pc_grads, eo_grads, so_grads, layers):
    """
    Plot L2 norms of gradients over time for EO and SO.
    Horizontal dashed lines mark bp_grads (2nd EO gradient) and analytical PC gradients.

    Parameters:
    - true_pc_grads: list of lists of tensors [timesteps][layers] with gradients of the analytical optimum
    - eo_grads: list of lists of tensors [timesteps][layers] with gradients from Error Optimization
    - so_grads: list of lists of tensors [timesteps][layers] with gradients from State Optimization
    - layers: list of integers specifying the layers to be plotted
    """
    if len(layers) == 0:
        raise ValueError("You must specify at least one layer to plot.")

    # Create a figure with subplots for each layer
    fig, axes = plt.subplots(1, len(layers), figsize=(4 * len(layers), 4), sharey=True)
    axes = np.atleast_1d(axes)

    for i, layer in enumerate(layers):
        ax = axes[i]

        # Define bp_grads = 1-step EO gradient
        bp_grads = eo_grads[1][layer]

        # Extract trajectories
        eo_layer_values = [step[layer] for step in eo_grads]
        so_layer_values = [step[layer] for step in so_grads]

        # Ensure same length by padding EO with its last value if needed
        if len(eo_layer_values) < len(so_layer_values):
            eo_layer_values += [eo_layer_values[-1]] * (len(so_layer_values) - len(eo_layer_values))

        # Compute L2 norms
        eo_norms = [torch.norm(val, p=2).cpu().numpy() for val in eo_layer_values]
        so_norms = [torch.norm(val, p=2).cpu().numpy() for val in so_layer_values]

        eo_norms = np.array(eo_norms)
        so_norms = np.array(so_norms)

        # Reference norms (constant horizontal lines)
        bp_norm = torch.norm(bp_grads, p=2).cpu().numpy().item()
        optimum_norm = torch.norm(true_pc_grads[0][layer], p=2).cpu().numpy().item()

        # Plot curves
        time_steps = np.arange(len(eo_norms))
        ax.plot(time_steps, so_norms, "C3-", label="State-based PC")
        ax.plot(time_steps, eo_norms, "b-", label="Error-based PC (ours)")

        # Plot reference lines
        ax.axhline(bp_norm, color="k", linestyle=":", label="Backpropagation")
        ax.axhline(optimum_norm, color="k", linestyle="--", label="Analytical PC gradients")

        # Titles and scaling
        ax.set_title(f"Gradient norm at layer {layer}", fontsize=15)
        ax.set_xlabel("Update steps T (log scale)")
        ax.set_xscale("log")

    # Shared y-axis label
    axes[0].set_ylabel("‖grad‖₂")

    # Legend (no need for; will use the same legend as above)
    # handles = [
    #     plt.Line2D([0], [0], color="b", linestyle="-"),
    #     plt.Line2D([0], [0], color="C3", linestyle="-"),
    #     plt.Line2D([0], [0], color="k", linestyle=":"),
    #     plt.Line2D([0], [0], color="k", linestyle="--"),
    # ]
    # fig.legend(
    #     handles,
    #     [
    #         "Error Optimization (ours)",
    #         "State Optimization",
    #         "Backpropagation",
    #         "Analytical PC gradients",
    #     ],
    #     loc="upper center",
    #     ncol=4,
    #     prop={"size": 12},
    # )

    plt.tight_layout(rect=[0, 0, 1, 0.90])
    plt.show()

In [None]:
plot_gradient_norms(
    true_pc_grads,
    eo_grads,
    so_grads,
    layers=[0, 9, 18],  # no biases, so the 18th param is simply the weight matrix of layer 18
)