### **Does NGD converge to more complex minima according to the RLCT (Real Log Canonical Threshold)?**

This notebook aims to test the above claim, using developmental interpretability methods.

We hypothesise that NGD will consistently have a higher RLCT, because it premultiplies the gradient used for gradient descent by the inverse of the Fisher Information Matrix.

SLT proposes that models converge to singularities, where the Fisher Information Matrix is non-invertable. Hence, when near a singularity, the inverse of the FIM will blow up as its determinant is close to 0. Therefore, NGD will "jump away" from the singularity, instead favouring more complex, less singular minima.

#### **Methodology**
- Choose between DNN (dense neural network) and CNN (convolutional neural network)
- Vary hidden nodes and hidden layers for DNN (depth and width)
- Vary number of convolutional layers for CNN.
- Estimate Hessian rank using `PyHessian` module.
- Estimate RLCT using `devinterp` library.

#### **Instructions**

To produce your own results, go to `args_ngd_sgd.json`. The following parameters can be adjusted:
- `model_type (str)`  : `dnn` (deep network), `cnn` (convolutional network)
- `optimizer (str)` : `sgd`, `ngd`, or `both` (performs analysis on both optimizers)
- `hessian (bool)` : `true`, `false` (does Hessian rank analysis if enabled)
- `num_epochs (int)` : total number of training epochs
- `cut_off_epochs (int)` : the epoch number at which the optimiser is swapped (make this equal to `num_epochs`) if you want no swapping
- `sgd_lr (float)` : sgd learning rate for training
- `ngd_lr (float)` : ngd learning rate for training
- `alpha (float)` : smoothing constant for estimation of FIM for NGD
- `momentum (float)` : momentum used in SGD and NGD
- `nesterov (bool)` : enable Nesterov momentum for NGD / SGD
- `batch_size (int)` : batch size for train / validation dataloaders
- `num_workers (int)` : number of GPU workers for data loading (keep this at around 6, vary depending on your hardware)
- `dataset (str)` : `mnist`, `cifar10` (dataset to use for training)
- `num_hessian_batches (int)` : number of batches used for estimation of Hessian
- `sampler (str)` : `sgld`, `sgnht` (optimiser to use for RLCT estimation)
- `num_chains (int)` : number of chains to use in RLCT estimation (higher leads to more accurate RLCT estimate)
- `num_draws (int)` : number of optimizer steps in RLCT estimation (should be high enough such that chain RLCT converges - check convergence plot)
- `localization (float)` : higher localization more strongly restricts optimizer to neighbourhood of model weights (stops RLCT estimator from going straight to minima)
- `sampler_lr (float)` : learning rate for sampler for RLCT estimation

You can run the notebook from the terminal using the following command:
```bash
jupyter nbconvert --to notebook --execute --inplace ./experiments/eval_ngd_sgd.ipynb
```

#### **0. Import libraries**

The `NGD` module is used for implementing Natural Gradient Descent efficiently. It approximates the Fisher Information Matrix to do this.
The `devinterp` library is used for estimation of the LLC (local learning coefficient).

In [13]:
from multiprocessing import freeze_support

import os
import sys
import copy
import pickle
import pprint
import json
from pathlib import Path
from datetime import datetime
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm

sys.path.append("../")

import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split

from devinterp.slt import estimate_learning_coeff_with_summary
from devinterp.optim import SGLD, SGNHT
from devinterp.slt import sample, OnlineLLCEstimator
from devinterp.slt.wbic import OnlineWBICEstimator
from devinterp.slt.mala import MalaAcceptanceRate
from devinterp.utils import plot_trace, optimal_temperature

from approxngd import KFAC
from PyHessian.pyhessian import *
from PyHessian.density_plot import *
from nngeometry.metrics import FIM
from nngeometry.object import PMatKFAC, PMatDiag, PVector

from utils_general import *
from utils_hessian_fim import *
from networks import *
from ngd import NGD

