### **Does NGD converge to more complex minima according to the RLCT (Real Log Canonical Threshold)?**

This notebook aims to test the above claim, using developmental interpretability methods.

We hypothesise that NGD will consistently have a higher RLCT, because it premultiplies the gradient used for gradient descent by the inverse of the Fisher Information Matrix.

SLT proposes that models converge to singularities, where the Fisher Information Matrix is non-invertable. Hence, when near a singularity, the inverse of the FIM will blow up as its determinant is close to 0. Therefore, NGD will "jump away" from the singularity, instead favouring more complex, less singular minima.

#### **Methodology**
- Choose between DNN (dense neural network) and CNN (convolutional neural network)
- Vary hidden nodes and hidden layers for DNN (depth and width)
- Vary number of convolutional layers for CNN.
- Estimate Hessian rank using `PyHessian` module.
- Estimate RLCT using `devinterp` library.

#### **Instructions**

To produce your own results, go to `args_ngd_sgd.json`. The following parameters can be adjusted:
- `model_type (str)`  : `dnn` (deep network), `cnn` (convolutional network)
- `experiment_type (str)` : `standard` (train models independently using SGD and NGD), `swap` (train to convergence with SGD, then continue with SGD + NGD)

<!-- -->

- `optimizer (str)` : `sgd`, `ngd`, or `both` (performs analysis on both optimizers) - SET TO `both` IF DOING `swap` EXPERIMENT
- `hessian (bool)` : `true`, `false` (does Hessian rank analysis if enabled)

<!-- -->

- `hidden_nodes (List : int)` : e.g. `[32, 32, 64, 64, 128, 128]`, the hidden nodes to use in DNN
- `hidden_layers (List: int)` : e.g. `[1, 1, 2, 2, 3, 3]`, the hidden layers to use in DNN. Must have same length as `hidden_nodes`.
- `hidden_conv_layers (List : int)` : array of number of hidden convolutional layers to use for CNN, e.g. `[1, 2, 3]`.

<!-- -->

- `cut_off_epochs (int)` : the epoch number at which the optimiser is swapped if using `experiment_type = "swap"`
- `num_epochs (int)` : total number of training epochs

<!-- -->

- `sgd_lr (float)` : sgd learning rate for training
- `ngd_lr (float)` : ngd learning rate for training

<!-- -->

- `alpha (float)` : smoothing constant for estimation of FIM for NGD
- `eta (float)` : $ F_{n} = \eta S_n + (1 - \eta) F_{n-1} $ where $ S_n $ is the estimated Fisher matrix from the current batch.
- `epsilon (float)` : added to $ F $ to stop it from becoming singular
- `delta (float)` : constant preventing $ F $ from becoming singular

<!-- -->

- `momentum (float)` : momentum used in SGD and NGD
- `nesterov (bool)` : enable Nesterov momentum for NGD / SGD

<!-- -->

- `seed (int)` : the random seed used for `torch.manual_seed()`

<!-- -->

- `batch_size (int)` : batch size for train / validation dataloaders
- `num_workers (int)` : number of GPU workers for data loading (keep this at around 6, vary depending on your hardware)
- `dataset (str)` : `mnist`, `cifar10` (dataset to use for training)
- `num_hessian_batches (int)` : number of batches used for estimation of Hessian

<!-- -->

- `sampler (str)` : `sgld`, `sgnht` (optimiser to use for RLCT estimation)
- `num_chains (int)` : number of chains to use in RLCT estimation (higher leads to more accurate RLCT estimate)
- `num_draws (int)` : number of optimizer steps in RLCT estimation (should be high enough such that chain RLCT converges - check convergence plot)
- `localization (float)` : higher localization more strongly restricts optimizer to neighbourhood of model weights (stops RLCT estimator from going straight to minima)
- `sampler_lr (float)` : learning rate for sampler for RLCT estimation

You can run the notebook from the terminal using the following command:
```bash
jupyter nbconvert --to notebook --execute --inplace ./experiments/eval_ngd_sgd.ipynb
```

#### **0. Import libraries**

The `NGD` module is used for implementing Natural Gradient Descent efficiently. It approximates the Fisher Information Matrix to do this.
The `devinterp` library is used for estimation of the LLC (local learning coefficient).

In [100]:
from multiprocessing import freeze_support

import os
import sys
import copy
import pickle
import pprint
import json
from pathlib import Path
from datetime import datetime
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm

sys.path.append("../")

import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split

from PyHessian.pyhessian import *
from PyHessian.density_plot import *
from nngeometry.metrics import FIM
from nngeometry.object import PMatKFAC, PMatDiag, PVector

from utils_general import *
from utils_hessian_fim import *
from architectures import *
from ngd import NGD

from plotly.subplots import make_subplots
import plotly.express as px
import plotly.colors
import plotly.graph_objects as go
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap

device = "cuda" if t.cuda.is_available() else "cpu"
print(f"Device in use: {device}")

%load_ext autoreload
%autoreload
%matplotlib inline


Device in use: cpu
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### **1. Import data for training**

We import our dataset for training. We use a helper function, `build_data_loaders` for this, which allows us to choose between MNIST and CIFAR10. 

We specify three dictionaries, `hp` = hyperparameters, `data_args` = arguments for dataloading, `devinterp_args` = arguments for LLC and WBIC estimation.

In [101]:
# Load experiment args

with open("args_ngd_sgd.json", "r") as file:
    args = json.load(file) 

hp, data_args, devinterp_args = args

print("======HYPERPARAMETERS======")
pprint.pprint(hp, sort_dicts=False)

print("======DATA ARGS======")
pprint.pprint(data_args, sort_dicts=False)

print("======DEVINTERP ARGS======")
pprint.pprint(devinterp_args, sort_dicts=False)

