### **Does NGD converge to more complex minima according to the RLCT (Real Log Canonical Threshold)?**

This notebook aims to test the above claim, using developmental interpretability methods.

We hypothesise that NGD will consistently have a higher RLCT, because it premultiplies the gradient used for gradient descent by the inverse of the Fisher Information Matrix.

SLT proposes that models converge to singularities, where the Fisher Information Matrix is non-invertable. Hence, when near a singularity, the inverse of the FIM will blow up as its determinant is close to 0. Therefore, NGD will "jump away" from the singularity, instead favouring more complex, less singular minima.

#### **Methodology**
- Choose between DNN (dense neural network) and CNN (convolutional neural network)
- Vary hidden nodes and hidden layers for DNN (depth and width)
- Vary number of convolutional layers for CNN.
- Estimate Hessian rank using `PyHessian` module.
- Estimate RLCT using `devinterp` library.

#### **Instructions**

To produce your own results, go to `args_ngd_sgd.json`. The following parameters can be adjusted:
- `model_type (str)`  : `dnn` (deep network), `cnn` (convolutional network)
- `experiment_type (str)` : `standard` (train models independently using SGD and NGD), `swap` (train to convergence with SGD, then continue with SGD + NGD)

<!-- -->

- `optimizer (str)` : `sgd`, `ngd`, or `both` (performs analysis on both optimizers) - SET TO `both` IF DOING `swap` EXPERIMENT
- `hessian (bool)` : `true`, `false` (does Hessian rank analysis if enabled)

<!-- -->

- `hidden_nodes (List : int)` : e.g. `[32, 32, 64, 64, 128, 128]`, the hidden nodes to use in DNN
- `hidden_layers (List: int)` : e.g. `[1, 1, 2, 2, 3, 3]`, the hidden layers to use in DNN. Must have same length as `hidden_nodes`.
- `hidden_conv_layers (List : int)` : array of number of hidden convolutional layers to use for CNN, e.g. `[1, 2, 3]`.

<!-- -->

- `cut_off_epochs (int)` : the epoch number at which the optimiser is swapped if using `experiment_type = "swap"`
- `num_epochs (int)` : total number of training epochs

<!-- -->

- `sgd_lr (float)` : sgd learning rate for training
- `ngd_lr (float)` : ngd learning rate for training

<!-- -->

- `alpha (float)` : smoothing constant for estimation of FIM for NGD
- `eta (float)` : $ F_{n} = \eta S_n + (1 - \eta) F_{n-1} $ where $ S_n $ is the estimated Fisher matrix from the current batch.
- `epsilon (float)` : added to $ F $ to stop it from becoming singular
- `delta (float)` : constant preventing $ F $ from becoming singular

<!-- -->

- `momentum (float)` : momentum used in SGD and NGD
- `nesterov (bool)` : enable Nesterov momentum for NGD / SGD

<!-- -->

- `seed (int)` : the random seed used for `torch.manual_seed()`

<!-- -->

- `batch_size (int)` : batch size for train / validation dataloaders
- `num_workers (int)` : number of GPU workers for data loading (keep this at around 6, vary depending on your hardware)
- `dataset (str)` : `mnist`, `cifar10` (dataset to use for training)
- `num_hessian_batches (int)` : number of batches used for estimation of Hessian

<!-- -->

- `sampler (str)` : `sgld`, `sgnht` (optimiser to use for RLCT estimation)
- `num_chains (int)` : number of chains to use in RLCT estimation (higher leads to more accurate RLCT estimate)
- `num_draws (int)` : number of optimizer steps in RLCT estimation (should be high enough such that chain RLCT converges - check convergence plot)
- `localization (float)` : higher localization more strongly restricts optimizer to neighbourhood of model weights (stops RLCT estimator from going straight to minima)
- `sampler_lr (float)` : learning rate for sampler for RLCT estimation

You can run the notebook from the terminal using the following command:
```bash
jupyter nbconvert --to notebook --execute --inplace ./experiments/eval_ngd_sgd.ipynb
```

#### **0. Import libraries**

The `NGD` module is used for implementing Natural Gradient Descent efficiently. It approximates the Fisher Information Matrix to do this.
The `devinterp` library is used for estimation of the LLC (local learning coefficient).

In [49]:
from multiprocessing import freeze_support

import os
import sys
import copy
import pickle
import pprint
import json
from pathlib import Path
from datetime import datetime
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm

sys.path.append("../")

import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split

from PyHessian.pyhessian import *
from PyHessian.density_plot import *
from nngeometry.metrics import FIM
from nngeometry.object import PMatKFAC, PMatDiag, PVector

from utils_general import *
from utils_hessian_fim import *
from architectures import *
from ngd import NGD

from plotly.subplots import make_subplots
import plotly.express as px
import plotly.colors
import plotly.graph_objects as go
import matplotlib
import matplotlib.pyplot as plt

device = "cuda" if t.cuda.is_available() else "cpu"
print(f"Device in use: {device}")

%load_ext autoreload
%autoreload
%matplotlib inline


Device in use: cuda
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### **1. Import data for training**

We import our dataset for training. We use a helper function, `build_data_loaders` for this, which allows us to choose between MNIST and CIFAR10. 

We specify three dictionaries, `hp` = hyperparameters, `data_args` = arguments for dataloading, `devinterp_args` = arguments for LLC and WBIC estimation.

In [50]:
# Load experiment args

with open("args_ngd_sgd.json", "r") as file:
    args = json.load(file) 

hp, data_args, devinterp_args = args

print("======HYPERPARAMETERS======")
pprint.pprint(hp, sort_dicts=False)

print("======DATA ARGS======")
pprint.pprint(data_args, sort_dicts=False)

print("======DEVINTERP ARGS======")
pprint.pprint(devinterp_args, sort_dicts=False)

{'model_type': 'ffnn',
 'experiment_type': 'swap',
 'optimizer': 'both',
 'hessian': False,
 'hidden_nodes': [256],
 'hidden_layers': [2],
 'hidden_conv_layers': [2],
 'cut_off_epochs': 30,
 'num_epochs': 80,
 'sgd_lr': 0.001,
 'ngd_lr': 0.001,
 'alpha': 0.1,
 'eta': 0.9,
 'epsilon': 1e-10,
 'delta': 0.0005,
 'momentum': 0.9,
 'nesterov': True,
 'seed': 2}
{'batch_size': 128,
 'num_workers': 22,
 'dataset': 'fashion-mnist',
 'num_hessian_batches': 1}
{'sampler': 'sgld',
 'num_chains': 2,
 'num_draws': 1000,
 'localization': 100.0,
 'sampler_lr': 0.0001}