from plotly.subplots import make_subplots
import plotly.express as px
import plotly.colors
import plotly.graph_objects as go
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap

device = "cuda" if t.cuda.is_available() else "cpu"
print(f"Device in use: {device}")

%load_ext autoreload
%autoreload
%matplotlib inline

Device in use: cuda
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### **1. Import data for training**

We import our dataset for training. We use a helper function, `build_data_loaders` for this, which allows us to choose between MNIST and CIFAR10. 

We specify three dictionaries, `hp` = hyperparameters, `data_args` = arguments for dataloading, `devinterp_args` = arguments for LLC and WBIC estimation.

In [14]:
# Load experiment args

with open("args_ngd_sgd.json", "r") as file:
    args = json.load(file) 

hp, data_args, devinterp_args = args

print("HYPERPARAMETERS")
pprint.pprint(hp)

print("DATA ARGS")
pprint.pprint(data_args)

print("DEVINTERP ARGS")
pprint.pprint(devinterp_args)

HYPERPARAMETERS
{'alpha': 0,
 'cut_off_epochs': 10,
 'eta': 0.9,
 'experiment_type': 'swap',
 'hessian': True,
 'hidden_conv_layers': [1, 2, 3],
 'hidden_layers': [2],
 'hidden_nodes': [128],
 'model_type': 'dnn',
 'momentum': 0.9,
 'nesterov': True,
 'ngd_lr': 0.01,
 'num_epochs': 30,
 'optimizer': 'both',
 'seed': 43,
 'sgd_lr': 0.01}
DATA ARGS
{'batch_size': 128,
 'dataset': 'mnist',
 'num_hessian_batches': 1,
 'num_workers': 6}
DEVINTERP ARGS
{'localization': 100.0,
 'num_chains': 1,
 'num_draws': 1000,
 'sampler': 'sgld',
 'sampler_lr': 0.0001}


In [15]:
# Set random seed for weight initialisation (for reproducibility of results and to ensure NGD/SGD start from same point in loss landscape)

t.manual_seed(hp["seed"])

<torch._C.Generator at 0x1e230ec58b0>

In [16]:
train_loader, test_loader = build_data_loaders(data_args)

#### **2. Training models**

Choose to train either a DNN or CNN.

This code produces a dictionary where each key describes the model itself, e.g. "DNN 4 HL, 256 HN".

The values are each a list containing the first model. As models are trained, each epoch the trained model will be added to this list so we can record the model history.

In [4]:
# Initialise models dependent on arguments

models = {}
optimizers = ["sgd", "ngd"] if hp["optimizer"] == "both" else [hp["optimizer"]]

if hp["model_type"] == "dnn":
    hidden_nodes = hp["hidden_nodes"]
    hidden_layers = hp["hidden_layers"]
    for hidden_node, hidden_layer in zip(hidden_nodes, hidden_layers):
        title = f"DNN {hidden_layer} HL, {hidden_node} HN"
        if data_args["dataset"] == "mnist":
            model = LinearMNIST(hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        elif data_args["dataset"] == "cifar10":
            model = LinearCIFAR10(hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}
elif hp["model_type"] == "cnn":
    hidden_conv_layers = hp["hidden_conv_layers"]
    for hidden_conv_layer in hidden_conv_layers:
        title = f"CNN {hidden_conv_layer} HCL"
        if data_args["dataset"] == "mnist":
            model = CnnMNIST(hidden_conv_layers=hidden_conv_layer).to(device)
        elif data_args["dataset"] == "cifar10":
            model = CnnCIFAR10(hidden_conv_layers=hidden_conv_layer).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}

In [None]:
# EXPERIMENT 1: Train models independently for num_epochs with SGD and NGD