{'model_type': 'ffnn',
 'experiment_type': 'standard',
 'optimizer': 'both',
 'hessian': True,
 'hidden_nodes': [128],
 'hidden_layers': [2],
 'hidden_conv_layers': [0],
 'cut_off_epochs': 40,
 'num_epochs': 40,
 'sgd_lr': 0.01,
 'ngd_lr': 0.01,
 'alpha': 0.1,
 'eta': 0.9,
 'epsilon': 1e-10,
 'delta': 0.0005,
 'momentum': 0.9,
 'nesterov': True,
 'seed': 1}
{'batch_size': 128,
 'num_workers': 6,
 'dataset': 'fashion-mnist',
 'num_hessian_batches': 1}
{'sampler': 'sgld',
 'num_chains': 2,
 'num_draws': 1000,
 'localization': 100.0,
 'sampler_lr': 0.0001}


In [102]:
# Set random seed for weight initialisation (for reproducibility of results and to ensure NGD/SGD start from same point in loss landscape)

t.manual_seed(hp["seed"])

<torch._C.Generator at 0x7fc2d074e1b0>

In [103]:
train_loader, test_loader = build_data_loaders(data_args)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:42<00:00, 619449.52it/s] 


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 153189.69it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:07<00:00, 618897.50it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 9412500.87it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






#### **2. Training models**

Choose to train either a DNN or CNN.

This code produces a dictionary `models` where each key describes the model itself, e.g. "DNN 4 HL, 256 HN".

The values in this dictionary, iterated over as `model`, are each another dictionary with the optimizer as keys
the vakues are then a list containing the model with initial weights. As models are trained, each epoch the trained model will be added to this list so we can record the model history.

```
models={
    "architecture1":{
        "optim1":[epoch0,epoch1],
        "optim2":[epoch0,epoch1],
    },
    "architecture2":{
        "optim1":[epoch0,epoch1],
        "optim2":[epoch0,epoch1],
    }
}
```

In [104]:
# Initialise models dependent on arguments

models = {}
optimizers = ["ngd", "sgd"] if hp["optimizer"] == "both" else [hp["optimizer"]]

if hp["model_type"] == "ffnn":
    hidden_nodes = hp["hidden_nodes"]
    hidden_layers = hp["hidden_layers"]
    for hidden_node, hidden_layer in zip(hidden_nodes, hidden_layers):
        title = f"FFNN {hidden_layer} HL, {hidden_node} HN"
        if data_args["dataset"] == "mnist" or data_args["dataset"] == "fashion-mnist":
            model = NeuralNet(relu=True,input_size=28*28,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        elif data_args["dataset"] == "cifar10":
            model = NeuralNet(relu=True,input_size=32*32*3,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}
elif hp["model_type"] == "dlnn":
    hidden_nodes = hp["hidden_nodes"]
    hidden_layers = hp["hidden_layers"]
    for hidden_node, hidden_layer in zip(hidden_nodes, hidden_layers):
        title = f"DLNN {hidden_layer} HL, {hidden_node} HN"
        if data_args["dataset"] == "mnist" or data_args["dataset"] == "fashion-mnist":
            model = NeuralNet(relu=False,input_size=28*28,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        elif data_args["dataset"] == "cifar10":
            model = NeuralNet(relu=False,input_size=32*32*3,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}
elif hp["model_type"] == "cnn":
    hidden_conv_layers = hp["hidden_conv_layers"]
    for hidden_conv_layer in hidden_conv_layers:
        title = f"CNN {hidden_conv_layer} HCL"
        if data_args["dataset"] == "mnist" or data_args["dataset"] == "fashion-mnist":
            model = LeNet(dataset="mnist",extra_layers=hidden_conv_layer).to(device)
        elif data_args["dataset"] == "cifar10":
            model = LeNet(dataset="cifar10",extra_layers=hidden_conv_layer).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}


If experiment 1, make sure `cutoff_epochs` equal to `num_epochs`

In [105]:
# EXPERIMENT 1: Train models independently for num_epochs with SGD and NGD

if hp["experiment_type"] == "standard":
    train_losses = {}
    val_losses = {}
    update_norms = {}
    all_update_norms = {}

    criterion = nn.CrossEntropyLoss()

    for title, model in models.items():

        # Store list for SGD losses, NGD losses, for train and val
        model_train_losses = {optim : [] for optim in optimizers}
        model_val_losses = {optim : [] for optim in optimizers}
        model_update_norms = {optim : [] for optim in optimizers}
        model_all_update_norms = {optim : [] for optim in optimizers}

        for optim in optimizers:
            state = copy.deepcopy(model[optim][0])
            if optim == "sgd":
                optimizer = t.optim.SGD(
                    params=state.parameters(),
                    lr=hp["sgd_lr"],
                    momentum=hp["momentum"],
                    nesterov=hp["nesterov"],
                )
            elif optim == "ngd":
                optimizer = NGD(
                    params=state.parameters(),
                    lr=hp["ngd_lr"],
                    alpha=hp["alpha"],
                    eta=hp["eta"],
                    epsilon=hp["epsilon"],
                    delta=hp["delta"],
                    momentum=hp["momentum"],
                    nesterov=hp["nesterov"],
                )
            print(f"TRAINING MODEL: {title} | OPTIMISER: {optim}")
            initial_train_loss = evaluate(state, train_loader, criterion, device)
            initial_val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses[optim].append(initial_train_loss)
            model_val_losses[optim].append(initial_val_loss)
        
            for epoch in range(1, hp["num_epochs"]+1):
                train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optimizer, criterion, device)
                val_loss = evaluate(state, test_loader, criterion, device)
                model_train_losses[optim].append(train_loss)
                model_update_norms[optim].append(update_norm)
                model_val_losses[optim].append(val_loss)
                #note that epoch_update norms is a list with the update_norm of each minibatch, this will concat the list to previous list
                model_all_update_norms[optim] += epoch_update_norms
                model[optim].append(copy.deepcopy(state))
                print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")

        # Save train/val loss dictionaries to the right model key
        train_losses[title] = model_train_losses
        val_losses[title] = model_val_losses
        update_norms[title] = model_update_norms
        all_update_norms[title] = model_all_update_norms


