In [None]:
# Set working directory one level up (as if mnist_poc folder never existed)
import os

os.chdir("..")

In [None]:
# Handy imports
from pc_e import PCE

import torch
import matplotlib.pyplot as plt

# Track layerwise energies

In [None]:
import torch.nn.functional as F


class TrackedEnergies(PCE):
    def __init__(self, architecture, iters, e_lr, w_lr):
        super().__init__(architecture, iters, e_lr, w_lr)
        self.log_E_errors = []
        self.log_E_states = []

    # ERROR OPTIMIZATION
    def minimize_error_energy(self, x, y):
        self.log_E_errors.clear()
        return super().minimize_error_energy(x, y)

    def E_errors_layerwise(self, x: torch.Tensor, y: torch.Tensor):
        E_errors = [0.5 * torch.linalg.vector_norm(e, ord=2, dim=None) ** 2 for e in self.errors]

        return E_errors + [self.class_loss(self.y_pred(x), y)]

    def E(self, x, y):
        E_layers = self.E_errors_layerwise(x, y)
        self.log_E_errors.append([E.detach() for E in E_layers])
        return sum(E_layers)

    # STATE OPTIMIZATION
    def minimize_state_energy(self, x, y, iters, s_lr):
        self.log_E_states.clear()
        return super().minimize_state_energy(x, y, iters, s_lr)

    def E_states_only_layerwise(self, x: torch.Tensor, y: torch.Tensor, states: list[torch.Tensor]):
        def half_mse_loss(y_pred, y):
            return 0.5 * F.mse_loss(y_pred, y, reduction="sum")

        losses = [half_mse_loss] * len(states) + [self.class_loss]
        states = [x] + states + [y]

        return list(
            loss(layer(s_i), s_ip1)
            for s_i, s_ip1, layer, loss in zip(states[:-1], states[1:], self.layers, losses)
        )

    def E_states_only(self, x, y, states):
        E_layers = self.E_states_only_layerwise(x, y, states)
        self.log_E_states.append([E.detach() for E in E_layers])
        return sum(E_layers)

# Load a single data item

In [None]:
from datamodules import EMNIST
from lightning import seed_everything

seed_everything(42)  # always same batch & weights (later on)

batch_size = 1
dm = EMNIST(batch_size)
dataset_name = dm.dataset_name  # needs to happen before setup!
print("Training on", dm.dataset_name)
dm.setup("fit")
dl = dm.train_dataloader()
batch = next(iter(dl))
batch = dm.on_after_batch_transfer(batch, 0)
x, y = batch["img"], batch["y"]
print(x.shape, y.shape)

In [None]:
plt.imshow(x[0].reshape(28, 28).T, cmap="gray")
print(y)

# Load untrained model

In [None]:
from get_arch import get_architecture

architecture = get_architecture(dataset="EMNIST-deep", use_CELoss=True)

In [None]:
pc = TrackedEnergies(architecture, iters=8, e_lr=0.1, w_lr=None)
pc.minimize_error_energy(x, y)
pc.minimize_state_energy(x, y, iters=64, s_lr=0.1)
print("All done here!")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import LinearSegmentedColormap


def make_double_figure(list1, list2):
    list1 = np.array(list1)
    list2 = np.array(list2)

    # Hide zeros from plt's logscale
    list1[list1 == 0] = np.nan
    list2[list2 == 0] = np.nan

    log1 = np.log10(list1.T)
    log2 = np.log10(list2.T)

    # Compute shared color scale limits
    combined = np.concatenate([log1[~np.isnan(log1)], log2[~np.isnan(log2)]])
    vmin, vmax = np.min(combined), np.max(combined)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5), sharey=True)

    cmap = plt.get_cmap("inferno")
    cmap = LinearSegmentedColormap.from_list("short_inferno", cmap(np.linspace(0.0, 0.9, num=256)))
    im1 = ax1.imshow(log1, aspect="auto", interpolation="nearest", vmin=vmin, vmax=vmax, cmap=cmap)
    im2 = ax2.imshow(log2, aspect="auto", interpolation="nearest", vmin=vmin, vmax=vmax, cmap=cmap)

    # ax1.set_title(r"$\bf{State\ Optimization}$ (standard)", fontsize=14)
    # ax2.set_title(r"$\bf{Error\ Optimization}$ (ours)", fontsize=14)
    for ax in [ax1, ax2]:
        ax.set_xlabel("Time (optimization steps)")
    ax1.set_ylabel("Input Layer $i$ Output")  # will move these to correct position in Inkscape
    ax1.set_yticks([0, 5, 10, 15, 19])
    ax2.tick_params(axis="y", which="both", left=False, labelleft=False)

    # Create shared colorbar between plots
    divider = make_axes_locatable(ax2)
    cax = divider.append_axes("left", size="5%", pad=0.05)
    cbar = fig.colorbar(im2, cax=cax, orientation="vertical", ticks=np.linspace(vmin, vmax, num=6))
    cbar.ax.tick_params(labelsize=11)

    formatter = ticker.FuncFormatter(lambda x, _: f"$10^{{{int(x)}}}$")
    cbar.ax.yaxis.set_major_formatter(formatter)
    cax.yaxis.set_ticks_position("left")
    cax.yaxis.set_label_position("left")
    cax.invert_xaxis()

    plt.tight_layout()
    fig.savefig("mnist_poc/fig1.svg")
    plt.show()

In [None]:
make_double_figure(pc.log_E_states, pc.log_E_errors)