In [51]:
# Set random seed for weight initialisation (for reproducibility of results and to ensure NGD/SGD start from same point in loss landscape)

t.manual_seed(hp["seed"])

<torch._C.Generator at 0x7faba5368e50>

In [52]:
train_loader, test_loader = build_data_loaders(data_args)

#### **2. Training models**

Choose to train either a DNN or CNN.

This code produces a dictionary `models` where each key describes the model itself, e.g. "DNN 4 HL, 256 HN".

The values in this dictionary, iterated over as `model`, are each another dictionary with the optimizer as keys
the vakues are then a list containing the model with initial weights. As models are trained, each epoch the trained model will be added to this list so we can record the model history.

```
models={
    "architecture1":{
        "optim1":[epoch0,epoch1],
        "optim2":[epoch0,epoch1],
    },
    "architecture2":{
        "optim1":[epoch0,epoch1],
        "optim2":[epoch0,epoch1],
    }
}
```

In [53]:
# Initialise models dependent on arguments

models = {}
optimizers = ["ngd", "sgd"] if hp["optimizer"] == "both" else [hp["optimizer"]]

if hp["model_type"] == "ffnn":
    hidden_nodes = hp["hidden_nodes"]
    hidden_layers = hp["hidden_layers"]
    for hidden_node, hidden_layer in zip(hidden_nodes, hidden_layers):
        title = f"FFNN {hidden_layer} HL, {hidden_node} HN"
        if data_args["dataset"] == "mnist" or data_args["dataset"] == "fashion-mnist":
            model = NeuralNet(relu=True,input_size=28*28,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        elif data_args["dataset"] == "cifar10":
            model = NeuralNet(relu=True,input_size=32*32*3,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}
elif hp["model_type"] == "dlnn":
    hidden_nodes = hp["hidden_nodes"]
    hidden_layers = hp["hidden_layers"]
    for hidden_node, hidden_layer in zip(hidden_nodes, hidden_layers):
        title = f"DLNN {hidden_layer} HL, {hidden_node} HN"
        if data_args["dataset"] == "mnist" or data_args["dataset"] == "fashion-mnist":
            model = NeuralNet(relu=False,input_size=28*28,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        elif data_args["dataset"] == "cifar10":
            model = NeuralNet(relu=False,input_size=32*32*3,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}
elif hp["model_type"] == "cnn":
    hidden_conv_layers = hp["hidden_conv_layers"]
    for hidden_conv_layer in hidden_conv_layers:
        title = f"CNN {hidden_conv_layer} HCL"
        if data_args["dataset"] == "mnist" or data_args["dataset"] == "fashion-mnist":
            model = LeNet(dataset="mnist",extra_layers=hidden_conv_layer).to(device)
        elif data_args["dataset"] == "cifar10":
            model = LeNet(dataset="cifar10",extra_layers=hidden_conv_layer).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}


If experiment 1, make sure `cutoff_epochs` equal to `num_epochs`

In [54]:
# EXPERIMENT 1: Train models independently for num_epochs with SGD and NGD

if hp["experiment_type"] == "standard":
    train_losses = {}
    val_losses = {}
    update_norms = {}
    all_update_norms = {}

    criterion = nn.CrossEntropyLoss()

    for title, model in models.items():

        # Store list for SGD losses, NGD losses, for train and val
        model_train_losses = {optim : [] for optim in optimizers}
        model_val_losses = {optim : [] for optim in optimizers}
        model_update_norms = {optim : [] for optim in optimizers}
        model_all_update_norms = {optim : [] for optim in optimizers}

        for optim in optimizers:
            state = copy.deepcopy(model[optim][0])
            if optim == "sgd":
                optimizer = t.optim.SGD(
                    params=state.parameters(),
                    lr=hp["sgd_lr"],
                    momentum=hp["momentum"],
                    nesterov=hp["nesterov"],
                )
            elif optim == "ngd":
                optimizer = NGD(
                    params=state.parameters(),
                    lr=hp["ngd_lr"],
                    alpha=hp["alpha"],
                    eta=hp["eta"],
                    epsilon=hp["epsilon"],
                    delta=hp["delta"],
                    momentum=hp["momentum"],
                    nesterov=hp["nesterov"],
                )
            print(f"TRAINING MODEL: {title} | OPTIMISER: {optim}")
            initial_train_loss = evaluate(state, train_loader, criterion, device)
            initial_val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses[optim].append(initial_train_loss)
            model_val_losses[optim].append(initial_val_loss)
        
            for epoch in range(1, hp["num_epochs"]+1):
                train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optimizer, criterion, device)
                val_loss = evaluate(state, test_loader, criterion, device)
                model_train_losses[optim].append(train_loss)
                model_update_norms[optim].append(update_norm)
                model_val_losses[optim].append(val_loss)
                #note that epoch_update norms is a list with the update_norm of each minibatch, this will concat the list to previous list
                model_all_update_norms[optim] += epoch_update_norms
                model[optim].append(copy.deepcopy(state))
                print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")

        # Save train/val loss dictionaries to the right model key
        train_losses[title] = model_train_losses
        val_losses[title] = model_val_losses
        update_norms[title] = model_update_norms
        all_update_norms[title] = model_all_update_norms


Only run the below cell ONCE, after everytime you run the initialisation cell. Otherwise, you will keep appending models incorrectly.

In [55]:
# EXPERIMENT 2: Train to convergence with SGD, then continue with SGD and NGD and compare