TRAINING MODEL: FFNN 2 HL, 128 HN | OPTIMISER: ngd


100%|██████████| 469/469 [00:05<00:00, 78.57it/s]


Epoch 1/40: train_loss=1.7646, val_loss=0.9200


100%|██████████| 469/469 [00:05<00:00, 78.68it/s]


Epoch 2/40: train_loss=0.5763, val_loss=0.4501


100%|██████████| 469/469 [00:05<00:00, 78.31it/s]


Epoch 3/40: train_loss=0.3698, val_loss=0.3752


100%|██████████| 469/469 [00:06<00:00, 72.23it/s]


Epoch 4/40: train_loss=0.3105, val_loss=0.3448


100%|██████████| 469/469 [00:06<00:00, 74.37it/s]


Epoch 5/40: train_loss=0.2761, val_loss=0.3273


100%|██████████| 469/469 [00:06<00:00, 68.25it/s]


Epoch 6/40: train_loss=0.2500, val_loss=0.3174


100%|██████████| 469/469 [00:06<00:00, 75.78it/s]


Epoch 7/40: train_loss=0.2300, val_loss=0.3139


100%|██████████| 469/469 [00:06<00:00, 68.96it/s]


Epoch 8/40: train_loss=0.2106, val_loss=0.3093


100%|██████████| 469/469 [00:05<00:00, 79.07it/s]


Epoch 9/40: train_loss=0.1945, val_loss=0.3065


100%|██████████| 469/469 [00:05<00:00, 80.98it/s]


Epoch 10/40: train_loss=0.1795, val_loss=0.3110


100%|██████████| 469/469 [00:05<00:00, 80.23it/s]


Epoch 11/40: train_loss=0.1645, val_loss=0.3088


100%|██████████| 469/469 [00:06<00:00, 74.06it/s]


Epoch 12/40: train_loss=0.1535, val_loss=0.3136


100%|██████████| 469/469 [00:06<00:00, 76.97it/s]


Epoch 13/40: train_loss=0.1417, val_loss=0.3164


100%|██████████| 469/469 [00:06<00:00, 76.31it/s]


Epoch 14/40: train_loss=0.1305, val_loss=0.3248


100%|██████████| 469/469 [00:05<00:00, 79.24it/s]


Epoch 15/40: train_loss=0.1216, val_loss=0.3356


100%|██████████| 469/469 [00:06<00:00, 75.69it/s]


Epoch 16/40: train_loss=0.1134, val_loss=0.3447


100%|██████████| 469/469 [00:05<00:00, 82.40it/s]


Epoch 17/40: train_loss=0.1075, val_loss=0.3495


100%|██████████| 469/469 [00:05<00:00, 78.54it/s]


Epoch 18/40: train_loss=0.1004, val_loss=0.3516


100%|██████████| 469/469 [00:05<00:00, 79.00it/s]


Epoch 19/40: train_loss=0.0931, val_loss=0.3588


100%|██████████| 469/469 [00:05<00:00, 83.11it/s]


Epoch 20/40: train_loss=0.0874, val_loss=0.3814


100%|██████████| 469/469 [00:05<00:00, 85.19it/s]


Epoch 21/40: train_loss=0.0851, val_loss=0.3734


100%|██████████| 469/469 [00:05<00:00, 85.47it/s]


Epoch 22/40: train_loss=0.0823, val_loss=0.3807


100%|██████████| 469/469 [00:06<00:00, 76.53it/s]


Epoch 23/40: train_loss=0.0774, val_loss=0.4135


100%|██████████| 469/469 [00:05<00:00, 79.06it/s]


Epoch 24/40: train_loss=0.0734, val_loss=0.4047


100%|██████████| 469/469 [00:05<00:00, 80.53it/s]


Epoch 25/40: train_loss=0.0687, val_loss=0.4042


100%|██████████| 469/469 [00:06<00:00, 70.89it/s]


Epoch 26/40: train_loss=0.0665, val_loss=0.4255


100%|██████████| 469/469 [00:07<00:00, 59.66it/s]


Epoch 27/40: train_loss=0.0621, val_loss=0.4335


100%|██████████| 469/469 [00:06<00:00, 68.09it/s]


Epoch 28/40: train_loss=0.0625, val_loss=0.4537


100%|██████████| 469/469 [00:06<00:00, 76.10it/s]


Epoch 29/40: train_loss=0.0632, val_loss=0.4616


100%|██████████| 469/469 [00:06<00:00, 73.76it/s]


Epoch 30/40: train_loss=0.0632, val_loss=0.4699


100%|██████████| 469/469 [00:06<00:00, 77.71it/s]


Epoch 31/40: train_loss=0.0586, val_loss=0.4834


100%|██████████| 469/469 [00:07<00:00, 63.91it/s]


Epoch 32/40: train_loss=0.0584, val_loss=0.4807


100%|██████████| 469/469 [00:06<00:00, 71.20it/s]


Epoch 33/40: train_loss=0.0655, val_loss=0.4999


100%|██████████| 469/469 [00:06<00:00, 72.35it/s]


Epoch 34/40: train_loss=0.0725, val_loss=0.5383


100%|██████████| 469/469 [00:06<00:00, 70.93it/s]


Epoch 35/40: train_loss=0.0750, val_loss=0.5270


100%|██████████| 469/469 [00:07<00:00, 63.65it/s]


Epoch 36/40: train_loss=0.0751, val_loss=0.5316


100%|██████████| 469/469 [00:06<00:00, 70.25it/s]


Epoch 37/40: train_loss=0.0774, val_loss=0.5626


100%|██████████| 469/469 [00:06<00:00, 70.80it/s]


Epoch 38/40: train_loss=0.0832, val_loss=0.5887


