### **Does NGD converge to more complex minima according to the RLCT (Real Log Canonical Threshold)?**

This notebook aims to test the above claim, using developmental interpretability methods.

We hypothesise that NGD will consistently have a higher RLCT, because it premultiplies the gradient used for gradient descent by the inverse of the Fisher Information Matrix.

SLT proposes that models converge to singularities, where the Fisher Information Matrix is non-invertable. Hence, when near a singularity, the inverse of the FIM will blow up as its determinant is close to 0. Therefore, NGD will "jump away" from the singularity, instead favouring more complex, less singular minima.

#### **Methodology**
- Choose between DNN (dense neural network) and CNN (convolutional neural network)
- Vary hidden nodes and hidden layers for DNN (depth and width)
- Vary number of convolutional layers for CNN.
- Estimate Hessian rank using `PyHessian` module.
- Estimate RLCT using `devinterp` library.

#### **Instructions**

To produce your own results, go to `args_ngd_sgd.json`. The following parameters can be adjusted:
- `model_type (str)`  : `dnn` (deep network), `cnn` (convolutional network)
- `experiment_type (str)` : `standard` (train models independently using SGD and NGD), `swap` (train to convergence with SGD, then continue with SGD + NGD)

<!-- -->

- `optimizer (str)` : `sgd`, `ngd`, or `both` (performs analysis on both optimizers) - SET TO `both` IF DOING `swap` EXPERIMENT
- `hessian (bool)` : `true`, `false` (does Hessian rank analysis if enabled)

<!-- -->

- `hidden_nodes (List : int)` : e.g. `[32, 32, 64, 64, 128, 128]`, the hidden nodes to use in DNN
- `hidden_layers (List: int)` : e.g. `[1, 1, 2, 2, 3, 3]`, the hidden layers to use in DNN. Must have same length as `hidden_nodes`.
- `hidden_conv_layers (List : int)` : array of number of hidden convolutional layers to use for CNN, e.g. `[1, 2, 3]`.

<!-- -->

- `cut_off_epochs (int)` : the epoch number at which the optimiser is swapped if using `experiment_type = "swap"`
- `num_epochs (int)` : total number of training epochs

<!-- -->

- `sgd_lr (float)` : sgd learning rate for training
- `ngd_lr (float)` : ngd learning rate for training

<!-- -->

- `alpha (float)` : smoothing constant for estimation of FIM for NGD
- `eta (float)` : $ F_{n} = \eta S_n + (1 - \eta) F_{n-1} $ where $ S_n $ is the estimated Fisher matrix from the current batch.
- `epsilon (float)` : added to $ F $ to stop it from becoming singular
- `delta (float)` : constant preventing $ F $ from becoming singular

<!-- -->

- `momentum (float)` : momentum used in SGD and NGD
- `nesterov (bool)` : enable Nesterov momentum for NGD / SGD

<!-- -->

- `seed (int)` : the random seed used for `torch.manual_seed()`

<!-- -->

- `batch_size (int)` : batch size for train / validation dataloaders
- `num_workers (int)` : number of GPU workers for data loading (keep this at around 6, vary depending on your hardware)
- `dataset (str)` : `mnist`, `cifar10` (dataset to use for training)
- `num_hessian_batches (int)` : number of batches used for estimation of Hessian

<!-- -->

- `sampler (str)` : `sgld`, `sgnht` (optimiser to use for RLCT estimation)
- `num_chains (int)` : number of chains to use in RLCT estimation (higher leads to more accurate RLCT estimate)
- `num_draws (int)` : number of optimizer steps in RLCT estimation (should be high enough such that chain RLCT converges - check convergence plot)
- `localization (float)` : higher localization more strongly restricts optimizer to neighbourhood of model weights (stops RLCT estimator from going straight to minima)
- `sampler_lr (float)` : learning rate for sampler for RLCT estimation

You can run the notebook from the terminal using the following command:
```bash
jupyter nbconvert --to notebook --execute --inplace ./experiments/eval_ngd_sgd.ipynb
```

#### **0. Import libraries**

The `NGD` module is used for implementing Natural Gradient Descent efficiently. It approximates the Fisher Information Matrix to do this.
The `devinterp` library is used for estimation of the LLC (local learning coefficient).