if hp["experiment_type"] == "standard":
    train_losses = {}
    val_losses = {}
    update_norms = {}
    all_update_norms = {}

    criterion = nn.CrossEntropyLoss()

    for title, model in models.items():

        # Store list for SGD losses, NGD losses, for train and val
        model_train_losses = {optim : [] for optim in optimizers}
        model_val_losses = {optim : [] for optim in optimizers}
        model_update_norms = {optim : [] for optim in optimizers}
        model_all_update_norms = {optim : [] for optim in optimizers}

        for optim in optimizers:
            state = copy.deepcopy(model[optim][0])
            if optim == "sgd":
                optimizer = t.optim.SGD(
                    params=state.parameters(),
                    lr=hp["sgd_lr"],
                    momentum=hp["momentum"],
                    nesterov=hp["nesterov"],
                )
            elif optim == "ngd":
                optimizer = NGD(
                    params=state.parameters(),
                    lr=hp["ngd_lr"],
                    alpha=hp["alpha"],
                    eta=hp["eta"],
                    momentum=hp["momentum"],
                    nesterov=hp["nesterov"],
                )
            print(f"TRAINING MODEL: {title} | OPTIMISER: {optim}")
            initial_train_loss = evaluate(state, train_loader, criterion, device)
            initial_val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses[optim].append(initial_train_loss)
            model_val_losses[optim].append(initial_val_loss)
            for epoch in range(1, hp["cut_off_epochs"]+1):
                train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optimizer, criterion, device)
                val_loss = evaluate(state, test_loader, criterion, device)
                model_train_losses[optim].append(train_loss)
                model_update_norms[optim].append(update_norm)
                model_val_losses[optim].append(val_loss)
                model_all_update_norms[optim] += epoch_update_norms
                model[optim].append(copy.deepcopy(state))
                print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")

        # Save train/val loss dictionaries to the right model key
        train_losses[title] = model_train_losses
        val_losses[title] = model_val_losses
        update_norms[title] = model_update_norms
        all_update_norms[title] = model_all_update_norms


In [5]:
# EXPERIMENT 2: Train to convergence with SGD, then continue with SGD and NGD and compare