100%|██████████| 469/469 [00:06<00:00, 73.91it/s]


Epoch 39/40: train_loss=0.0797, val_loss=0.5951


100%|██████████| 469/469 [00:06<00:00, 72.04it/s]


Epoch 40/40: train_loss=0.0788, val_loss=0.5831
TRAINING MODEL: FFNN 2 HL, 128 HN | OPTIMISER: sgd


100%|██████████| 469/469 [00:03<00:00, 121.40it/s]


Epoch 1/40: train_loss=0.6438, val_loss=0.4689


100%|██████████| 469/469 [00:04<00:00, 113.14it/s]


Epoch 2/40: train_loss=0.4143, val_loss=0.4164


100%|██████████| 469/469 [00:04<00:00, 105.03it/s]


Epoch 3/40: train_loss=0.3702, val_loss=0.3940


100%|██████████| 469/469 [00:04<00:00, 108.04it/s]


Epoch 4/40: train_loss=0.3452, val_loss=0.3928


100%|██████████| 469/469 [00:04<00:00, 111.49it/s]


Epoch 5/40: train_loss=0.3267, val_loss=0.3733


100%|██████████| 469/469 [00:04<00:00, 112.38it/s]


Epoch 6/40: train_loss=0.3101, val_loss=0.3638


100%|██████████| 469/469 [00:04<00:00, 116.54it/s]


Epoch 7/40: train_loss=0.2958, val_loss=0.3508


100%|██████████| 469/469 [00:04<00:00, 108.26it/s]


Epoch 8/40: train_loss=0.2834, val_loss=0.3409


100%|██████████| 469/469 [00:04<00:00, 103.31it/s]


Epoch 9/40: train_loss=0.2730, val_loss=0.3329


100%|██████████| 469/469 [00:04<00:00, 107.18it/s]


Epoch 10/40: train_loss=0.2622, val_loss=0.3361


100%|██████████| 469/469 [00:04<00:00, 109.26it/s]


Epoch 11/40: train_loss=0.2538, val_loss=0.3296


100%|██████████| 469/469 [00:03<00:00, 123.28it/s]


Epoch 12/40: train_loss=0.2466, val_loss=0.3384


100%|██████████| 469/469 [00:04<00:00, 100.22it/s]


Epoch 13/40: train_loss=0.2385, val_loss=0.3373


100%|██████████| 469/469 [00:04<00:00, 104.67it/s]


Epoch 14/40: train_loss=0.2302, val_loss=0.3272


100%|██████████| 469/469 [00:04<00:00, 108.72it/s]


Epoch 15/40: train_loss=0.2224, val_loss=0.3260


100%|██████████| 469/469 [00:04<00:00, 99.27it/s] 


Epoch 16/40: train_loss=0.2180, val_loss=0.3320


100%|██████████| 469/469 [00:04<00:00, 100.39it/s]


Epoch 17/40: train_loss=0.2103, val_loss=0.3467


100%|██████████| 469/469 [00:04<00:00, 115.11it/s]


Epoch 18/40: train_loss=0.2050, val_loss=0.3366


100%|██████████| 469/469 [00:04<00:00, 115.73it/s]


Epoch 19/40: train_loss=0.1978, val_loss=0.3374


100%|██████████| 469/469 [00:04<00:00, 109.02it/s]


Epoch 20/40: train_loss=0.1917, val_loss=0.3297


100%|██████████| 469/469 [00:04<00:00, 112.97it/s]


Epoch 21/40: train_loss=0.1874, val_loss=0.3375


100%|██████████| 469/469 [00:04<00:00, 113.86it/s]


Epoch 22/40: train_loss=0.1829, val_loss=0.3426


100%|██████████| 469/469 [00:04<00:00, 109.84it/s]


Epoch 23/40: train_loss=0.1785, val_loss=0.3464


100%|██████████| 469/469 [00:04<00:00, 97.42it/s] 


Epoch 24/40: train_loss=0.1738, val_loss=0.3625


100%|██████████| 469/469 [00:04<00:00, 111.65it/s]


Epoch 25/40: train_loss=0.1671, val_loss=0.3569


100%|██████████| 469/469 [00:04<00:00, 98.69it/s] 


Epoch 26/40: train_loss=0.1637, val_loss=0.3782


100%|██████████| 469/469 [00:05<00:00, 84.48it/s] 


Epoch 27/40: train_loss=0.1591, val_loss=0.3472


100%|██████████| 469/469 [00:04<00:00, 105.89it/s]


Epoch 28/40: train_loss=0.1535, val_loss=0.3634


100%|██████████| 469/469 [00:04<00:00, 101.75it/s]


Epoch 29/40: train_loss=0.1497, val_loss=0.3576


100%|██████████| 469/469 [00:04<00:00, 105.14it/s]


Epoch 30/40: train_loss=0.1463, val_loss=0.3747


100%|██████████| 469/469 [00:04<00:00, 112.92it/s]


Epoch 31/40: train_loss=0.1407, val_loss=0.3658


100%|██████████| 469/469 [00:04<00:00, 104.02it/s]


Epoch 32/40: train_loss=0.1375, val_loss=0.3611


100%|██████████| 469/469 [00:04<00:00, 106.85it/s]


Epoch 33/40: train_loss=0.1355, val_loss=0.3753


100%|██████████| 469/469 [00:04<00:00, 101.58it/s]


Epoch 34/40: train_loss=0.1319, val_loss=0.3642


100%|██████████| 469/469 [00:04<00:00, 110.22it/s]


Epoch 35/40: train_loss=0.1282, val_loss=0.3893


100%|██████████| 469/469 [00:04<00:00, 111.35it/s]


Epoch 36/40: train_loss=0.1244, val_loss=0.3797


100%|██████████| 469/469 [00:04<00:00, 96.25it/s] 


Epoch 37/40: train_loss=0.1214, val_loss=0.3983