In [18]:
from multiprocessing import freeze_support

import os
import sys
import copy
import pickle
import pprint
import json
from pathlib import Path
from datetime import datetime
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm

sys.path.append("../")

import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split

from devinterp.slt import estimate_learning_coeff_with_summary
from devinterp.optim import SGLD, SGNHT
from devinterp.slt import sample, OnlineLLCEstimator
from devinterp.slt.wbic import OnlineWBICEstimator
from devinterp.slt.mala import MalaAcceptanceRate
from devinterp.utils import plot_trace, optimal_temperature

from approxngd import KFAC
from PyHessian.pyhessian import *
from PyHessian.density_plot import *
from nngeometry.metrics import FIM
from nngeometry.object import PMatKFAC, PMatDiag, PVector

from utils_general import *
from utils_hessian_fim import *
from architectures import *
from ngd import NGD

from plotly.subplots import make_subplots
import plotly.express as px
import plotly.colors
import plotly.graph_objects as go
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap

device = "cuda" if t.cuda.is_available() else "cpu"
print(f"Device in use: {device}")

%load_ext autoreload
%autoreload
%matplotlib inline


Device in use: cpu
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### **1. Import data for training**

We import our dataset for training. We use a helper function, `build_data_loaders` for this, which allows us to choose between MNIST and CIFAR10. 

We specify three dictionaries, `hp` = hyperparameters, `data_args` = arguments for dataloading, `devinterp_args` = arguments for LLC and WBIC estimation.

In [19]:
# Load experiment args

with open("args_ngd_sgd.json", "r") as file:
    args = json.load(file) 

hp, data_args, devinterp_args = args

print("======HYPERPARAMETERS======")
pprint.pprint(hp, sort_dicts=False)

print("======DATA ARGS======")
pprint.pprint(data_args, sort_dicts=False)

print("======DEVINTERP ARGS======")
pprint.pprint(devinterp_args, sort_dicts=False)

{'model_type': 'ffnn',
 'experiment_type': 'swap',
 'optimizer': 'both',
 'hessian': True,
 'hidden_nodes': [64, 64],
 'hidden_layers': [1, 2],
 'hidden_conv_layers': [0],
 'cut_off_epochs': 3,
 'num_epochs': 5,
 'sgd_lr': 0.01,
 'ngd_lr': 0.01,
 'alpha': 0.01,
 'eta': 0.9,
 'epsilon': 1e-10,
 'delta': 0.0005,
 'momentum': 0.9,
 'nesterov': True,
 'seed': 1}
{'batch_size': 128,
 'num_workers': 6,
 'dataset': 'cifar10',
 'num_hessian_batches': 1}
{'sampler': 'sgld',
 'num_chains': 1,
 'num_draws': 1000,
 'localization': 100.0,
 'sampler_lr': 0.0001}


In [20]:
# Set random seed for weight initialisation (for reproducibility of results and to ensure NGD/SGD start from same point in loss landscape)

t.manual_seed(hp["seed"])

<torch._C.Generator at 0x7f04b905f070>

In [21]:
train_loader, test_loader = build_data_loaders(data_args)

Files already downloaded and verified
Files already downloaded and verified


#### **2. Training models**

Choose to train either a DNN or CNN.

This code produces a dictionary `models` where each key describes the model itself, e.g. "DNN 4 HL, 256 HN".

The values in this dictionary, iterated over as `model`, are each another dictionary with the optimizer as keys
the vakues are then a list containing the model with initial weights. As models are trained, each epoch the trained model will be added to this list so we can record the model history.

```
models={
    "architecture1":{
        "optim1":[epoch0,epoch1],
        "optim2":[epoch0,epoch1],
    },
    "architecture2":{
        "optim1":[epoch0,epoch1],
        "optim2":[epoch0,epoch1],
    }
}
```

In [22]:
# Initialise models dependent on arguments

models = {}
optimizers = ["sgd", "ngd"] if hp["optimizer"] == "both" else [hp["optimizer"]]