if hp["experiment_type"] == "swap":
    train_losses = {}
    val_losses = {}
    update_norms = {}
    all_update_norms = {}

    criterion = nn.CrossEntropyLoss()

    for title, model in models.items():
        
        #delete the first model for NGD as we recycle the SGD stuff
        #del model["ngd"][0]
        #replace this first model with a bunch of Nones
        model["ngd"]=[None for i in range(hp["cut_off_epochs"]+1)]

        #fill in first few with None, for NGD
        model_train_losses = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_val_losses = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_update_norms = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_all_update_norms = {"sgd" : [], "ngd" : [None for i in range(len(train_loader)*hp["cut_off_epochs"])]}

        state = copy.deepcopy(model["sgd"][0])
        optimizer = t.optim.SGD(
            params=state.parameters(),
            lr=hp["sgd_lr"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
    
        print(f"TRAINING MODEL: {title} | OPTIMISER: SGD")
        initial_train_loss = evaluate(state, train_loader, criterion, device)
        initial_val_loss = evaluate(state, test_loader, criterion, device)
        model_train_losses["sgd"].append(initial_train_loss)
        model_val_losses["sgd"].append(initial_val_loss)
        for epoch in range(1, hp["cut_off_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optimizer, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["sgd"].append(train_loss)
            model_update_norms["sgd"].append(update_norm)
            model_val_losses["sgd"].append(val_loss)
            model_all_update_norms["sgd"] += epoch_update_norms
            model["sgd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")


        # Train with SGD
        print("Training from model checkpoint with SGD.")
        state = copy.deepcopy(model["sgd"][hp["cut_off_epochs"]])
        optim_sgd = t.optim.SGD(
            params=state.parameters(),
            #multiply this to vary SGD
            lr=hp["sgd_lr"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
        for epoch in range(hp["cut_off_epochs"]+1, hp["num_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optim_sgd, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["sgd"].append(train_loss)
            model_update_norms["sgd"].append(update_norm)
            model_val_losses["sgd"].append(val_loss)
            model_all_update_norms["sgd"] += epoch_update_norms
            model["sgd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
        
        # Train with NGD
        print("Training from model checkpoint with NGD.")
        state = copy.deepcopy(model["sgd"][hp["cut_off_epochs"]])
        optim_ngd = NGD(
            params=state.parameters(),
            lr=hp["ngd_lr"],
            alpha=hp["alpha"],
            eta=hp["eta"],
            epsilon=hp["epsilon"],
            delta=hp["delta"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
        for epoch in range(hp["cut_off_epochs"]+1, hp["num_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optim_ngd, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["ngd"].append(train_loss)
            model_update_norms["ngd"].append(update_norm)
            model_val_losses["ngd"].append(val_loss)
            model_all_update_norms["ngd"] += epoch_update_norms
            model["ngd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")

        # Save train/val loss dictionaries to the right model key
        train_losses[title] = model_train_losses
        val_losses[title] = model_val_losses
        update_norms[title] = model_update_norms
        all_update_norms[title] = model_all_update_norms


TRAINING MODEL: FFNN 2 HL, 256 HN | OPTIMISER: SGD


100%|██████████| 469/469 [00:02<00:00, 210.41it/s]


Epoch 1/80: train_loss=1.2705, val_loss=0.7461


100%|██████████| 469/469 [00:02<00:00, 205.93it/s]


Epoch 2/80: train_loss=0.6507, val_loss=0.6102


100%|██████████| 469/469 [00:02<00:00, 228.12it/s]


Epoch 3/80: train_loss=0.5597, val_loss=0.5498


100%|██████████| 469/469 [00:02<00:00, 220.48it/s]


Epoch 4/80: train_loss=0.5118, val_loss=0.5174


100%|██████████| 469/469 [00:02<00:00, 223.52it/s]


Epoch 5/80: train_loss=0.4819, val_loss=0.4953


100%|██████████| 469/469 [00:02<00:00, 218.56it/s]


Epoch 6/80: train_loss=0.4608, val_loss=0.4841


100%|██████████| 469/469 [00:02<00:00, 211.02it/s]


Epoch 7/80: train_loss=0.4454, val_loss=0.4677


100%|██████████| 469/469 [00:02<00:00, 204.21it/s]


Epoch 8/80: train_loss=0.4323, val_loss=0.4585


100%|██████████| 469/469 [00:02<00:00, 208.79it/s]


Epoch 9/80: train_loss=0.4216, val_loss=0.4493


100%|██████████| 469/469 [00:02<00:00, 208.50it/s]


Epoch 10/80: train_loss=0.4116, val_loss=0.4433


100%|██████████| 469/469 [00:02<00:00, 193.13it/s]


Epoch 11/80: train_loss=0.4027, val_loss=0.4350


100%|██████████| 469/469 [00:02<00:00, 205.03it/s]


Epoch 12/80: train_loss=0.3941, val_loss=0.4259


100%|██████████| 469/469 [00:02<00:00, 209.76it/s]


Epoch 13/80: train_loss=0.3870, val_loss=0.4204


100%|██████████| 469/469 [00:01<00:00, 242.70it/s]


Epoch 14/80: train_loss=0.3806, val_loss=0.4154


100%|██████████| 469/469 [00:02<00:00, 232.13it/s]


Epoch 15/80: train_loss=0.3740, val_loss=0.4115


100%|██████████| 469/469 [00:02<00:00, 212.48it/s]


Epoch 16/80: train_loss=0.3682, val_loss=0.4064


100%|██████████| 469/469 [00:02<00:00, 204.28it/s]


Epoch 17/80: train_loss=0.3629, val_loss=0.3999


100%|██████████| 469/469 [00:02<00:00, 208.30it/s]


Epoch 18/80: train_loss=0.3577, val_loss=0.3968


100%|██████████| 469/469 [00:02<00:00, 200.81it/s]


Epoch 19/80: train_loss=0.3522, val_loss=0.3941


100%|██████████| 469/469 [00:02<00:00, 204.04it/s]


Epoch 20/80: train_loss=0.3472, val_loss=0.3879


100%|██████████| 469/469 [00:02<00:00, 202.92it/s]


Epoch 21/80: train_loss=0.3436, val_loss=0.3849


100%|██████████| 469/469 [00:02<00:00, 213.00it/s]


Epoch 22/80: train_loss=0.3383, val_loss=0.3838


100%|██████████| 469/469 [00:02<00:00, 207.75it/s]


Epoch 23/80: train_loss=0.3344, val_loss=0.3790


100%|██████████| 469/469 [00:02<00:00, 204.94it/s]


Epoch 24/80: train_loss=0.3305, val_loss=0.3803


100%|██████████| 469/469 [00:02<00:00, 210.94it/s]


Epoch 25/80: train_loss=0.3267, val_loss=0.3736


100%|██████████| 469/469 [00:02<00:00, 208.07it/s]


Epoch 26/80: train_loss=0.3233, val_loss=0.3729


100%|██████████| 469/469 [00:02<00:00, 201.66it/s]


Epoch 27/80: train_loss=0.3191, val_loss=0.3713


100%|██████████| 469/469 [00:02<00:00, 218.40it/s]


Epoch 28/80: train_loss=0.3158, val_loss=0.3652


100%|██████████| 469/469 [00:02<00:00, 214.77it/s]


Epoch 29/80: train_loss=0.3123, val_loss=0.3691


100%|██████████| 469/469 [00:02<00:00, 226.60it/s]


Epoch 30/80: train_loss=0.3084, val_loss=0.3647
Training from model checkpoint with SGD.


100%|██████████| 469/469 [00:02<00:00, 214.19it/s]


Epoch 31/80: train_loss=0.3060, val_loss=0.3613


100%|██████████| 469/469 [00:02<00:00, 209.33it/s]


Epoch 32/80: train_loss=0.3027, val_loss=0.3573


100%|██████████| 469/469 [00:02<00:00, 209.92it/s]


Epoch 33/80: train_loss=0.2996, val_loss=0.3596


100%|██████████| 469/469 [00:02<00:00, 208.10it/s]


Epoch 34/80: train_loss=0.2969, val_loss=0.3567


100%|██████████| 469/469 [00:02<00:00, 207.39it/s]


Epoch 35/80: train_loss=0.2928, val_loss=0.3585


100%|██████████| 469/469 [00:02<00:00, 216.39it/s]


Epoch 36/80: train_loss=0.2895, val_loss=0.3524


100%|██████████| 469/469 [00:02<00:00, 206.69it/s]


Epoch 37/80: train_loss=0.2869, val_loss=0.3543


100%|██████████| 469/469 [00:02<00:00, 201.87it/s]


Epoch 38/80: train_loss=0.2843, val_loss=0.3546


100%|██████████| 469/469 [00:02<00:00, 201.56it/s]


Epoch 39/80: train_loss=0.2818, val_loss=0.3481


100%|██████████| 469/469 [00:02<00:00, 231.37it/s]


Epoch 40/80: train_loss=0.2788, val_loss=0.3460


100%|██████████| 469/469 [00:02<00:00, 221.12it/s]


Epoch 41/80: train_loss=0.2760, val_loss=0.3525


100%|██████████| 469/469 [00:02<00:00, 216.41it/s]


Epoch 42/80: train_loss=0.2735, val_loss=0.3453


100%|██████████| 469/469 [00:02<00:00, 214.45it/s]


Epoch 43/80: train_loss=0.2704, val_loss=0.3428


100%|██████████| 469/469 [00:02<00:00, 204.44it/s]


Epoch 44/80: train_loss=0.2685, val_loss=0.3432


100%|██████████| 469/469 [00:02<00:00, 207.82it/s]


Epoch 45/80: train_loss=0.2652, val_loss=0.3420


100%|██████████| 469/469 [00:02<00:00, 217.00it/s]


Epoch 46/80: train_loss=0.2631, val_loss=0.3393


100%|██████████| 469/469 [00:02<00:00, 207.32it/s]


Epoch 47/80: train_loss=0.2605, val_loss=0.3378


100%|██████████| 469/469 [00:02<00:00, 205.66it/s]


Epoch 48/80: train_loss=0.2574, val_loss=0.3344


100%|██████████| 469/469 [00:02<00:00, 218.99it/s]


Epoch 49/80: train_loss=0.2551, val_loss=0.3367


100%|██████████| 469/469 [00:02<00:00, 216.89it/s]


Epoch 50/80: train_loss=0.2527, val_loss=0.3379


100%|██████████| 469/469 [00:02<00:00, 222.67it/s]


Epoch 51/80: train_loss=0.2505, val_loss=0.3357


100%|██████████| 469/469 [00:02<00:00, 216.23it/s]


Epoch 52/80: train_loss=0.2480, val_loss=0.3334


100%|██████████| 469/469 [00:02<00:00, 217.59it/s]


Epoch 53/80: train_loss=0.2452, val_loss=0.3394


100%|██████████| 469/469 [00:02<00:00, 214.59it/s]


Epoch 54/80: train_loss=0.2430, val_loss=0.3298


100%|██████████| 469/469 [00:02<00:00, 213.10it/s]


Epoch 55/80: train_loss=0.2408, val_loss=0.3336


100%|██████████| 469/469 [00:02<00:00, 208.58it/s]


Epoch 56/80: train_loss=0.2393, val_loss=0.3310


100%|██████████| 469/469 [00:02<00:00, 209.09it/s]


Epoch 57/80: train_loss=0.2366, val_loss=0.3360


100%|██████████| 469/469 [00:02<00:00, 208.44it/s]


Epoch 58/80: train_loss=0.2341, val_loss=0.3371


100%|██████████| 469/469 [00:02<00:00, 205.87it/s]


Epoch 59/80: train_loss=0.2323, val_loss=0.3369


100%|██████████| 469/469 [00:02<00:00, 199.78it/s]


Epoch 60/80: train_loss=0.2295, val_loss=0.3259


100%|██████████| 469/469 [00:02<00:00, 185.97it/s]


Epoch 61/80: train_loss=0.2275, val_loss=0.3292


100%|██████████| 469/469 [00:02<00:00, 181.41it/s]


Epoch 62/80: train_loss=0.2250, val_loss=0.3283


100%|██████████| 469/469 [00:02<00:00, 188.41it/s]


Epoch 63/80: train_loss=0.2225, val_loss=0.3308


100%|██████████| 469/469 [00:02<00:00, 201.07it/s]


Epoch 64/80: train_loss=0.2206, val_loss=0.3254


100%|██████████| 469/469 [00:02<00:00, 198.84it/s]


Epoch 65/80: train_loss=0.2182, val_loss=0.3322


100%|██████████| 469/469 [00:02<00:00, 212.21it/s]


Epoch 66/80: train_loss=0.2162, val_loss=0.3299


100%|██████████| 469/469 [00:02<00:00, 204.43it/s]


Epoch 67/80: train_loss=0.2144, val_loss=0.3264


100%|██████████| 469/469 [00:02<00:00, 198.40it/s]


Epoch 68/80: train_loss=0.2118, val_loss=0.3276


100%|██████████| 469/469 [00:02<00:00, 201.59it/s]


Epoch 69/80: train_loss=0.2098, val_loss=0.3244


100%|██████████| 469/469 [00:02<00:00, 199.37it/s]


Epoch 70/80: train_loss=0.2076, val_loss=0.3230


100%|██████████| 469/469 [00:02<00:00, 196.09it/s]


Epoch 71/80: train_loss=0.2063, val_loss=0.3273


100%|██████████| 469/469 [00:02<00:00, 211.29it/s]


Epoch 72/80: train_loss=0.2044, val_loss=0.3300


100%|██████████| 469/469 [00:02<00:00, 215.47it/s]


Epoch 73/80: train_loss=0.2019, val_loss=0.3233


100%|██████████| 469/469 [00:02<00:00, 203.05it/s]


Epoch 74/80: train_loss=0.1994, val_loss=0.3295


100%|██████████| 469/469 [00:02<00:00, 215.55it/s]


Epoch 75/80: train_loss=0.1981, val_loss=0.3238


100%|██████████| 469/469 [00:02<00:00, 205.16it/s]


Epoch 76/80: train_loss=0.1948, val_loss=0.3338


100%|██████████| 469/469 [00:02<00:00, 218.98it/s]


Epoch 77/80: train_loss=0.1931, val_loss=0.3295


100%|██████████| 469/469 [00:02<00:00, 210.38it/s]


Epoch 78/80: train_loss=0.1916, val_loss=0.3252


100%|██████████| 469/469 [00:02<00:00, 200.98it/s]


Epoch 79/80: train_loss=0.1898, val_loss=0.3309


100%|██████████| 469/469 [00:02<00:00, 205.15it/s]


Epoch 80/80: train_loss=0.1879, val_loss=0.3240
Training from model checkpoint with NGD.


100%|██████████| 469/469 [00:05<00:00, 81.95it/s]


Epoch 31/80: train_loss=0.3021, val_loss=0.3578


100%|██████████| 469/469 [00:05<00:00, 84.80it/s]


Epoch 32/80: train_loss=0.2960, val_loss=0.3562


100%|██████████| 469/469 [00:05<00:00, 80.07it/s]


Epoch 33/80: train_loss=0.2925, val_loss=0.3543


100%|██████████| 469/469 [00:05<00:00, 80.12it/s]


Epoch 34/80: train_loss=0.2888, val_loss=0.3524


100%|██████████| 469/469 [00:05<00:00, 82.79it/s]


Epoch 35/80: train_loss=0.2856, val_loss=0.3510


100%|██████████| 469/469 [00:05<00:00, 81.36it/s]


Epoch 36/80: train_loss=0.2818, val_loss=0.3495


100%|██████████| 469/469 [00:05<00:00, 85.32it/s]


Epoch 37/80: train_loss=0.2785, val_loss=0.3483


100%|██████████| 469/469 [00:05<00:00, 80.20it/s]


Epoch 38/80: train_loss=0.2751, val_loss=0.3467


100%|██████████| 469/469 [00:05<00:00, 86.51it/s]


Epoch 39/80: train_loss=0.2717, val_loss=0.3447


100%|██████████| 469/469 [00:05<00:00, 89.07it/s]


Epoch 40/80: train_loss=0.2684, val_loss=0.3427


100%|██████████| 469/469 [00:05<00:00, 88.58it/s]


Epoch 41/80: train_loss=0.2647, val_loss=0.3411


100%|██████████| 469/469 [00:05<00:00, 87.10it/s]


Epoch 42/80: train_loss=0.2612, val_loss=0.3393


100%|██████████| 469/469 [00:05<00:00, 85.40it/s]


Epoch 43/80: train_loss=0.2581, val_loss=0.3374


100%|██████████| 469/469 [00:05<00:00, 87.27it/s]


Epoch 44/80: train_loss=0.2546, val_loss=0.3362


100%|██████████| 469/469 [00:05<00:00, 88.17it/s]


Epoch 45/80: train_loss=0.2515, val_loss=0.3358


100%|██████████| 469/469 [00:05<00:00, 91.07it/s]


Epoch 46/80: train_loss=0.2484, val_loss=0.3344


100%|██████████| 469/469 [00:05<00:00, 86.25it/s]


Epoch 47/80: train_loss=0.2454, val_loss=0.3329


100%|██████████| 469/469 [00:05<00:00, 86.45it/s]


Epoch 48/80: train_loss=0.2424, val_loss=0.3315


100%|██████████| 469/469 [00:05<00:00, 85.69it/s]


Epoch 49/80: train_loss=0.2390, val_loss=0.3303


100%|██████████| 469/469 [00:05<00:00, 83.34it/s]


Epoch 50/80: train_loss=0.2356, val_loss=0.3286


100%|██████████| 469/469 [00:05<00:00, 86.51it/s]


Epoch 51/80: train_loss=0.2326, val_loss=0.3268


100%|██████████| 469/469 [00:05<00:00, 86.03it/s]


Epoch 52/80: train_loss=0.2297, val_loss=0.3255


100%|██████████| 469/469 [00:05<00:00, 82.74it/s]


Epoch 53/80: train_loss=0.2264, val_loss=0.3246


100%|██████████| 469/469 [00:05<00:00, 84.98it/s]


Epoch 54/80: train_loss=0.2234, val_loss=0.3241


100%|██████████| 469/469 [00:05<00:00, 85.59it/s]


Epoch 55/80: train_loss=0.2204, val_loss=0.3231


100%|██████████| 469/469 [00:05<00:00, 85.55it/s]


Epoch 56/80: train_loss=0.2171, val_loss=0.3215


100%|██████████| 469/469 [00:05<00:00, 86.21it/s]


Epoch 57/80: train_loss=0.2140, val_loss=0.3215


100%|██████████| 469/469 [00:05<00:00, 83.76it/s]


Epoch 58/80: train_loss=0.2110, val_loss=0.3207


100%|██████████| 469/469 [00:06<00:00, 76.62it/s]


Epoch 59/80: train_loss=0.2082, val_loss=0.3197


100%|██████████| 469/469 [00:05<00:00, 81.34it/s]


Epoch 60/80: train_loss=0.2053, val_loss=0.3197


100%|██████████| 469/469 [00:05<00:00, 84.10it/s]


Epoch 61/80: train_loss=0.2021, val_loss=0.3188


100%|██████████| 469/469 [00:05<00:00, 86.08it/s]


Epoch 62/80: train_loss=0.1995, val_loss=0.3183


100%|██████████| 469/469 [00:05<00:00, 87.12it/s]


Epoch 63/80: train_loss=0.1965, val_loss=0.3173


100%|██████████| 469/469 [00:05<00:00, 88.84it/s]


Epoch 64/80: train_loss=0.1937, val_loss=0.3168


100%|██████████| 469/469 [00:05<00:00, 89.13it/s]


Epoch 65/80: train_loss=0.1905, val_loss=0.3166


100%|██████████| 469/469 [00:05<00:00, 88.88it/s]


Epoch 66/80: train_loss=0.1880, val_loss=0.3148


100%|██████████| 469/469 [00:05<00:00, 85.42it/s]


Epoch 67/80: train_loss=0.1850, val_loss=0.3144


100%|██████████| 469/469 [00:05<00:00, 87.16it/s]


Epoch 68/80: train_loss=0.1826, val_loss=0.3135


100%|██████████| 469/469 [00:05<00:00, 89.75it/s]


Epoch 69/80: train_loss=0.1797, val_loss=0.3134


100%|██████████| 469/469 [00:05<00:00, 88.39it/s]


Epoch 70/80: train_loss=0.1768, val_loss=0.3126


100%|██████████| 469/469 [00:05<00:00, 87.72it/s]


Epoch 71/80: train_loss=0.1742, val_loss=0.3130


100%|██████████| 469/469 [00:05<00:00, 85.92it/s]


Epoch 72/80: train_loss=0.1713, val_loss=0.3120


100%|██████████| 469/469 [00:05<00:00, 86.00it/s]


Epoch 73/80: train_loss=0.1687, val_loss=0.3118


100%|██████████| 469/469 [00:05<00:00, 80.51it/s]


Epoch 74/80: train_loss=0.1658, val_loss=0.3127


100%|██████████| 469/469 [00:05<00:00, 85.69it/s]


Epoch 75/80: train_loss=0.1627, val_loss=0.3138


100%|██████████| 469/469 [00:05<00:00, 83.47it/s]


Epoch 76/80: train_loss=0.1602, val_loss=0.3124


100%|██████████| 469/469 [00:05<00:00, 83.33it/s]


Epoch 77/80: train_loss=0.1575, val_loss=0.3123


100%|██████████| 469/469 [00:05<00:00, 86.10it/s]


Epoch 78/80: train_loss=0.1549, val_loss=0.3124


100%|██████████| 469/469 [00:05<00:00, 86.87it/s]


Epoch 79/80: train_loss=0.1520, val_loss=0.3115


100%|██████████| 469/469 [00:05<00:00, 87.59it/s]


Epoch 80/80: train_loss=0.1494, val_loss=0.3115


#### **3. Visualising training / validation loss results**

Check that the models all converged.

Displays training and testing data for each model separately, with 4 traces on each graph.

The traces are: SGD training, SGD testing, NGD training, NGD testing.

In [56]:
color_cycle = ['rgb(200, 0, 0)', 'rgb(32, 102, 168)']

In [57]:
# Display training / val data for all models for NGD / SGD

loss_figures = {}

for title in models.keys():
    loss_fig = go.Figure()
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        loss_fig.add_trace(go.Scatter(
            x=np.arange(0, hp["num_epochs"]+1),
            y=train_losses[title][optim],
            mode="lines",
            line=dict(color=color, width=3),
            name=f"{optim} Train",
        ))

        loss_fig.add_trace(go.Scatter(
            x=np.arange(0, hp["num_epochs"]+1),
            y=val_losses[title][optim],
            mode="lines",
            line=dict(color=color, dash='dot', width=3),
            name=f"{optim} Validation",
        ))
        
    loss_fig.update_layout(
        title=f"{title} training / validation loss",
        xaxis_title="Epoch",
        yaxis_title="Loss",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=5,
        ),
        yaxis_type="linear",
        yaxis=dict(
        type="linear",
        range=[0, None],  # Start at 0, automatic upper limit
        ),
        width=600,
        height=400,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
        ),
    )
    loss_figures[title] = loss_fig
    loss_fig.show()

In [58]:
# Gradient norms over epochs for NGD and SGD, note that each step corresponds to one minibatch

update_norm_figures = {}

for title in models.keys():
    update_norm_fig = go.Figure()
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1

        #swapping and the adding of Nones to ngd causes the indexing to be weird
        if hp["experiment_type"] == "swap" and optim=='ngd':
            norm_y=update_norms[title][optim][1:]
        else:
            norm_y=update_norms[title][optim]

        update_norm_fig.add_trace(go.Scatter(
            x=np.arange(1, hp["num_epochs"]+1),
            #y=np.convolve(all_update_norms[title][optim], np.ones(30)/30, mode="same"),
            #y=all_update_norms[title][optim],

            #remove the first None of ngd as it is added without any updates, the first epoch has no updates
            y=norm_y,
            mode="lines",
            line=dict(color=color, width=3),
            name=f"{optim} update norm"))
    update_norm_fig.update_layout(
        title=f"{title} update norms over training",
        xaxis_title="Epochs",
        xaxis=dict(
            dtick=5,),
        yaxis_title="Update Size",
        yaxis_type="linear",
        width=600,
        height=400,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
        ),
    )
    update_norm_figures[title] = update_norm_fig
    update_norm_fig.show()

#### **4. Perform Hessian rank computation**

As a way of verifying results produced by the RLCT, we will compute an approximation of the Hessian for each model at convergence. Then, we'll estimate the rank of this matrix using its eigenspectrum. SLT predicts the following: 

$\text{RLCT} \geq \frac{\text{rank}(\textbf{Hess})}{2}$ 

We will check whether this is true for our experiments. Hessian computation is done using helper functions from `utils_hessian_fim.py` which acts as a wrapper for the `PyHessian` module (and the `nngeometry` module, for doing computations involving the Fisher Information Matrix).

In [59]:
# Create Hessians for each model - recall that model is a list containing all its past versions
if hp["hessian"]:
    hessians = {}
    for title, model in models.items():
        hessian = {}
        for optim in optimizers:
            hessians_list = produce_hessians(
                models=model[optim],
                data_loader=train_loader,
                num_batches=10,
                criterion=criterion,
                device=device,
            )
            hessian[optim] = hessians_list
        hessians[title] = hessian

In [60]:
# Compute Hessian trace for each model
if hp["hessian"]:
    hessian_traces = {}
    for title, hessian in tqdm(hessians.items(), desc="Processing Hessian Traces"):
        hessian_trace = {}
        for optim in optimizers:
            trace_list = produce_hessian_traces(
                hessians=hessian[optim],
                tol=1e-05,
                maxIters=50,
                N=5,
            )
            hessian_trace[optim] = trace_list
        hessian_traces[title] = hessian_trace

In [61]:
# Get the eigenspectum data from the Hessian objects
# if hp["hessian"]:
#     eigenspectra_data = {}
#     eigenspectra_figs = {}
#     for title, hessian in tqdm(hessians.items(), desc="Processing Hessians"):
#         eigenspectrum_data = {}
#         eigenspectrum_figs = {}
#         for optim in optimizers:
#             eigenspectrum_figs_list, eigenspectrum_data_list = produce_eigenspectra(
#                 hessians=hessian[optim],
#                 plot_type="log",
#             )
#             eigenspectrum_figs[optim] = eigenspectrum_figs_list
#             eigenspectrum_data[optim] = eigenspectrum_data_list
#         eigenspectra_data[title] = eigenspectrum_data
#         eigenspectra_figs[title] = eigenspectrum_figs

In [62]:
# Produce the traces of Hessian dimensionality over epochs
# if hp["hessian"]:
#     hessian_ranks = {}
#     for title, eigenspectrum_data in eigenspectra_data.items():
#         hessian_rank = {}
#         for optim in optimizers:
#             hessian_rank_list = find_hessian_dimensionality(eigenspectrum_data[optim])
#             hessian_rank[optim] = hessian_rank_list
#         hessian_ranks[title] = hessian_rank

#### **5. Perform RLCT estimation**

Using the `devinterp` library, we perform estimation of the RLCT (Real Log Canonical Threshold) or otherwise known as the LLC (Local Learning Coefficient).

`rlct_estimates` is a dictionary containing dictionaries, each of which contain two lists, one for SGD RLCT values over epochs, and one for NGD RLCT values over epochs. The same is true for `wbic_estimates`.

In [63]:
# Online LLC / WBIC estimation

wbic_histories = {}
llc_histories = {}

wbic_estimates = {}
llc_estimates = {}

for title, model in models.items():
    wbic_estimate = {}
    wbic_history = {}
    llc_estimate = {}
    llc_history = {}
    for optim in optimizers:
        llc_values_list, llc_history_list, wbic_values_list, wbic_history_list = estimate_rlcts_online(
            model[optim], train_loader, criterion, device, devinterp_args, wbic=True
        )
        wbic_estimate[optim] = wbic_values_list
        wbic_history[optim] = wbic_history_list
        llc_estimate[optim] = llc_values_list
        llc_history[optim] = llc_history_list
    wbic_estimates[title] = wbic_estimate
    wbic_histories[title] = wbic_history
    llc_estimates[title] = llc_estimate
    llc_histories[title] = llc_history

Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 294.31it/s]
Chain 1: 100%|██████████| 1000/1000 [00:02<00:00, 399.11it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 322.75it/s]
Chain 1: 100%|██████████| 1000/1000 [00:03<00:00, 294.27it/s]
Chain 0: 100%|██████████| 1000/1000 [00:02<00:00, 351.47it/s]
Chain 1: 100%|██████████| 1000/1000 [00:02<00:00, 403.40it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 328.18it/s]
Chain 1: 100%|██████████| 1000/1000 [00:02<00:00, 367.09it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 297.66it/s]
Chain 1: 100%|██████████| 1000/1000 [00:02<00:00, 382.84it/s]
Chain 0: 100%|██████████| 1000/1000 [00:02<00:00, 369.13it/s]
Chain 1: 100%|██████████| 1000/1000 [00:03<00:00, 307.13it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 313.35it/s]
Chain 1: 100%|██████████| 1000/1000 [00:02<00:00, 349.13it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 313.50it/s]
Chain 1: 100%|██████████| 1000/1000 [00:02<00:00, 341.78it/s]
Chain 0:

#### **6. Visualise RLCT / Hessian rank, Hessian eigenspectra, generalisation loss, RLCT convergence**

We display the following final figures:
- RLCT and Hessian rank evolution for each model, for NGD and SGD
- Hessian eigenspectra for SGD and NGD overlaid on the same plot
- Evolution of generalisation loss for each model, compared to validation loss
- RLCT moving average evolution for each model, for NGD and SGD, to check for convergence

In [64]:
# Visualise RLCT and Hessian rank data
exp_figures = {}
if hp["hessian"]:
    for title in models.keys():
        exp_fig = make_subplots(specs=[[{"secondary_y" : True}]])
        color_index = 0
        for optim in optimizers:
            color = color_cycle[color_index % len(color_cycle)]
            color_index += 1
            exp_fig.add_trace(go.Scatter(
                x=np.arange(0, hp["num_epochs"]+1),
                y=llc_estimates[title][optim],
                mode="lines",
                name=f"{optim} LLC",
                line=dict(color=color, width=3),
            ), secondary_y=False)
            exp_fig.add_trace(go.Scatter(
                x=np.arange(0, hp["num_epochs"]+1),
                y=hessian_traces[title][optim][1:],
                mode="lines",
                name=f"{optim} Hessian Trace",
                line=dict(color=color, dash="dash", width=3)
            ), secondary_y=True)
        exp_fig.update_layout(
            title=f"{title} LLC / Hessian Trace",
            xaxis_title="Epoch",
            xaxis=dict(
                tickmode='linear',
                tick0=0,
                dtick=10,
            ),
            yaxis_type="linear",
            width=600,
            height=400,
            legend=dict(
                x=0,
                y=1,
                xanchor='left',
                yanchor='bottom',
                orientation='h',
            ),
        )
        exp_fig.update_yaxes(title_text="LLC", secondary_y=False)
        exp_fig.update_yaxes(title_text="Hessian Trace", secondary_y=True)
        exp_figures[title] = exp_fig
        exp_fig.show()   

In [65]:
#visualize LLC and val loss

llc_figures = {}

for title in models.keys():
    llc_fig = make_subplots(specs=[[{"secondary_y" : True}]])
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1

        #ignore the first epoch
        llc_fig.add_trace(go.Scatter(
            x=np.arange(1, hp["num_epochs"]+1),
            y=llc_estimates[title][optim],
            mode="lines",
            name=f"{optim} LLC",
            line=dict(color=color, width=3),
        ), secondary_y=False)
        llc_fig.add_trace(go.Scatter(
            x=np.arange(0, hp["num_epochs"]+1),
            y=val_losses[title][optim],
            mode="lines",
            name=f"{optim} Validation Loss",
            line=dict(color=color, dash="dash", width=3)
        ), secondary_y=True)
    llc_fig.update_layout(
        title=f"{title} LLC / Validation Loss",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=5,
        ),
        yaxis_type="linear",
        yaxis=dict(
            range=[0, None],
        ),
        width=600,
        height=400,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
        ),
    )
    llc_fig.update_yaxes(title_text="LLC", secondary_y=False)
    llc_fig.update_yaxes(title_text="Validation Loss", 
                         secondary_y=True,
                        showgrid=False,  # This removes the grid for the secondary y-axis
                        zeroline=False   # This removes the zero line for the secondary y-axis
                        )
    llc_figures[title] = llc_fig
    llc_fig.show()

In [66]:
#visualize hessian trace and val loss

#visualize LLC and val loss
hess_figures = {}

if hp['hessian']:
    for title in models.keys():
        hess_fig = make_subplots(specs=[[{"secondary_y" : True}]])
        color_index = 0
        for optim in optimizers:
            color = color_cycle[color_index % len(color_cycle)]
            color_index += 1
            hess_fig.add_trace(go.Scatter(
                x=np.arange(1, hp["num_epochs"]+1),
                y=hessian_traces[title][optim][1:],
                mode="lines",
                name=f"{optim} Trace",
                line=dict(color=color, width=3),
            ), secondary_y=False)
            hess_fig.add_trace(go.Scatter(
                x=np.arange(1, hp["num_epochs"]+1),
                y=val_losses[title][optim][1:],
                mode="lines",
                name=f"{optim} Validation Loss",
                line=dict(color=color, dash="dash", width=3)
            ), secondary_y=True)
        hess_fig.update_layout(
            title=f"{title} Hessian Trace / Validation Loss",
            xaxis_title="Epoch",
            xaxis=dict(
                tickmode='linear',
                tick0=0,
                dtick=5,
            ),
            yaxis_type="linear",
            width=600,
            height=400,
            legend=dict(
                x=0,
                y=1,
                xanchor='left',
                yanchor='bottom',
                orientation='h',
            ),
        )
        hess_fig.update_yaxes(title_text="Hessian Trace", secondary_y=False)
        hess_fig.update_yaxes(title_text="Validation Loss", secondary_y=True)
        hess_figures[title] = hess_fig
        hess_fig.show()

In [67]:
# Visualise converged eigenspectra for SGD and NGD for each model

# if hp["hessian"]:
#     combined_eigenspectra = {}

#     for title in models.keys():

#         combined_eigenspectrum = go.Figure()
#         final_eigenspectra = {}
#         traces = {}

#         for optim in optimizers:
#             color = color_cycle[color_index % len(color_cycle)]
#             color_index += 1
#             final_eigenspectra[optim] = eigenspectra_figs[title][optim][-2]
#             final_eigenspectra[optim].data[0].name = optim
#             final_eigenspectra[optim].data[0].line.color = color
#             combined_eigenspectrum.add_trace(final_eigenspectra[optim].data[0])

#         combined_eigenspectrum.update_layout(
#             title=f"{title} Hessian eigenspectra at convergence",
#             xaxis_title="Eigenvalue",
#             yaxis_title="Probability density (log scale)",
#             yaxis_type="log",
#         )

#         combined_eigenspectrum.show()
#         combined_eigenspectra[title] = (combined_eigenspectrum)

In [68]:
# Visualise generalisation loss vs. testing loss for each model

val_wbic_figs = {}
for title in models.keys():
    val_wbic_fig = make_subplots(specs=[[{"secondary_y" : True}]])
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        val_wbic_fig.add_trace(go.Scatter(
            x=np.arange(1, hp["num_epochs"]+1),
            y=val_losses[title][optim][1:],
            mode="lines",
            name=f"{optim} Validation",
            line=dict(color=color, width=3),
            marker=dict(size=8, symbol="x")
        ),secondary_y=False)
        val_wbic_fig.add_trace(go.Scatter(
            x=np.arange(1, hp["num_epochs"]+1),
            y=wbic_estimates[title][optim][1:],
            mode="lines",
            name=f"{optim} WBIC",
            line=dict(color=color, dash="dot", width=4),
            marker=dict(size=8, symbol="x")
        ),secondary_y=True)
    val_wbic_fig.update_layout(
        title=f"{title} Validation Loss / WBIC",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=5,
        ),
        yaxis_type="linear",
        yaxis=dict(
            range=[0,None]
        ),
        width=600,
        height=400,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
        ),
    )
    val_wbic_fig.update_yaxes(title="Validation Loss", secondary_y=False)
    val_wbic_fig.update_yaxes(title="WBIC", 
                              secondary_y=True,
                              showgrid=False,
                              )
    val_wbic_figs[title] = val_wbic_fig
    val_wbic_fig.show()

In [69]:
#visualize cross correlation
# Visualise generalisation loss vs. testing loss for each model
from scipy import signal

corr_figs = {}
for title in models.keys():
    corr_fig = make_subplots()
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1

        # corr=np.correlate(
        #     val_losses[title][optim][1:], 
        #     wbic_estimates[title][optim][1:], 
        #     mode='full')

        corr=signal.correlate(
            val_losses[title][optim][1:], 
            wbic_estimates[title][optim][1:], mode='full', method='auto')

        corr_fig.add_trace(go.Scatter(
            x=np.arange(1,hp['num_epochs']+1),
            y=corr,
            mode="lines",
            name=f"{optim} Correlation",
            line=dict(color=color, width=3),
            marker=dict(size=8, symbol="x")
        ))

    corr_fig.update_layout(
        title=f"{title} WBIC and Validation Loss Cross Correlation",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=5,
        ),
        yaxis_type="linear",
        width=600,
        height=400,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
        ),
    )
    corr_fig.update_yaxes(title="Cross-Correlation")
    corr_figs[title] = corr_fig
    corr_fig.show()

TypeError: loop of ufunc does not support argument 50 of type NoneType which has no callable conjugate method

In [None]:
# Check the LLC chains / WBIC chains converged for each model

rlct_converge_plots = {}
for title in models.keys():
    rlct_converge_plot = make_subplots(specs=[[{"secondary_y" : True}]])
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        if wbic_histories[title][optim][-1] is not None and llc_histories[title][optim][-1] is not None:
            rlct_converge_plot.add_trace(go.Scatter(
                y=np.mean(llc_histories[title][optim][-1]["llc/moving_avg"],axis=0),
                name=f"{optim} LLC",
                line=dict(color=color, width=3)
            ), secondary_y=False)
            rlct_converge_plot.add_trace(go.Scatter(
                y=wbic_histories[title][optim][-1]["wbic/means"],
                name=f"{optim} WBIC",
                line=dict(color=color, dash="dash", width=6)
            ), secondary_y=True)
    rlct_converge_plot.update_layout(
        title=f"LLC / WBIC Trace",
        xaxis_title="Draws",
        yaxis=dict(
            range=[0,None]
        ),
        width=600,
        height=400,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
        ),
    )
    rlct_converge_plot.update_yaxes(title="LLC", secondary_y=False)
    rlct_converge_plot.update_yaxes(title="WBIC",
                                    secondary_y=True,
                                    showgrid=False)
    rlct_converge_plots[title] = rlct_converge_plot
    rlct_converge_plot.show()

In [70]:
# Save the results to a HTML file.

figures = []

combined_args = {**hp, **data_args, **devinterp_args}

if loss_figures:
    figures += list(loss_figures.values())
if update_norm_figures:
    figures += list(update_norm_figures.values())
if exp_figures:
    figures += list(exp_figures.values())
# if combined_eigenspectra:
#     figures += list(combined_eigenspectra.values())
if val_wbic_figs:
    figures += list(val_wbic_figs.values())
if rlct_converge_plots:
    figures += list(rlct_converge_plots.values())

figures+=list(llc_figures.values())
figures+=list(hess_figures.values())
figures+=list(corr_figs.values())

curr_time = datetime.now().strftime("%Y-%m-%d-%H-%M")

write_figs_to_html(
    figs=figures,
    dest=f"./ngd_sgd/{hp['model_type']}_ngd_sgd_llc_{curr_time}.html",
    title="Does NGD converge to minima that are 'more complex' i.e. have a higher RLCT?",
    summary=combined_args,
)