100%|██████████| 469/469 [00:04<00:00, 94.05it/s] 


Epoch 38/40: train_loss=0.1161, val_loss=0.4150


100%|██████████| 469/469 [00:04<00:00, 108.73it/s]


Epoch 39/40: train_loss=0.1136, val_loss=0.3962


100%|██████████| 469/469 [00:04<00:00, 100.53it/s]


Epoch 40/40: train_loss=0.1100, val_loss=0.4184


In [106]:
# EXPERIMENT 2: Train to convergence with SGD, then continue with SGD and NGD and compare

if hp["experiment_type"] == "swap":
    train_losses = {}
    val_losses = {}
    update_norms = {}
    all_update_norms = {}

    criterion = nn.CrossEntropyLoss()

    for title, model in models.items():

        del model["ngd"][0]

        model_train_losses = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_val_losses = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_update_norms = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_all_update_norms = {"sgd" : [], "ngd" : [None for i in range(len(train_loader)*hp["cut_off_epochs"])]}

        state = copy.deepcopy(model["sgd"][0])
        optimizer = t.optim.SGD(
            params=state.parameters(),
            lr=hp["sgd_lr"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
    
        print(f"TRAINING MODEL: {title} | OPTIMISER: SGD")
        initial_train_loss = evaluate(state, train_loader, criterion, device)
        initial_val_loss = evaluate(state, test_loader, criterion, device)
        model_train_losses["sgd"].append(initial_train_loss)
        model_val_losses["sgd"].append(initial_val_loss)
        for epoch in range(1, hp["cut_off_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optimizer, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["sgd"].append(train_loss)
            model_update_norms["sgd"].append(update_norm)
            model_val_losses["sgd"].append(val_loss)
            model_all_update_norms["sgd"] += epoch_update_norms
            model["sgd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")


        # Train with SGD
        print("Training from model checkpoint with SGD.")
        state = copy.deepcopy(model["sgd"][hp["cut_off_epochs"]])
        optim_sgd = t.optim.SGD(
            params=state.parameters(),
            lr=hp["sgd_lr"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
        for epoch in range(hp["cut_off_epochs"]+1, hp["num_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optim_sgd, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["sgd"].append(train_loss)
            model_update_norms["sgd"].append(update_norm)
            model_val_losses["sgd"].append(val_loss)
            model_all_update_norms["sgd"] += epoch_update_norms
            model["sgd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
        
        # Train with NGD
        print("Training from model checkpoint with NGD.")
        state = copy.deepcopy(model["sgd"][hp["cut_off_epochs"]])
        optim_ngd = NGD(
            params=state.parameters(),
            lr=hp["ngd_lr"],
            alpha=hp["alpha"],
            eta=hp["eta"],
            epsilon=hp["epsilon"],
            delta=hp["delta"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
        for epoch in range(hp["cut_off_epochs"]+1, hp["num_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optim_ngd, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["ngd"].append(train_loss)
            model_update_norms["ngd"].append(update_norm)
            model_val_losses["ngd"].append(val_loss)
            model_all_update_norms["ngd"] += epoch_update_norms
            model["ngd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")

        # Save train/val loss dictionaries to the right model key
        train_losses[title] = model_train_losses
        val_losses[title] = model_val_losses
        update_norms[title] = model_update_norms
        all_update_norms[title] = model_all_update_norms


In [107]:
# If we are doing the swap experiment, fill in the first hp["cut_off_epochs"] models with None

if hp["experiment_type"] == "swap":
    none_models = [None for i in range(hp["cut_off_epochs"]+1)]
    for title, model in models.items():
        model["ngd"] = none_models + model["ngd"]

#### **3. Visualising training / validation loss results**

Check that the models all converged.

Displays training and testing data for each model separately, with 4 traces on each graph.

The traces are: SGD training, SGD testing, NGD training, NGD testing.

In [108]:
# Display training / val data for all models for NGD / SGD

epochs = np.arange(1, hp["num_epochs"]+1)

loss_figures = {}

color_cycle = ['rgb(200, 0, 0)', 'rgb(32, 102, 168)']

for title in models.keys():
    loss_fig = make_subplots(specs=[[{"secondary_y" : True}]])
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        loss_fig.add_trace(go.Scatter(
            x=np.arange(0, hp["num_epochs"]+1),
            y=train_losses[title][optim],
            mode="lines",
            line=dict(color=color, width=3),
            name=f"{optim} Train",
        ), secondary_y=False)
        loss_fig.add_trace(go.Scatter(
            x=np.arange(0, hp["num_epochs"]+1),
            y=val_losses[title][optim],
            mode="lines",
            line=dict(color=color, dash='dot', width=4),
            name=f"{optim} Validation",
        ), secondary_y=True)
    loss_fig.update_layout(
        title=f"{title} training / validation loss",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=1,
        ),
        yaxis_type="linear",
        font=dict(
            family="Times New Roman",
            size=25,
            color="black",
        ),
        width=900,
        height=600,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
            font=dict(
                size=20,
            ),
        ),
    )
    loss_fig.update_yaxes(title_text="Training Loss", secondary_y=False)
    loss_fig.update_yaxes(title_text="Validation Loss", secondary_y=True)
    loss_figures[title] = loss_fig
    loss_fig.show()

In [109]:
# Gradient norms over epochs for NGD and SGD, note that each step corresponds to one minibatch

update_norm_figures = {}

for title in models.keys():
    update_norm_fig = go.Figure()
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        update_norm_fig.add_trace(go.Scatter(
            y=np.convolve(all_update_norms[title][optim], np.ones(30)/30, mode="same"),
            mode="lines",
            line=dict(color=color, width=3),
            name=f"{optim} update norm"))
    update_norm_fig.update_layout(
        title=f"{title} update norms over training steps",
        xaxis_title="Steps",
        xaxis=dict(type="log", dtick="1"),
        yaxis_title="Update Size",
        yaxis_type="linear",
        font=dict(
            family="Times New Roman",
            size=25,
            color="black",
        ),
        width=900,
        height=600,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
            font=dict(
                size=20,
            ),
        ),
    )
    update_norm_figures[title] = update_norm_fig
    update_norm_fig.show()

#### **4. Perform Hessian rank computation**

As a way of verifying results produced by the RLCT, we will compute an approximation of the Hessian for each model at convergence. Then, we'll estimate the rank of this matrix using its eigenspectrum. SLT predicts the following: 

$\text{RLCT} \geq \frac{\text{rank}(\textbf{Hess})}{2}$ 

We will check whether this is true for our experiments. Hessian computation is done using helper functions from `utils_hessian_fim.py` which acts as a wrapper for the `PyHessian` module (and the `nngeometry` module, for doing computations involving the Fisher Information Matrix).

In [110]:
# Create Hessians for each model - recall that model is a list containing all its past versions
if hp["hessian"]:
    hessians = {}
    for title, model in models.items():
        hessian = {}
        for optim in optimizers:
            hessians_list = produce_hessians(
                models=model[optim],
                data_loader=train_loader,
                num_batches=10,
                criterion=criterion,
                device=device,
            )
            hessian[optim] = hessians_list
        hessians[title] = hessian

In [111]:
# Compute Hessian trace for each model
if hp["hessian"]:
    hessian_traces = {}
    for title, hessian in tqdm(hessians.items(), desc="Processing Hessian Traces"):
        hessian_trace = {}
        for optim in optimizers:
            trace_list = produce_hessian_traces(
                hessians=hessian[optim],
                tol=1e-05,
                maxIters=50,
                N=5,
            )
            hessian_trace[optim] = trace_list
        hessian_traces[title] = hessian_trace

Processing Hessian Traces: 100%|██████████| 1/1 [03:13<00:00, 193.70s/it]


In [112]:
# Get the eigenspectum data from the Hessian objects
# if hp["hessian"]:
#     eigenspectra_data = {}
#     eigenspectra_figs = {}
#     for title, hessian in tqdm(hessians.items(), desc="Processing Hessians"):
#         eigenspectrum_data = {}
#         eigenspectrum_figs = {}
#         for optim in optimizers:
#             eigenspectrum_figs_list, eigenspectrum_data_list = produce_eigenspectra(
#                 hessians=hessian[optim],
#                 plot_type="log",
#             )
#             eigenspectrum_figs[optim] = eigenspectrum_figs_list
#             eigenspectrum_data[optim] = eigenspectrum_data_list
#         eigenspectra_data[title] = eigenspectrum_data
#         eigenspectra_figs[title] = eigenspectrum_figs

In [113]:
# Produce the traces of Hessian dimensionality over epochs
# if hp["hessian"]:
#     hessian_ranks = {}
#     for title, eigenspectrum_data in eigenspectra_data.items():
#         hessian_rank = {}
#         for optim in optimizers:
#             hessian_rank_list = find_hessian_dimensionality(eigenspectrum_data[optim])
#             hessian_rank[optim] = hessian_rank_list
#         hessian_ranks[title] = hessian_rank

#### **5. Perform RLCT estimation**

Using the `devinterp` library, we perform estimation of the RLCT (Real Log Canonical Threshold) or otherwise known as the LLC (Local Learning Coefficient).

`rlct_estimates` is a dictionary containing dictionaries, each of which contain two lists, one for SGD RLCT values over epochs, and one for NGD RLCT values over epochs. The same is true for `wbic_estimates`.

In [114]:
# Standard RLCT estimation

rlct_estimates = {}
histories = {}

for title, model in models.items():
    history = {}
    rlct_estimate = {}
    for optim in optimizers:
        rlct_list, history_list = estimate_rlcts(
            model[optim], train_loader, criterion, device, devinterp_args,
        )
        rlct_estimate[optim] = rlct_list
        history[optim] = history_list
    rlct_estimates[title] = rlct_estimate 
    histories[title] = history

Chain 0: 100%|██████████| 1000/1000 [00:06<00:00, 158.79it/s]
Chain 1: 100%|██████████| 1000/1000 [00:06<00:00, 159.11it/s]
Chain 0: 100%|██████████| 1000/1000 [00:06<00:00, 164.88it/s]
Chain 1: 100%|██████████| 1000/1000 [00:06<00:00, 148.95it/s]
Chain 0: 100%|██████████| 1000/1000 [00:06<00:00, 161.33it/s]
Chain 1: 100%|██████████| 1000/1000 [00:06<00:00, 163.29it/s]
Chain 0: 100%|██████████| 1000/1000 [00:06<00:00, 159.43it/s]
Chain 1: 100%|██████████| 1000/1000 [00:06<00:00, 156.66it/s]
Chain 0: 100%|██████████| 1000/1000 [00:05<00:00, 181.76it/s]
Chain 1: 100%|██████████| 1000/1000 [00:06<00:00, 152.17it/s]
Chain 0: 100%|██████████| 1000/1000 [00:06<00:00, 148.35it/s]
Chain 1: 100%|██████████| 1000/1000 [00:08<00:00, 124.17it/s]
Chain 0: 100%|██████████| 1000/1000 [00:07<00:00, 137.93it/s]
Chain 1: 100%|██████████| 1000/1000 [00:08<00:00, 121.96it/s]
Chain 0: 100%|██████████| 1000/1000 [00:06<00:00, 149.68it/s]
Chain 1: 100%|██████████| 1000/1000 [00:06<00:00, 148.52it/s]
Chain 0:

In [115]:
# Online LLC / WBIC estimation

wbic_histories = {}
llc_histories = {}

wbic_estimates = {}
llc_estimates = {}

for title, model in models.items():
    wbic_estimate = {}
    wbic_history = {}
    llc_estimate = {}
    llc_history = {}
    for optim in optimizers:
        llc_values_list, llc_history_list, wbic_values_list, wbic_history_list = estimate_rlcts_online(
            model[optim], train_loader, criterion, device, devinterp_args, wbic=True
        )
        wbic_estimate[optim] = wbic_values_list
        wbic_history[optim] = wbic_history_list
        llc_estimate[optim] = llc_values_list
        llc_history[optim] = llc_history_list
    wbic_estimates[title] = wbic_estimate
    wbic_histories[title] = wbic_history
    llc_estimates[title] = llc_estimate
    llc_histories[title] = llc_history

Chain 0: 100%|██████████| 1000/1000 [00:05<00:00, 169.70it/s]
Chain 1: 100%|██████████| 1000/1000 [00:05<00:00, 167.82it/s]
Chain 0: 100%|██████████| 1000/1000 [00:05<00:00, 169.84it/s]
Chain 1: 100%|██████████| 1000/1000 [00:06<00:00, 155.47it/s]
Chain 0: 100%|██████████| 1000/1000 [00:06<00:00, 143.63it/s]
Chain 1: 100%|██████████| 1000/1000 [00:05<00:00, 173.63it/s]
Chain 0: 100%|██████████| 1000/1000 [00:05<00:00, 169.39it/s]
Chain 1: 100%|██████████| 1000/1000 [00:05<00:00, 171.94it/s]
Chain 0: 100%|██████████| 1000/1000 [00:05<00:00, 169.98it/s]
Chain 1: 100%|██████████| 1000/1000 [00:06<00:00, 152.27it/s]
Chain 0: 100%|██████████| 1000/1000 [00:06<00:00, 159.35it/s]
Chain 1: 100%|██████████| 1000/1000 [00:06<00:00, 164.77it/s]
Chain 0: 100%|██████████| 1000/1000 [00:06<00:00, 164.91it/s]
Chain 1: 100%|██████████| 1000/1000 [00:05<00:00, 170.72it/s]
Chain 0: 100%|██████████| 1000/1000 [00:06<00:00, 155.60it/s]
Chain 1: 100%|██████████| 1000/1000 [00:06<00:00, 151.99it/s]
Chain 0:

#### **6. Visualise RLCT / Hessian rank, Hessian eigenspectra, generalisation loss, RLCT convergence**

We display the following final figures:
- RLCT and Hessian rank evolution for each model, for NGD and SGD
- Hessian eigenspectra for SGD and NGD overlaid on the same plot
- Evolution of generalisation loss for each model, compared to validation loss
- RLCT moving average evolution for each model, for NGD and SGD, to check for convergence

In [120]:
# Visualise RLCT and Hessian rank data

exp_figures = {}

for title in models.keys():
    exp_fig = make_subplots(specs=[[{"secondary_y" : True}]])
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        exp_fig.add_trace(go.Scatter(
            x=epochs,
            y=llc_estimates[title][optim],
            mode="lines",
            name=f"{optim} RLCT",
            line=dict(color=color, width=3),
        ), secondary_y=False)
        exp_fig.add_trace(go.Scatter(
            x=epochs,
            y=hessian_traces[title][optim][1:],
            mode="lines",
            name=f"{optim} Hessian Trace",
            line=dict(color=color, dash="dash", width=3)
        ), secondary_y=True)
    exp_fig.update_layout(
        title=f"{title} LLC / Hessian Trace",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=10,
        ),
        yaxis_type="linear",
        width=600,
        height=400,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
        ),
    )
    exp_fig.update_yaxes(title_text="LLC", secondary_y=False)
    exp_fig.update_yaxes(title_text="Hessian Trace", secondary_y=True)
    exp_figures[title] = exp_fig
    exp_fig.show()   

In [121]:
#visualize LLC and val loss

llc_figures = {}

for title in models.keys():
    llc_fig = make_subplots(specs=[[{"secondary_y" : True}]])
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        llc_fig.add_trace(go.Scatter(
            x=epochs,
            y=llc_estimates[title][optim],
            mode="lines",
            name=f"{optim} LLC",
            line=dict(color=color, width=3),
        ), secondary_y=False)
        llc_fig.add_trace(go.Scatter(
            x=epochs,
            y=val_losses[title][optim][1:],
            mode="lines",
            name=f"{optim} Validation Loss",
            line=dict(color=color, dash="dash", width=3)
        ), secondary_y=True)
    llc_fig.update_layout(
        title=f"{title} LLC / Validation Loss",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=10,
        ),
        yaxis_type="linear",
        width=600,
        height=400,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
        ),
    )
    llc_fig.update_yaxes(title_text="LLC", secondary_y=False)
    llc_fig.update_yaxes(title_text="Validation Loss", secondary_y=True)
    llc_figures[title] = llc_fig
    llc_fig.show()

In [122]:
#visualize hessian trace and val loss

#visualize LLC and val loss

hess_figures = {}

for title in models.keys():
    hess_fig = make_subplots(specs=[[{"secondary_y" : True}]])
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        hess_fig.add_trace(go.Scatter(
            x=epochs,
            y=hessian_traces[title][optim][1:],
            mode="lines",
            name=f"{optim} Trace",
            line=dict(color=color, width=3),
        ), secondary_y=False)
        hess_fig.add_trace(go.Scatter(
            x=epochs,
            y=val_losses[title][optim][1:],
            mode="lines",
            name=f"{optim} Validation Loss",
            line=dict(color=color, dash="dash", width=3)
        ), secondary_y=True)
    hess_fig.update_layout(
        title=f"{title} Hessian Trace / Validation Loss",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=10,
        ),
        yaxis_type="linear",
        width=600,
        height=400,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
        ),
    )
    hess_fig.update_yaxes(title_text="Hessian Trace", secondary_y=False)
    hess_fig.update_yaxes(title_text="Validation Loss", secondary_y=True)
    hess_figures[title] = hess_fig
    hess_fig.show()

In [119]:
# Visualise converged eigenspectra for SGD and NGD for each model

# if hp["hessian"]:
#     combined_eigenspectra = {}

#     for title in models.keys():

#         combined_eigenspectrum = go.Figure()
#         final_eigenspectra = {}
#         traces = {}

#         for optim in optimizers:
#             color = color_cycle[color_index % len(color_cycle)]
#             color_index += 1
#             final_eigenspectra[optim] = eigenspectra_figs[title][optim][-2]
#             final_eigenspectra[optim].data[0].name = optim
#             final_eigenspectra[optim].data[0].line.color = color
#             combined_eigenspectrum.add_trace(final_eigenspectra[optim].data[0])

#         combined_eigenspectrum.update_layout(
#             title=f"{title} Hessian eigenspectra at convergence",
#             xaxis_title="Eigenvalue",
#             yaxis_title="Probability density (log scale)",
#             yaxis_type="log",
#         )

#         combined_eigenspectrum.show()
#         combined_eigenspectra[title] = (combined_eigenspectrum)

In [127]:
# Visualise generalisation loss vs. testing loss for each model

val_wbic_figs = {}
for title in models.keys():
    val_wbic_fig = make_subplots(specs=[[{"secondary_y" : True}]])
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        val_wbic_fig.add_trace(go.Scatter(
            x=epochs,
            y=val_losses[title][optim][1:],
            mode="lines",
            name=f"{optim} Validation",
            line=dict(color=color, width=3),
            marker=dict(size=8, symbol="x")
        ),secondary_y=False)
        val_wbic_fig.add_trace(go.Scatter(
            x=epochs,
            y=wbic_estimates[title][optim][1:],
            mode="lines",
            name=f"{optim} WBIC",
            line=dict(color=color, dash="dot", width=4),
            marker=dict(size=8, symbol="x")
        ),secondary_y=True)
    val_wbic_fig.update_layout(
        title=f"{title} Validation Loss / WBIC",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=10,
        ),
        yaxis_type="linear",
        width=600,
        height=400,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
        ),
    )
    val_wbic_fig.update_yaxes(title="Validation Loss", secondary_y=False)
    val_wbic_fig.update_yaxes(title="WBIC", secondary_y=True)
    val_wbic_figs[title] = val_wbic_fig
    val_wbic_fig.show()

In [135]:
#visualize cross correlation
# Visualise generalisation loss vs. testing loss for each model
from scipy import signal

corr_figs = {}
for title in models.keys():
    corr_fig = make_subplots()
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1

        # corr=np.correlate(
        #     val_losses[title][optim][1:], 
        #     wbic_estimates[title][optim][1:], 
        #     mode='full')

        corr=signal.correlate(
            val_losses[title][optim][1:], 
            wbic_estimates[title][optim][1:], mode='full', method='auto')

        corr_fig.add_trace(go.Scatter(
            x=epochs,
            y=corr,
            mode="lines",
            name=f"{optim} Correlation",
            line=dict(color=color, width=3),
            marker=dict(size=8, symbol="x")
        ))

    corr_fig.update_layout(
        title=f"{title} WBIC and Validation Loss Cross Correlation",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=10,
        ),
        yaxis_type="linear",
        width=600,
        height=400,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
        ),
    )
    corr_fig.update_yaxes(title="Cross-Correlation")
    corr_figs[title] = corr_fig
    corr_fig.show()