if hp["model_type"] == "ffnn":
    hidden_nodes = hp["hidden_nodes"]
    hidden_layers = hp["hidden_layers"]
    for hidden_node, hidden_layer in zip(hidden_nodes, hidden_layers):
        title = f"FFNN {hidden_layer} HL, {hidden_node} HN"
        if data_args["dataset"] == "mnist":
            model = NeuralNet(relu=True,input_size=28*28,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        elif data_args["dataset"] == "cifar10":
            model = NeuralNet(relu=True,input_size=32*32*3,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}
elif hp["model_type"] == "dlnn":
    hidden_nodes = hp["hidden_nodes"]
    hidden_layers = hp["hidden_layers"]
    for hidden_node, hidden_layer in zip(hidden_nodes, hidden_layers):
        title = f"DLNN {hidden_layer} HL, {hidden_node} HN"
        if data_args["dataset"] == "mnist":
            model = NeuralNet(relu=False,input_size=28*28,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        elif data_args["dataset"] == "cifar10":
            model = NeuralNet(relu=False,input_size=32*32*3,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}
elif hp["model_type"] == "cnn":
    hidden_conv_layers = hp["hidden_conv_layers"]
    for hidden_conv_layer in hidden_conv_layers:
        title = f"CNN {hidden_conv_layer} HCL"
        if data_args["dataset"] == "mnist":
            model = LeNet(dataset="mnist",extra_layers=hidden_conv_layer).to(device)
        elif data_args["dataset"] == "cifar10":
            model = LeNet(dataset="cifar10",extra_layers=hidden_conv_layer).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}


If experiment 1, make sure `cutoff_epochs` equal to `num_epochs`

In [23]:
# EXPERIMENT 1: Train models independently for num_epochs with SGD and NGD

if hp["experiment_type"] == "standard":
    train_losses = {}
    val_losses = {}
    update_norms = {}
    all_update_norms = {}

    criterion = nn.CrossEntropyLoss()

    for title, model in models.items():

        # Store list for SGD losses, NGD losses, for train and val
        model_train_losses = {optim : [] for optim in optimizers}
        model_val_losses = {optim : [] for optim in optimizers}
        model_update_norms = {optim : [] for optim in optimizers}
        model_all_update_norms = {optim : [] for optim in optimizers}

        for optim in optimizers:
            state = copy.deepcopy(model[optim][0])
            if optim == "sgd":
                optimizer = t.optim.SGD(
                    params=state.parameters(),
                    lr=hp["sgd_lr"],
                    momentum=hp["momentum"],
                    nesterov=hp["nesterov"],
                )
            elif optim == "ngd":
                optimizer = NGD(
                    params=state.parameters(),
                    lr=hp["ngd_lr"],
                    alpha=hp["alpha"],
                    eta=hp["eta"],
                    epsilon=hp["epsilon"],
                    delta=hp["delta"],
                    momentum=hp["momentum"],
                    nesterov=hp["nesterov"],
                )
            print(f"TRAINING MODEL: {title} | OPTIMISER: {optim}")
            initial_train_loss = evaluate(state, train_loader, criterion, device)
            initial_val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses[optim].append(initial_train_loss)
            model_val_losses[optim].append(initial_val_loss)
            #only train till cutoff epochs
            for epoch in range(1, hp["cut_off_epochs"]+1):
                train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optimizer, criterion, device)
                val_loss = evaluate(state, test_loader, criterion, device)
                model_train_losses[optim].append(train_loss)
                model_update_norms[optim].append(update_norm)
                model_val_losses[optim].append(val_loss)
                model_all_update_norms[optim] += epoch_update_norms
                model[optim].append(copy.deepcopy(state))
                print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")

        # Save train/val loss dictionaries to the right model key
        train_losses[title] = model_train_losses
        val_losses[title] = model_val_losses
        update_norms[title] = model_update_norms
        all_update_norms[title] = model_all_update_norms


In [24]:
# EXPERIMENT 2: Train to convergence with SGD, then continue with SGD and NGD and compare