if hp["experiment_type"] == "swap":
    train_losses = {}
    val_losses = {}
    update_norms = {}
    all_update_norms = {}

    criterion = nn.CrossEntropyLoss()

    for title, model in models.items():

        del model["ngd"][0]

        model_train_losses = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_val_losses = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_update_norms = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_all_update_norms = {"sgd" : [], "ngd" : [None for i in range(len(train_loader)*hp["cut_off_epochs"])]}

        state = copy.deepcopy(model["sgd"][0])
        optimizer = t.optim.SGD(
            params=state.parameters(),
            lr=hp["sgd_lr"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
    
        print(f"TRAINING MODEL: {title} | OPTIMISER: SGD")
        initial_train_loss = evaluate(state, train_loader, criterion, device)
        initial_val_loss = evaluate(state, test_loader, criterion, device)
        model_train_losses["sgd"].append(initial_train_loss)
        model_val_losses["sgd"].append(initial_val_loss)
        for epoch in range(1, hp["cut_off_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optimizer, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["sgd"].append(train_loss)
            model_update_norms["sgd"].append(update_norm)
            model_val_losses["sgd"].append(val_loss)
            model_all_update_norms["sgd"] += epoch_update_norms
            model["sgd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")


        # Train with SGD
        print("Training from model checkpoint with SGD.")
        state = copy.deepcopy(model["sgd"][hp["cut_off_epochs"]])
        optim_sgd = t.optim.SGD(
            params=state.parameters(),
            lr=hp["sgd_lr"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
        for epoch in range(hp["cut_off_epochs"]+1, hp["num_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optim_sgd, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["sgd"].append(train_loss)
            model_update_norms["sgd"].append(update_norm)
            model_val_losses["sgd"].append(val_loss)
            model_all_update_norms["sgd"] += epoch_update_norms
            model["sgd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
        
        # Train with NGD
        print("Training from model checkpoint with NGD.")
        state = copy.deepcopy(model["sgd"][hp["cut_off_epochs"]])
        optim_ngd = NGD(
            params=state.parameters(),
            lr=hp["ngd_lr"],
            alpha=hp["alpha"],
            eta=hp["eta"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
        for epoch in range(hp["cut_off_epochs"]+1, hp["num_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optim_ngd, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["ngd"].append(train_loss)
            model_update_norms["ngd"].append(update_norm)
            model_val_losses["ngd"].append(val_loss)
            model_all_update_norms["ngd"] += epoch_update_norms
            model["ngd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")

        # Save train/val loss dictionaries to the right model key
        train_losses[title] = model_train_losses
        val_losses[title] = model_val_losses
        update_norms[title] = model_update_norms
        all_update_norms[title] = model_all_update_norms


TRAINING MODEL: DNN 2 HL, 128 HN | OPTIMISER: SGD


100%|██████████| 469/469 [00:04<00:00, 107.65it/s]


Epoch 1/30: train_loss=0.4351, val_loss=0.1924


100%|██████████| 469/469 [00:04<00:00, 111.34it/s]


Epoch 2/30: train_loss=0.1624, val_loss=0.1306


100%|██████████| 469/469 [00:04<00:00, 112.53it/s]


Epoch 3/30: train_loss=0.1111, val_loss=0.1048


100%|██████████| 469/469 [00:04<00:00, 97.15it/s] 


Epoch 4/30: train_loss=0.0835, val_loss=0.0878


100%|██████████| 469/469 [00:06<00:00, 77.85it/s]


Epoch 5/30: train_loss=0.0666, val_loss=0.0800


100%|██████████| 469/469 [00:05<00:00, 82.00it/s]


Epoch 6/30: train_loss=0.0540, val_loss=0.0756


100%|██████████| 469/469 [00:06<00:00, 76.81it/s]


Epoch 7/30: train_loss=0.0444, val_loss=0.0711


100%|██████████| 469/469 [00:05<00:00, 79.83it/s]


Epoch 8/30: train_loss=0.0357, val_loss=0.0679


100%|██████████| 469/469 [00:05<00:00, 83.67it/s]


Epoch 9/30: train_loss=0.0303, val_loss=0.0700


100%|██████████| 469/469 [00:05<00:00, 83.05it/s]


Epoch 10/30: train_loss=0.0255, val_loss=0.0693
Training from model checkpoint with SGD.


100%|██████████| 469/469 [00:05<00:00, 82.30it/s]


Epoch 11/30: train_loss=0.0210, val_loss=0.0767


100%|██████████| 469/469 [00:05<00:00, 82.29it/s]


Epoch 12/30: train_loss=0.0169, val_loss=0.0730


100%|██████████| 469/469 [00:05<00:00, 82.72it/s]


Epoch 13/30: train_loss=0.0143, val_loss=0.0707


100%|██████████| 469/469 [00:05<00:00, 83.97it/s]


Epoch 14/30: train_loss=0.0117, val_loss=0.0706


100%|██████████| 469/469 [00:05<00:00, 83.50it/s]


Epoch 15/30: train_loss=0.0093, val_loss=0.0717


100%|██████████| 469/469 [00:05<00:00, 81.61it/s]


Epoch 16/30: train_loss=0.0083, val_loss=0.0743


100%|██████████| 469/469 [00:05<00:00, 82.08it/s]


Epoch 17/30: train_loss=0.0065, val_loss=0.0771


100%|██████████| 469/469 [00:05<00:00, 81.63it/s]


Epoch 18/30: train_loss=0.0052, val_loss=0.0747


100%|██████████| 469/469 [00:05<00:00, 83.25it/s]


Epoch 19/30: train_loss=0.0044, val_loss=0.0768


100%|██████████| 469/469 [00:05<00:00, 80.56it/s]


Epoch 20/30: train_loss=0.0037, val_loss=0.0764


100%|██████████| 469/469 [00:05<00:00, 80.72it/s]


Epoch 21/30: train_loss=0.0029, val_loss=0.0748


100%|██████████| 469/469 [00:06<00:00, 76.51it/s]


Epoch 22/30: train_loss=0.0024, val_loss=0.0791


100%|██████████| 469/469 [00:05<00:00, 82.92it/s]


Epoch 23/30: train_loss=0.0021, val_loss=0.0781


100%|██████████| 469/469 [00:05<00:00, 81.57it/s]


Epoch 24/30: train_loss=0.0018, val_loss=0.0784


100%|██████████| 469/469 [00:05<00:00, 82.22it/s]


Epoch 25/30: train_loss=0.0016, val_loss=0.0791


100%|██████████| 469/469 [00:05<00:00, 81.99it/s]


Epoch 26/30: train_loss=0.0014, val_loss=0.0797


100%|██████████| 469/469 [00:05<00:00, 82.94it/s]


Epoch 27/30: train_loss=0.0013, val_loss=0.0795


100%|██████████| 469/469 [00:05<00:00, 82.00it/s]


Epoch 28/30: train_loss=0.0012, val_loss=0.0811


100%|██████████| 469/469 [00:05<00:00, 81.13it/s]


Epoch 29/30: train_loss=0.0011, val_loss=0.0815


100%|██████████| 469/469 [00:05<00:00, 83.30it/s]


Epoch 30/30: train_loss=0.0011, val_loss=0.0828
Training from model checkpoint with NGD.


	add(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add(Tensor other, *, Number alpha) (Triggered internally at ..\torch\csrc\utils\python_arg_parser.cpp:1630.)
  d_p = d_p.add(momentum, buf)
100%|██████████| 469/469 [00:19<00:00, 24.67it/s]


Epoch 11/30: train_loss=171455814.5441, val_loss=2.3011


100%|██████████| 469/469 [00:18<00:00, 25.38it/s]


Epoch 12/30: train_loss=2.3013, val_loss=2.3013


100%|██████████| 469/469 [00:18<00:00, 25.72it/s]


Epoch 13/30: train_loss=2.3014, val_loss=2.3012


100%|██████████| 469/469 [00:18<00:00, 25.66it/s]


Epoch 14/30: train_loss=2.3014, val_loss=2.3011


100%|██████████| 469/469 [00:18<00:00, 24.76it/s]


Epoch 15/30: train_loss=2.3014, val_loss=2.3013


100%|██████████| 469/469 [00:18<00:00, 25.56it/s]


Epoch 16/30: train_loss=2.3014, val_loss=2.3012


100%|██████████| 469/469 [00:18<00:00, 25.31it/s]


Epoch 17/30: train_loss=2.3015, val_loss=2.3011


100%|██████████| 469/469 [00:18<00:00, 25.52it/s]


Epoch 18/30: train_loss=2.3014, val_loss=2.3011


100%|██████████| 469/469 [00:18<00:00, 25.07it/s]


Epoch 19/30: train_loss=2.3014, val_loss=2.3010


100%|██████████| 469/469 [00:18<00:00, 25.30it/s]


Epoch 20/30: train_loss=2.3014, val_loss=2.3011


100%|██████████| 469/469 [00:18<00:00, 25.46it/s]


Epoch 21/30: train_loss=2.3014, val_loss=2.3012


100%|██████████| 469/469 [00:18<00:00, 25.22it/s]


Epoch 22/30: train_loss=2.3013, val_loss=2.3013


100%|██████████| 469/469 [00:18<00:00, 25.36it/s]


Epoch 23/30: train_loss=2.3014, val_loss=2.3010


100%|██████████| 469/469 [00:18<00:00, 25.50it/s]


Epoch 24/30: train_loss=2.3014, val_loss=2.3014


100%|██████████| 469/469 [00:18<00:00, 25.39it/s]


Epoch 25/30: train_loss=2.3014, val_loss=2.3011


100%|██████████| 469/469 [00:18<00:00, 25.75it/s]


Epoch 26/30: train_loss=2.3014, val_loss=2.3010


100%|██████████| 469/469 [00:18<00:00, 25.37it/s]


Epoch 27/30: train_loss=2.3014, val_loss=2.3011


100%|██████████| 469/469 [00:18<00:00, 25.47it/s]


Epoch 28/30: train_loss=2.3014, val_loss=2.3012


100%|██████████| 469/469 [00:18<00:00, 25.61it/s]


Epoch 29/30: train_loss=2.3014, val_loss=2.3011


100%|██████████| 469/469 [00:18<00:00, 25.68it/s]


Epoch 30/30: train_loss=2.3014, val_loss=2.3011


In [6]:
# If we are doing the swap experiment, fill in the first hp["cut_off_epochs"] models with None

if hp["experiment_type"] == "swap":
    none_models = [None for i in range(hp["cut_off_epochs"]+1)]
    for title, model in models.items():
        model["ngd"] = none_models + model["ngd"]

#### **3. Visualising training / validation loss results**

Check that the models all converged.

Displays training and testing data for each model separately, with 4 traces on each graph.

The traces are: SGD training, SGD testing, NGD training, NGD testing.

In [8]:
# Display training / val data for all models for NGD / SGD

epochs = np.arange(1, hp["num_epochs"]+1)

loss_figures = {}

color_cycle = ['rgb(0, 0, 255)', 'rgb(255, 0, 0)']

for title in models.keys():
    loss_fig = make_subplots(specs=[[{"secondary_y" : True}]])
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        loss_fig.add_trace(go.Scatter(
            x=np.arange(0, hp["num_epochs"]+1),
            y=train_losses[title][optim],
            mode="lines+markers",
            line=dict(color=color),
            name=f"{optim} Train",
        ), secondary_y=False)
        loss_fig.add_trace(go.Scatter(
            x=np.arange(0, hp["num_epochs"]+1),
            y=val_losses[title][optim],
            mode="lines+markers",
            line=dict(color=color, dash='dot'),
            name=f"{optim} Validation",
        ), secondary_y=True)
    loss_fig.update_layout(
        title=f"{title} training / validation loss",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=1,
        ),
        yaxis_type="linear",
    )
    loss_fig.update_yaxes(title_text="Training Loss", secondary_y=False)
    loss_fig.update_yaxes(title_text="Validation Loss", secondary_y=True)
    loss_figures[title] = loss_fig
    loss_fig.show()

In [None]:
# Gradient norms over epochs for NGD and SGD

update_norm_figures = {}

color_cycle = ['rgb(0, 0, 255)', 'rgb(255, 0, 0)']

for title in models.keys():
    update_norm_fig = go.Figure()
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        update_norm_fig.add_trace(go.Scatter(
            y=all_update_norms[title][optim][::5],
            mode="lines",
            line=dict(color=color),
            name=f"{optim} update norm"))
    update_norm_fig.update_layout(
        title=f"{title} update norms over epochs",
        xaxis_title="Steps",
        yaxis_title="Update Size",
        yaxis_type="linear",
    )
    update_norm_figures[title] = update_norm_fig
    update_norm_fig.show()

#### **4. Perform Hessian rank computation**

As a way of verifying results produced by the RLCT, we will compute an approximation of the Hessian for each model at convergence. Then, we'll estimate the rank of this matrix using its eigenspectrum. SLT predicts the following: 

$\text{RLCT} \geq \frac{\text{rank}(\textbf{Hess})}{2}$ 

We will check whether this is true for our experiments. Hessian computation is done using helper functions from `utils_hessian_fim.py` which acts as a wrapper for the `PyHessian` module (and the `nngeometry` module, for doing computations involving the Fisher Information Matrix).

In [None]:
# Create Hessians for each model - recall that model is a list containing all its past versions
if hp["hessian"]:
    hessians = {}
    for title, model in models.items():
        hessian = {}
        for optim in optimizers:
            hessians_list = produce_hessians(
                models=model[optim],
                data_loader=train_loader,
                num_batches=data_args["num_hessian_batches"],
                criterion=criterion,
                device=device,
            )
            hessian[optim] = hessians_list
        hessians[title] = hessian

In [None]:
# Get the eigenspectum data from the Hessian objects
if hp["hessian"]:
    eigenspectra_data = {}
    eigenspectra_figs = {}
    for title, hessian in hessians.items():
        eigenspectrum_data = {}
        eigenspectrum_figs = {}
        for optim in optimizers:
            eigenspectrum_figs_list, eigenspectrum_data_list = produce_eigenspectra(
                hessians=hessian[optim],
                plot_type="log",
            )
            eigenspectrum_figs[optim] = eigenspectrum_figs_list
            eigenspectrum_data[optim] = eigenspectrum_data_list
        eigenspectra_data[title] = eigenspectrum_data
        eigenspectra_figs[title] = eigenspectrum_figs

In [None]:
# Produce the traces of Hessian dimensionality over epochs
if hp["hessian"]:
    hessian_ranks = {}
    for title, eigenspectrum_data in eigenspectra_data.items():
        hessian_rank = {}
        for optim in optimizers:
            hessian_rank_list = find_hessian_dimensionality(eigenspectrum_data[optim])
            hessian_rank[optim] = hessian_rank_list
        hessian_ranks[title] = hessian_rank

#### **5. Perform RLCT estimation**

Using the `devinterp` library, we perform estimation of the RLCT (Real Log Canonical Threshold) or otherwise known as the LLC (Local Learning Coefficient).

`rlct_estimates` is a dictionary containing dictionaries, each of which contain two lists, one for SGD RLCT values over epochs, and one for NGD RLCT values over epochs. The same is true for `wbic_estimates`.

In [None]:
rlct_estimates = {}
histories = {}

for title, model in models.items():
    history = {}
    rlct_estimate = {}
    for optim in optimizers:
        rlct_list, history_list = estimate_rlcts(
            model[optim], train_loader, criterion, device, devinterp_args,
        )
        rlct_estimate[optim] = rlct_list
        history[optim] = history_list
    rlct_estimates[title] = rlct_estimate 
    histories[title] = history

In [None]:
# Compute generalisation losses for each model, for SGD and NGD, so we can compare this to the testing loss

gen_losses = {}
for title in models.keys():
    gen_loss = {}
    for optim in optimizers:
        gen_loss_list = []
        for i in range(hp["num_epochs"]):
            if histories[title][optim][i] is None:
                gen_loss_list.append(None)
            else:
                gen_loss_list.append(train_losses[title][optim][i] + histories[title][optim][i]["llc/moving_avg"][-1][-1]/data_args["batch_size"])
        gen_loss[optim] = gen_loss_list
    gen_losses[title] = gen_loss

#### **6. Visualise RLCT / Hessian rank, Hessian eigenspectra, generalisation loss, RLCT convergence**

We display the following final figures:
- RLCT and Hessian rank evolution for each model, for NGD and SGD
- Hessian eigenspectra for SGD and NGD overlaid on the same plot
- Evolution of generalisation loss for each model, compared to validation loss
- RLCT moving average evolution for each model, for NGD and SGD, to check for convergence

In [None]:
# Visualise RLCT and Hessian rank data

exp_figures = {}

for title in models.keys():
    exp_fig = make_subplots(specs=[[{"secondary_y" : True}]])
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        exp_fig.add_trace(go.Scatter(
            x=epochs,
            y=rlct_estimates[title][optim][1:],
            mode="lines+markers",
            name=f"{optim} RLCT",
            line=dict(color=color),
        ), secondary_y=False)
        if hp["hessian"]:
            exp_fig.add_trace(go.Scatter(
                x=epochs,
                y=hessian_ranks[title][optim][1:],
                mode="lines+markers",
                name=f"{optim} Hessian Rank",
                line=dict(color=color, dash="dot"),
            ), secondary_y=True)
    exp_fig.update_layout(
        title=f"{title} RLCT / Hessian Rank (optional) evolution during training",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=1,
        )
        yaxis_type="linear",
    )
    exp_fig.update_yaxes(title_text="RLCT", secondary_y=False)
    exp_fig.update_yaxes(title_text="Hessian rank", secondary_y=True)
    exp_figures[title] = exp_fig
    exp_fig.show()   

In [None]:
# Visualise converged eigenspectra for SGD and NGD for each model

combined_eigenspectra = {}

for title in models.keys():

    combined_eigenspectrum = go.Figure()
    final_eigenspectra = {}
    traces = {}

    for optim in optimizers:
        final_eigenspectra[optim] = eigenspectra_figs[title][optim][-2]
        final_eigenspectra[optim].data[0].name = optim
        combined_eigenspectrum.add_trace(final_eigenspectra[optim].data[0])

    combined_eigenspectrum.update_layout(
        title=f"{title} Hessian eigenspectra at convergence",
        xaxis_title="Eigenvalue",
        yaxis_title="Probability density (log scale)",
        yaxis_type="log",
    )

    combined_eigenspectrum.show()
    combined_eigenspectra[title] = (combined_eigenspectrum)

In [None]:
# Visualise generalisation loss vs. testing loss for each model

train_gen_figs = {}
for title in models.keys():
    train_gen_fig = go.Figure()
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        train_gen_fig.add_trace(go.Scatter(
            x=epochs,
            y=val_losses[title][optim][1:],
            mode="lines+markers",
            name=f"{optim} Validation",
            line=dict(color=color),
        ))
        train_gen_fig.add_trace(go.Scatter(
            x=epochs,
            y=gen_losses[title][optim][1:],
            mode="lines+markers",
            name=f"{optim} Generalisation",
            line=dict(color=color, dash="dot"),
        ))
    train_gen_fig.update_layout(
        title=f"{title} validation / generalisation loss",
        xaxis_title="Epoch",
        yaxis_title="Loss",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=1,
        )
        yaxis_type="linear",
    )
    train_gen_figs[title] = train_gen_fig
    train_gen_fig.show()

In [None]:
# Check the LLC chains converged for each model

rlct_converge_plots = {}
for title in models.keys():
    rlct_converge_plot = go.Figure()
    for epoch in epochs:
        for optim in optimizers:
            if histories[title][optim][epoch] is not None:
                rlct_converge_plot.add_trace(go.Scatter(
                    y=histories[title][optim][epoch]["llc/moving_avg"][0],
                    name=f"{optim} Epoch {epoch}",
                ))
    rlct_converge_plot.update_layout(
        title=f"Evolution of LLC moving average for each model over epochs for {title}",
        xaxis_title="Draws",
        yaxis_title="RLCT",
        yaxis_type="linear",
        legend_title="Epoch"
    )
    rlct_converge_plots[title] = rlct_converge_plot
    rlct_converge_plot.show()

In [None]:
# Save the results to a HTML file.

figures = []

combined_args = {**hp, **data_args, **devinterp_args}

summary = pprint.pformat(combined_args)

for fig in loss_figures.values():
    figures.append(fig)
for fig in update_norm_figures.values():
    figures.append(fig)
for fig in exp_figures.values():
    figures.append(fig)
for fig in combined_eigenspectra.values():
    figures.append(fig)
for fig in train_gen_figs.values():
    figures.append(fig)
for fig in rlct_converge_plots.values():
    figures.append(fig)

curr_time = datetime.now().strftime("%Y-%m-%d-%H-%M")

write_figs_to_html(
    figs=figures,
    dest=f"./ngd_sgd/dnn_ngd_sgd_rlct_{curr_time}.html",
    title="Does NGD converge to minima that are 'more complex' i.e. have a higher RLCT?",
    summary=summary,
)