In [98]:
# Check the LLC chains / WBIC chains converged for each model

rlct_converge_plots = {}
for title in models.keys():
    rlct_converge_plot = make_subplots(specs=[[{"secondary_y" : True}]])
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        if wbic_histories[title][optim][-1] is not None and llc_histories[title][optim][-1] is not None:
            rlct_converge_plot.add_trace(go.Scatter(
                y=np.mean(llc_histories[title][optim][-1]["llc/moving_avg"],axis=0),
                name=f"{optim} LLC",
                line=dict(color=color, width=3)
            ), secondary_y=False)
            rlct_converge_plot.add_trace(go.Scatter(
                y=wbic_histories[title][optim][-1]["wbic/means"],
                name=f"{optim} WBIC",
                line=dict(color=color, dash="dash", width=6)
            ), secondary_y=True)
    rlct_converge_plot.update_layout(
        title=f"LLC / WBIC Trace",
        xaxis_title="Draws",
        font=dict(
            family="Times New Roman",
            size=23,
            color="black",
        ),
        width=900,
        height=600,
        legend=dict(
            x=0,
            y=1,
            xanchor='left',
            yanchor='bottom',
            orientation='h',
            font=dict(
                size=20,
            ),
        ),
    )
    rlct_converge_plot.update_yaxes(title="LLC", secondary_y=False)
    rlct_converge_plot.update_yaxes(title="WBIC", secondary_y=True)
    rlct_converge_plots[title] = rlct_converge_plot
    rlct_converge_plot.show()

In [136]:
# Save the results to a HTML file.

figures = []

combined_args = {**hp, **data_args, **devinterp_args}

if loss_figures:
    figures += list(loss_figures.values())
if update_norm_figures:
    figures += list(update_norm_figures.values())
if exp_figures:
    figures += list(exp_figures.values())
# if combined_eigenspectra:
#     figures += list(combined_eigenspectra.values())
if val_wbic_figs:
    figures += list(val_wbic_figs.values())
if rlct_converge_plots:
    figures += list(rlct_converge_plots.values())

figures+=list(llc_figures.values())
figures+=list(hess_figures.values())
figures+=list(corr_figs.values())

curr_time = datetime.now().strftime("%Y-%m-%d-%H-%M")

write_figs_to_html(
    figs=figures,
    dest=f"./ngd_sgd/{hp['model_type']}_ngd_sgd_rlct_{curr_time}.html",
    title="Does NGD converge to minima that are 'more complex' i.e. have a higher RLCT?",
    summary=combined_args,
)