if hp["experiment_type"] == "swap":
    train_losses = {}
    val_losses = {}
    update_norms = {}
    all_update_norms = {}

    criterion = nn.CrossEntropyLoss()

    for title, model in models.items():

        del model["ngd"][0]

        model_train_losses = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_val_losses = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_update_norms = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_all_update_norms = {"sgd" : [], "ngd" : [None for i in range(len(train_loader)*hp["cut_off_epochs"])]}

        state = copy.deepcopy(model["sgd"][0])
        optimizer = t.optim.SGD(
            params=state.parameters(),
            lr=hp["sgd_lr"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
    
        print(f"TRAINING MODEL: {title} | OPTIMISER: SGD")
        initial_train_loss = evaluate(state, train_loader, criterion, device)
        initial_val_loss = evaluate(state, test_loader, criterion, device)
        model_train_losses["sgd"].append(initial_train_loss)
        model_val_losses["sgd"].append(initial_val_loss)
        for epoch in range(1, hp["cut_off_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optimizer, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["sgd"].append(train_loss)
            model_update_norms["sgd"].append(update_norm)
            model_val_losses["sgd"].append(val_loss)
            model_all_update_norms["sgd"] += epoch_update_norms
            model["sgd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")


        # Train with SGD
        print("Training from model checkpoint with SGD.")
        state = copy.deepcopy(model["sgd"][hp["cut_off_epochs"]])
        optim_sgd = t.optim.SGD(
            params=state.parameters(),
            lr=hp["sgd_lr"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
        for epoch in range(hp["cut_off_epochs"]+1, hp["num_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optim_sgd, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["sgd"].append(train_loss)
            model_update_norms["sgd"].append(update_norm)
            model_val_losses["sgd"].append(val_loss)
            model_all_update_norms["sgd"] += epoch_update_norms
            model["sgd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
        
        # Train with NGD
        print("Training from model checkpoint with NGD.")
        state = copy.deepcopy(model["sgd"][hp["cut_off_epochs"]])
        optim_ngd = NGD(
            params=state.parameters(),
            lr=hp["ngd_lr"],
            alpha=hp["alpha"],
            eta=hp["eta"],
            epsilon=hp["epsilon"],
            delta=hp["delta"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
        for epoch in range(hp["cut_off_epochs"]+1, hp["num_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optim_ngd, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["ngd"].append(train_loss)
            model_update_norms["ngd"].append(update_norm)
            model_val_losses["ngd"].append(val_loss)
            model_all_update_norms["ngd"] += epoch_update_norms
            model["ngd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")

        # Save train/val loss dictionaries to the right model key
        train_losses[title] = model_train_losses
        val_losses[title] = model_val_losses
        update_norms[title] = model_update_norms
        all_update_norms[title] = model_all_update_norms


TRAINING MODEL: FFNN 1 HL, 64 HN | OPTIMISER: SGD


100%|██████████| 391/391 [00:05<00:00, 77.36it/s]


Epoch 1/5: train_loss=1.6631, val_loss=1.5252


100%|██████████| 391/391 [00:04<00:00, 80.94it/s]


Epoch 2/5: train_loss=1.4866, val_loss=1.4982


100%|██████████| 391/391 [00:04<00:00, 79.69it/s]


Epoch 3/5: train_loss=1.4137, val_loss=1.4502
Training from model checkpoint with SGD.


100%|██████████| 391/391 [00:05<00:00, 77.63it/s]


Epoch 4/5: train_loss=1.3623, val_loss=1.4719


100%|██████████| 391/391 [00:05<00:00, 74.42it/s]


Epoch 5/5: train_loss=1.3244, val_loss=1.4515
Training from model checkpoint with NGD.


100%|██████████| 391/391 [00:07<00:00, 51.13it/s]


Epoch 4/5: train_loss=1.3293, val_loss=1.4287


100%|██████████| 391/391 [00:06<00:00, 58.71it/s]


Epoch 5/5: train_loss=1.2670, val_loss=1.4325
TRAINING MODEL: FFNN 2 HL, 64 HN | OPTIMISER: SGD


100%|██████████| 391/391 [00:05<00:00, 72.42it/s]


Epoch 1/5: train_loss=1.7268, val_loss=1.5496


100%|██████████| 391/391 [00:04<00:00, 80.25it/s]


Epoch 2/5: train_loss=1.5011, val_loss=1.4728


100%|██████████| 391/391 [00:06<00:00, 61.91it/s]


Epoch 3/5: train_loss=1.4101, val_loss=1.4402
Training from model checkpoint with SGD.


100%|██████████| 391/391 [00:04<00:00, 83.32it/s]


Epoch 4/5: train_loss=1.3463, val_loss=1.4196


100%|██████████| 391/391 [00:04<00:00, 82.51it/s]


Epoch 5/5: train_loss=1.2995, val_loss=1.4300
Training from model checkpoint with NGD.


100%|██████████| 391/391 [00:07<00:00, 50.36it/s]


Epoch 4/5: train_loss=1.3057, val_loss=1.3750


100%|██████████| 391/391 [00:07<00:00, 54.83it/s]


Epoch 5/5: train_loss=1.2261, val_loss=1.3871


In [25]:
# If we are doing the swap experiment, fill in the first hp["cut_off_epochs"] models with None

if hp["experiment_type"] == "swap":
    none_models = [None for i in range(hp["cut_off_epochs"]+1)]
    for title, model in models.items():
        model["ngd"] = none_models + model["ngd"]

#### **3. Visualising training / validation loss results**

Check that the models all converged.

Displays training and testing data for each model separately, with 4 traces on each graph.

The traces are: SGD training, SGD testing, NGD training, NGD testing.

In [26]:
# Display training / val data for all models for NGD / SGD

epochs = np.arange(1, hp["num_epochs"]+1)

loss_figures = {}

color_cycle = ['rgb(0, 0, 255)', 'rgb(255, 0, 0)']

for title in models.keys():
    loss_fig = make_subplots(specs=[[{"secondary_y" : True}]])
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        loss_fig.add_trace(go.Scatter(
            x=np.arange(0, hp["num_epochs"]+1),
            y=train_losses[title][optim],
            mode="lines+markers",
            line=dict(color=color),
            name=f"{optim} Train",
        ), secondary_y=False)
        loss_fig.add_trace(go.Scatter(
            x=np.arange(0, hp["num_epochs"]+1),
            y=val_losses[title][optim],
            mode="lines+markers",
            line=dict(color=color, dash='dot'),
            name=f"{optim} Validation",
        ), secondary_y=True)
    loss_fig.update_layout(
        title=f"{title} training / validation loss",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=1,
        ),
        yaxis_type="linear",
    )
    loss_fig.update_yaxes(title_text="Training Loss", secondary_y=False)
    loss_fig.update_yaxes(title_text="Validation Loss", secondary_y=True)
    loss_figures[title] = loss_fig
    loss_fig.show()

In [27]:
# Gradient norms over epochs for NGD and SGD

update_norm_figures = {}

color_cycle = ['rgb(0, 0, 255)', 'rgb(255, 0, 0)']

for title in models.keys():
    update_norm_fig = go.Figure()
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        update_norm_fig.add_trace(go.Scatter(
            y=all_update_norms[title][optim][::5],
            mode="lines",
            line=dict(color=color),
            name=f"{optim} update norm"))
    update_norm_fig.update_layout(
        title=f"{title} update norms over epochs",
        xaxis_title="Steps",
        yaxis_title="Update Size",
        yaxis_type="linear",
    )
    update_norm_figures[title] = update_norm_fig
    update_norm_fig.show()

#### **4. Perform Hessian rank computation**

As a way of verifying results produced by the RLCT, we will compute an approximation of the Hessian for each model at convergence. Then, we'll estimate the rank of this matrix using its eigenspectrum. SLT predicts the following: 

$\text{RLCT} \geq \frac{\text{rank}(\textbf{Hess})}{2}$ 

We will check whether this is true for our experiments. Hessian computation is done using helper functions from `utils_hessian_fim.py` which acts as a wrapper for the `PyHessian` module (and the `nngeometry` module, for doing computations involving the Fisher Information Matrix).

In [28]:
# Create Hessians for each model - recall that model is a list containing all its past versions
if hp["hessian"]:
    hessians = {}
    for title, model in models.items():
        hessian = {}
        for optim in optimizers:
            hessians_list = produce_hessians(
                models=model[optim],
                data_loader=train_loader,
                num_batches=data_args["num_hessian_batches"],
                criterion=criterion,
                device=device,
            )
            hessian[optim] = hessians_list
        hessians[title] = hessian


Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak. We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak. (Triggered internally at ../torch/csrc/autograd/engine.cpp:1151.)



In [29]:
# Get the eigenspectum data from the Hessian objects
if hp["hessian"]:
    eigenspectra_data = {}
    eigenspectra_figs = {}
    for title, hessian in hessians.items():
        eigenspectrum_data = {}
        eigenspectrum_figs = {}
        for optim in optimizers:
            eigenspectrum_figs_list, eigenspectrum_data_list = produce_eigenspectra(
                hessians=hessian[optim],
                plot_type="log",
            )
            eigenspectrum_figs[optim] = eigenspectrum_figs_list
            eigenspectrum_data[optim] = eigenspectrum_data_list
        eigenspectra_data[title] = eigenspectrum_data
        eigenspectra_figs[title] = eigenspectrum_figs


Casting complex values to real discards the imaginary part



In [30]:
# Produce the traces of Hessian dimensionality over epochs
if hp["hessian"]:
    hessian_ranks = {}
    for title, eigenspectrum_data in eigenspectra_data.items():
        hessian_rank = {}
        for optim in optimizers:
            hessian_rank_list = find_hessian_dimensionality(eigenspectrum_data[optim])
            hessian_rank[optim] = hessian_rank_list
        hessian_ranks[title] = hessian_rank

#### **5. Perform RLCT estimation**

Using the `devinterp` library, we perform estimation of the RLCT (Real Log Canonical Threshold) or otherwise known as the LLC (Local Learning Coefficient).

`rlct_estimates` is a dictionary containing dictionaries, each of which contain two lists, one for SGD RLCT values over epochs, and one for NGD RLCT values over epochs. The same is true for `wbic_estimates`.

In [31]:
rlct_estimates = {}
histories = {}

for title, model in models.items():
    history = {}
    rlct_estimate = {}
    for optim in optimizers:
        rlct_list, history_list = estimate_rlcts(
            model[optim], train_loader, criterion, device, devinterp_args,
        )
        rlct_estimate[optim] = rlct_list
        history[optim] = history_list
    rlct_estimates[title] = rlct_estimate 
    histories[title] = history


You are taking more draws than burn-in steps, your LLC estimates will likely be underestimates. Please check LLC chain convergence.


You are taking more sample batches than there are dataloader batches available, this removes some randomness from sampling but is probably fine. (All sample batches beyond the number dataloader batches are cycled from the start, f.e. 9 samples from [A, B, C] would be [B, A, C, B, A, C, B, A, C].)


You are taking more sample batches than there are dataloader batches available, this removes some randomness from sampling but is probably fine. (All sample batches beyond the number dataloader batches are cycled from the start, f.e. 9 samples from [A, B, C] would be [B, A, C, B, A, C, B, A, C].)

Chain 0: 100%|██████████| 1000/1000 [00:08<00:00, 116.72it/s]
Chain 0: 100%|██████████| 1000/1000 [00:08<00:00, 121.74it/s]
Chain 0: 100%|██████████| 1000/1000 [00:07<00:00, 129.52it/s]
Chain 0: 100%|██████████| 1000/1000 [00:07<00:00, 125.35it/s]
Chain 0: 100%|████

In [32]:
# Compute generalisation losses for each model, for SGD and NGD, so we can compare this to the testing loss

gen_losses = {}
for title in models.keys():
    gen_loss = {}
    for optim in optimizers:
        gen_loss_list = []
        for i in range(hp["num_epochs"]):
            if histories[title][optim][i] is None:
                gen_loss_list.append(None)
            else:
                gen_loss_list.append(train_losses[title][optim][i] + histories[title][optim][i]["llc/moving_avg"][-1][-1]/data_args["batch_size"])
        gen_loss[optim] = gen_loss_list
    gen_losses[title] = gen_loss

#### **6. Visualise RLCT / Hessian rank, Hessian eigenspectra, generalisation loss, RLCT convergence**

We display the following final figures:
- RLCT and Hessian rank evolution for each model, for NGD and SGD
- Hessian eigenspectra for SGD and NGD overlaid on the same plot
- Evolution of generalisation loss for each model, compared to validation loss
- RLCT moving average evolution for each model, for NGD and SGD, to check for convergence

In [33]:
# Visualise RLCT and Hessian rank data

exp_figures = {}

for title in models.keys():
    exp_fig = make_subplots(specs=[[{"secondary_y" : True}]])
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        exp_fig.add_trace(go.Scatter(
            x=epochs,
            y=rlct_estimates[title][optim][1:],
            mode="lines+markers",
            name=f"{optim} RLCT",
            line=dict(color=color),
        ), secondary_y=False)
        if hp["hessian"]:
            exp_fig.add_trace(go.Scatter(
                x=epochs,
                y=hessian_ranks[title][optim][1:],
                mode="lines+markers",
                name=f"{optim} Hessian Rank",
                line=dict(color=color, dash="dot"),
            ), secondary_y=True)
    exp_fig.update_layout(
        title=f"{title} RLCT / Hessian Rank (optional) evolution during training",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=1,
        ),
        yaxis_type="linear",
    )
    exp_fig.update_yaxes(title_text="RLCT", secondary_y=False)
    exp_fig.update_yaxes(title_text="Hessian rank", secondary_y=True)
    exp_figures[title] = exp_fig
    exp_fig.show()   

In [34]:
# Visualise converged eigenspectra for SGD and NGD for each model

combined_eigenspectra = {}

for title in models.keys():

    combined_eigenspectrum = go.Figure()
    final_eigenspectra = {}
    traces = {}

    for optim in optimizers:
        final_eigenspectra[optim] = eigenspectra_figs[title][optim][-2]
        final_eigenspectra[optim].data[0].name = optim
        combined_eigenspectrum.add_trace(final_eigenspectra[optim].data[0])

    combined_eigenspectrum.update_layout(
        title=f"{title} Hessian eigenspectra at convergence",
        xaxis_title="Eigenvalue",
        yaxis_title="Probability density (log scale)",
        yaxis_type="log",
    )

    combined_eigenspectrum.show()
    combined_eigenspectra[title] = (combined_eigenspectrum)

In [36]:
# Visualise generalisation loss vs. testing loss for each model

train_gen_figs = {}
for title in models.keys():
    train_gen_fig = go.Figure()
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        train_gen_fig.add_trace(go.Scatter(
            x=epochs,
            y=val_losses[title][optim][1:],
            mode="lines+markers",
            name=f"{optim} Validation",
            line=dict(color=color),
        ))
        train_gen_fig.add_trace(go.Scatter(
            x=epochs,
            y=gen_losses[title][optim][1:],
            mode="lines+markers",
            name=f"{optim} Generalisation",
            line=dict(color=color, dash="dot"),
        ))
    train_gen_fig.update_layout(
        title=f"{title} validation / generalisation loss",
        xaxis_title="Epoch",
        yaxis_title="Loss",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=1,
        ),
        yaxis_type="linear",
    )
    train_gen_figs[title] = train_gen_fig
    train_gen_fig.show()

In [37]:
# Check the LLC chains converged for each model

rlct_converge_plots = {}
for title in models.keys():
    rlct_converge_plot = go.Figure()
    for epoch in epochs:
        for optim in optimizers:
            if histories[title][optim][epoch] is not None:
                rlct_converge_plot.add_trace(go.Scatter(
                    y=histories[title][optim][epoch]["llc/moving_avg"][0],
                    name=f"{optim} Epoch {epoch}",
                ))
    rlct_converge_plot.update_layout(
        title=f"Evolution of LLC moving average for each model over epochs for {title}",
        xaxis_title="Draws",
        yaxis_title="RLCT",
        yaxis_type="linear",
        legend_title="Epoch"
    )
    rlct_converge_plots[title] = rlct_converge_plot
    rlct_converge_plot.show()

In [89]:
# Save the results to a HTML file.

figures = []

combined_args = {**hp, **data_args, **devinterp_args}

#summary = pprint.pformat(combined_args, sort_dicts=False)

for fig in loss_figures.values():
    figures.append(fig)
for fig in update_norm_figures.values():
    figures.append(fig)
for fig in exp_figures.values():
    figures.append(fig)
for fig in combined_eigenspectra.values():
    figures.append(fig)
for fig in train_gen_figs.values():
    figures.append(fig)
for fig in rlct_converge_plots.values():
    figures.append(fig)

curr_time = datetime.now().strftime("%Y-%m-%d-%H-%M")

write_figs_to_html(
    figs=figures,
    dest=f"./ngd_sgd/dnn_ngd_sgd_rlct_{curr_time}.html",
    title="Does NGD converge to minima that are 'more complex' i.e. have a higher RLCT?",
    summary=combined_args,
)