### **Does NGD converge to more complex minima according to the RLCT (Real Log Canonical Threshold)?**

This notebook aims to test the above claim, using developmental interpretability methods.

We hypothesise that NGD will consistently have a higher RLCT, because it premultiplies the gradient used for gradient descent by the inverse of the Fisher Information Matrix.

SLT proposes that models converge to singularities, where the Fisher Information Matrix is non-invertable. Hence, when near a singularity, the inverse of the FIM will blow up as its determinant is close to 0. Therefore, NGD will "jump away" from the singularity, instead favouring more complex, less singular minima.

#### **Methodology**
- Choose between DNN (dense neural network) and CNN (convolutional neural network)
- Vary hidden nodes and hidden layers for DNN (depth and width)
- Vary number of convolutional layers for CNN.
- Estimate Hessian rank using `PyHessian` module.
- Estimate RLCT using `devinterp` library.

#### **Instructions**

To produce your own results, go to `args_ngd_sgd.json`. The following parameters can be adjusted:
- `model_type (str)`  : `dnn` (deep network), `cnn` (convolutional network)
- `experiment_type (str)` : `standard` (train models independently using SGD and NGD), `swap` (train to convergence with SGD, then continue with SGD + NGD)

<!-- -->

- `optimizer (str)` : `sgd`, `ngd`, or `both` (performs analysis on both optimizers) - SET TO `both` IF DOING `swap` EXPERIMENT
- `hessian (bool)` : `true`, `false` (does Hessian rank analysis if enabled)

<!-- -->

- `hidden_nodes (List : int)` : e.g. `[32, 32, 64, 64, 128, 128]`, the hidden nodes to use in DNN
- `hidden_layers (List: int)` : e.g. `[1, 1, 2, 2, 3, 3]`, the hidden layers to use in DNN. Must have same length as `hidden_nodes`.
- `hidden_conv_layers (List : int)` : array of number of hidden convolutional layers to use for CNN, e.g. `[1, 2, 3]`.

<!-- -->

- `cut_off_epochs (int)` : the epoch number at which the optimiser is swapped if using `experiment_type = "swap"`
- `num_epochs (int)` : total number of training epochs

<!-- -->

- `sgd_lr (float)` : sgd learning rate for training
- `ngd_lr (float)` : ngd learning rate for training

<!-- -->

- `alpha (float)` : smoothing constant for estimation of FIM for NGD
- `eta (float)` : $ F_{n} = \eta S_n + (1 - \eta) F_{n-1} $ where $ S_n $ is the estimated Fisher matrix from the current batch.
- `epsilon (float)` : added to $ F $ to stop it from becoming singular
- `delta (float)` : constant preventing $ F $ from becoming singular

<!-- -->

- `momentum (float)` : momentum used in SGD and NGD
- `nesterov (bool)` : enable Nesterov momentum for NGD / SGD

<!-- -->

- `seed (int)` : the random seed used for `torch.manual_seed()`

<!-- -->

- `batch_size (int)` : batch size for train / validation dataloaders
- `num_workers (int)` : number of GPU workers for data loading (keep this at around 6, vary depending on your hardware)
- `dataset (str)` : `mnist`, `cifar10` (dataset to use for training)
- `num_hessian_batches (int)` : number of batches used for estimation of Hessian

<!-- -->

- `sampler (str)` : `sgld`, `sgnht` (optimiser to use for RLCT estimation)
- `num_chains (int)` : number of chains to use in RLCT estimation (higher leads to more accurate RLCT estimate)
- `num_draws (int)` : number of optimizer steps in RLCT estimation (should be high enough such that chain RLCT converges - check convergence plot)
- `localization (float)` : higher localization more strongly restricts optimizer to neighbourhood of model weights (stops RLCT estimator from going straight to minima)
- `sampler_lr (float)` : learning rate for sampler for RLCT estimation

You can run the notebook from the terminal using the following command:
```bash
jupyter nbconvert --to notebook --execute --inplace ./experiments/eval_ngd_sgd.ipynb
```

#### **0. Import libraries**

The `NGD` module is used for implementing Natural Gradient Descent efficiently. It approximates the Fisher Information Matrix to do this.
The `devinterp` library is used for estimation of the LLC (local learning coefficient).

In [22]:
from multiprocessing import freeze_support

import os
import sys
import copy
import pickle
import pprint
import json
from pathlib import Path
from datetime import datetime
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm

sys.path.append("../")

import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split

from devinterp.slt import estimate_learning_coeff_with_summary
from devinterp.optim import SGLD, SGNHT
from devinterp.slt import sample, OnlineLLCEstimator
from devinterp.slt.wbic import OnlineWBICEstimator
from devinterp.slt.mala import MalaAcceptanceRate
from devinterp.utils import plot_trace, optimal_temperature

from PyHessian.pyhessian import *
from PyHessian.density_plot import *
from nngeometry.metrics import FIM
from nngeometry.object import PMatKFAC, PMatDiag, PVector

from utils_general import *
from utils_hessian_fim import *
from architectures import *
from ngd import NGD

from plotly.subplots import make_subplots
import plotly.express as px
import plotly.colors
import plotly.graph_objects as go
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap

device = "cuda" if t.cuda.is_available() else "cpu"
print(f"Device in use: {device}")

%load_ext autoreload
%autoreload
%matplotlib inline


Device in use: cuda
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### **1. Import data for training**

We import our dataset for training. We use a helper function, `build_data_loaders` for this, which allows us to choose between MNIST and CIFAR10. 

We specify three dictionaries, `hp` = hyperparameters, `data_args` = arguments for dataloading, `devinterp_args` = arguments for LLC and WBIC estimation.

In [23]:
# Load experiment args

with open("args_ngd_sgd.json", "r") as file:
    args = json.load(file) 

hp, data_args, devinterp_args = args

print("======HYPERPARAMETERS======")
pprint.pprint(hp, sort_dicts=False)

print("======DATA ARGS======")
pprint.pprint(data_args, sort_dicts=False)

print("======DEVINTERP ARGS======")
pprint.pprint(devinterp_args, sort_dicts=False)

{'model_type': 'ffnn',
 'experiment_type': 'standard',
 'optimizer': 'both',
 'hessian': False,
 'hidden_nodes': [64, 128, 256, 64, 128, 256],
 'hidden_layers': [1, 1, 1, 2, 2, 2],
 'hidden_conv_layers': [0],
 'cut_off_epochs': 20,
 'num_epochs': 20,
 'sgd_lr': 0.01,
 'ngd_lr': 0.01,
 'alpha': 0.01,
 'eta': 0.9,
 'epsilon': 1e-10,
 'delta': 0.0005,
 'momentum': 0.9,
 'nesterov': True,
 'seed': 5}
{'batch_size': 128,
 'num_workers': 64,
 'dataset': 'mnist',
 'num_hessian_batches': 1}
{'sampler': 'sgld',
 'num_chains': 1,
 'num_draws': 1000,
 'localization': 100.0,
 'sampler_lr': 0.0001}


In [24]:
# Set random seed for weight initialisation (for reproducibility of results and to ensure NGD/SGD start from same point in loss landscape)

t.manual_seed(hp["seed"])

<torch._C.Generator at 0x7fb4ea627fb0>

In [25]:
train_loader, test_loader = build_data_loaders(data_args)


This DataLoader will create 64 worker processes in total. Our suggested max number of worker in current system is 28, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



#### **2. Training models**

Choose to train either a DNN or CNN.

This code produces a dictionary `models` where each key describes the model itself, e.g. "DNN 4 HL, 256 HN".

The values in this dictionary, iterated over as `model`, are each another dictionary with the optimizer as keys
the vakues are then a list containing the model with initial weights. As models are trained, each epoch the trained model will be added to this list so we can record the model history.

```
models={
    "architecture1":{
        "optim1":[epoch0,epoch1],
        "optim2":[epoch0,epoch1],
    },
    "architecture2":{
        "optim1":[epoch0,epoch1],
        "optim2":[epoch0,epoch1],
    }
}
```

In [26]:
# Initialise models dependent on arguments

models = {}
optimizers = ["sgd", "ngd"] if hp["optimizer"] == "both" else [hp["optimizer"]]

if hp["model_type"] == "ffnn":
    hidden_nodes = hp["hidden_nodes"]
    hidden_layers = hp["hidden_layers"]
    for hidden_node, hidden_layer in zip(hidden_nodes, hidden_layers):
        title = f"FFNN {hidden_layer} HL, {hidden_node} HN"
        if data_args["dataset"] == "mnist":
            model = NeuralNet(relu=True,input_size=28*28,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        elif data_args["dataset"] == "cifar10":
            model = NeuralNet(relu=True,input_size=32*32*3,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}
elif hp["model_type"] == "dlnn":
    hidden_nodes = hp["hidden_nodes"]
    hidden_layers = hp["hidden_layers"]
    for hidden_node, hidden_layer in zip(hidden_nodes, hidden_layers):
        title = f"DLNN {hidden_layer} HL, {hidden_node} HN"
        if data_args["dataset"] == "mnist":
            model = NeuralNet(relu=False,input_size=28*28,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        elif data_args["dataset"] == "cifar10":
            model = NeuralNet(relu=False,input_size=32*32*3,hidden_layers=hidden_layer, hidden_nodes=hidden_node).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}
elif hp["model_type"] == "cnn":
    hidden_conv_layers = hp["hidden_conv_layers"]
    for hidden_conv_layer in hidden_conv_layers:
        title = f"CNN {hidden_conv_layer} HCL"
        if data_args["dataset"] == "mnist":
            model = LeNet(dataset="mnist",extra_layers=hidden_conv_layer).to(device)
        elif data_args["dataset"] == "cifar10":
            model = LeNet(dataset="cifar10",extra_layers=hidden_conv_layer).to(device)
        models[title] = {optim : [copy.deepcopy(model)] for optim in optimizers}


If experiment 1, make sure `cutoff_epochs` equal to `num_epochs`

In [27]:
# EXPERIMENT 1: Train models independently for num_epochs with SGD and NGD

if hp["experiment_type"] == "standard":
    train_losses = {}
    val_losses = {}
    update_norms = {}
    all_update_norms = {}

    criterion = nn.CrossEntropyLoss()

    for title, model in models.items():

        # Store list for SGD losses, NGD losses, for train and val
        model_train_losses = {optim : [] for optim in optimizers}
        model_val_losses = {optim : [] for optim in optimizers}
        model_update_norms = {optim : [] for optim in optimizers}
        model_all_update_norms = {optim : [] for optim in optimizers}

        for optim in optimizers:
            state = copy.deepcopy(model[optim][0])
            if optim == "sgd":
                optimizer = t.optim.SGD(
                    params=state.parameters(),
                    lr=hp["sgd_lr"],
                    momentum=hp["momentum"],
                    nesterov=hp["nesterov"],
                )
            elif optim == "ngd":
                optimizer = NGD(
                    params=state.parameters(),
                    lr=hp["ngd_lr"],
                    alpha=hp["alpha"],
                    eta=hp["eta"],
                    epsilon=hp["epsilon"],
                    delta=hp["delta"],
                    momentum=hp["momentum"],
                    nesterov=hp["nesterov"],
                )
            print(f"TRAINING MODEL: {title} | OPTIMISER: {optim}")
            initial_train_loss = evaluate(state, train_loader, criterion, device)
            initial_val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses[optim].append(initial_train_loss)
            model_val_losses[optim].append(initial_val_loss)
            #only train till cutoff epochs
            for epoch in range(1, hp["cut_off_epochs"]+1):
                train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optimizer, criterion, device)
                val_loss = evaluate(state, test_loader, criterion, device)
                model_train_losses[optim].append(train_loss)
                model_update_norms[optim].append(update_norm)
                model_val_losses[optim].append(val_loss)
                #note that epoch_update norms is a list with the update_norm of each minibatch, this will concat the list to previous list
                model_all_update_norms[optim] += epoch_update_norms
                model[optim].append(copy.deepcopy(state))
                print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")

        # Save train/val loss dictionaries to the right model key
        train_losses[title] = model_train_losses
        val_losses[title] = model_val_losses
        update_norms[title] = model_update_norms
        all_update_norms[title] = model_all_update_norms


TRAINING MODEL: FFNN 1 HL, 64 HN | OPTIMISER: sgd


100%|██████████| 469/469 [00:02<00:00, 216.01it/s]


Epoch 1/20: train_loss=0.3746, val_loss=0.2044


100%|██████████| 469/469 [00:02<00:00, 224.36it/s]


Epoch 2/20: train_loss=0.1752, val_loss=0.1405


100%|██████████| 469/469 [00:02<00:00, 225.83it/s]


Epoch 3/20: train_loss=0.1285, val_loss=0.1173


100%|██████████| 469/469 [00:02<00:00, 230.03it/s]


Epoch 4/20: train_loss=0.1036, val_loss=0.1039


100%|██████████| 469/469 [00:02<00:00, 222.83it/s]


Epoch 5/20: train_loss=0.0876, val_loss=0.0967


100%|██████████| 469/469 [00:02<00:00, 222.46it/s]


Epoch 6/20: train_loss=0.0760, val_loss=0.0914


100%|██████████| 469/469 [00:02<00:00, 223.54it/s]


Epoch 7/20: train_loss=0.0666, val_loss=0.0849


100%|██████████| 469/469 [00:02<00:00, 226.44it/s]


Epoch 8/20: train_loss=0.0601, val_loss=0.0827


100%|██████████| 469/469 [00:02<00:00, 226.81it/s]


Epoch 9/20: train_loss=0.0532, val_loss=0.0803


100%|██████████| 469/469 [00:02<00:00, 234.46it/s]


Epoch 10/20: train_loss=0.0481, val_loss=0.0809


100%|██████████| 469/469 [00:02<00:00, 226.93it/s]


Epoch 11/20: train_loss=0.0444, val_loss=0.0789


100%|██████████| 469/469 [00:02<00:00, 227.47it/s]


Epoch 12/20: train_loss=0.0402, val_loss=0.0766


100%|██████████| 469/469 [00:02<00:00, 227.84it/s]


Epoch 13/20: train_loss=0.0370, val_loss=0.0791


100%|██████████| 469/469 [00:02<00:00, 229.11it/s]


Epoch 14/20: train_loss=0.0337, val_loss=0.0826


100%|██████████| 469/469 [00:02<00:00, 233.62it/s]


Epoch 15/20: train_loss=0.0311, val_loss=0.0800


100%|██████████| 469/469 [00:02<00:00, 231.35it/s]


Epoch 16/20: train_loss=0.0282, val_loss=0.0804


100%|██████████| 469/469 [00:02<00:00, 223.67it/s]


Epoch 17/20: train_loss=0.0261, val_loss=0.0751


100%|██████████| 469/469 [00:02<00:00, 225.87it/s]


Epoch 18/20: train_loss=0.0241, val_loss=0.0770


100%|██████████| 469/469 [00:02<00:00, 206.28it/s]


Epoch 19/20: train_loss=0.0216, val_loss=0.0779


100%|██████████| 469/469 [00:02<00:00, 198.45it/s]


Epoch 20/20: train_loss=0.0203, val_loss=0.0760
TRAINING MODEL: FFNN 1 HL, 64 HN | OPTIMISER: ngd


100%|██████████| 469/469 [00:04<00:00, 109.10it/s]


Epoch 1/20: train_loss=0.7574, val_loss=0.2148


100%|██████████| 469/469 [00:04<00:00, 109.10it/s]


Epoch 2/20: train_loss=0.1848, val_loss=0.1654


100%|██████████| 469/469 [00:04<00:00, 116.84it/s]


Epoch 3/20: train_loss=0.1488, val_loss=0.1485


100%|██████████| 469/469 [00:04<00:00, 114.87it/s]


Epoch 4/20: train_loss=0.1256, val_loss=0.1360


100%|██████████| 469/469 [00:04<00:00, 115.15it/s]


Epoch 5/20: train_loss=0.1074, val_loss=0.1259


100%|██████████| 469/469 [00:04<00:00, 111.10it/s]


Epoch 6/20: train_loss=0.0932, val_loss=0.1181


100%|██████████| 469/469 [00:04<00:00, 113.12it/s]


Epoch 7/20: train_loss=0.0803, val_loss=0.1128


100%|██████████| 469/469 [00:04<00:00, 109.33it/s]


Epoch 8/20: train_loss=0.0727, val_loss=0.1116


100%|██████████| 469/469 [00:04<00:00, 115.18it/s]


Epoch 9/20: train_loss=0.0690, val_loss=0.1103


100%|██████████| 469/469 [00:04<00:00, 115.86it/s]


Epoch 10/20: train_loss=0.0645, val_loss=0.1079


100%|██████████| 469/469 [00:04<00:00, 110.79it/s]


Epoch 11/20: train_loss=0.0614, val_loss=0.1162


100%|██████████| 469/469 [00:04<00:00, 112.34it/s]


Epoch 12/20: train_loss=0.0586, val_loss=0.1066


100%|██████████| 469/469 [00:04<00:00, 110.69it/s]


Epoch 13/20: train_loss=0.0564, val_loss=0.1154


100%|██████████| 469/469 [00:04<00:00, 113.35it/s]


Epoch 14/20: train_loss=0.0555, val_loss=0.1233


100%|██████████| 469/469 [00:04<00:00, 111.09it/s]


Epoch 15/20: train_loss=0.0549, val_loss=0.1232


100%|██████████| 469/469 [00:04<00:00, 111.95it/s]


Epoch 16/20: train_loss=0.0511, val_loss=0.1319


100%|██████████| 469/469 [00:04<00:00, 116.16it/s]


Epoch 17/20: train_loss=0.0506, val_loss=0.1305


100%|██████████| 469/469 [00:04<00:00, 114.43it/s]


Epoch 18/20: train_loss=0.0499, val_loss=0.1339


100%|██████████| 469/469 [00:04<00:00, 109.00it/s]


Epoch 19/20: train_loss=0.0514, val_loss=0.1318


100%|██████████| 469/469 [00:04<00:00, 115.70it/s]


Epoch 20/20: train_loss=0.0511, val_loss=0.1408
TRAINING MODEL: FFNN 1 HL, 128 HN | OPTIMISER: sgd


100%|██████████| 469/469 [00:02<00:00, 221.35it/s]


Epoch 1/20: train_loss=0.3537, val_loss=0.2003


100%|██████████| 469/469 [00:02<00:00, 227.67it/s]


Epoch 2/20: train_loss=0.1653, val_loss=0.1368


100%|██████████| 469/469 [00:02<00:00, 226.35it/s]


Epoch 3/20: train_loss=0.1185, val_loss=0.1078


100%|██████████| 469/469 [00:02<00:00, 227.92it/s]


Epoch 4/20: train_loss=0.0932, val_loss=0.0930


100%|██████████| 469/469 [00:02<00:00, 220.20it/s]


Epoch 5/20: train_loss=0.0765, val_loss=0.0841


100%|██████████| 469/469 [00:02<00:00, 225.64it/s]


Epoch 6/20: train_loss=0.0645, val_loss=0.0793


100%|██████████| 469/469 [00:02<00:00, 233.97it/s]


Epoch 7/20: train_loss=0.0551, val_loss=0.0768


100%|██████████| 469/469 [00:02<00:00, 231.11it/s]


Epoch 8/20: train_loss=0.0477, val_loss=0.0741


100%|██████████| 469/469 [00:02<00:00, 226.81it/s]


Epoch 9/20: train_loss=0.0417, val_loss=0.0694


100%|██████████| 469/469 [00:02<00:00, 220.38it/s]


Epoch 10/20: train_loss=0.0369, val_loss=0.0667


100%|██████████| 469/469 [00:02<00:00, 226.30it/s]


Epoch 11/20: train_loss=0.0323, val_loss=0.0661


100%|██████████| 469/469 [00:02<00:00, 226.33it/s]


Epoch 12/20: train_loss=0.0286, val_loss=0.0664


100%|██████████| 469/469 [00:02<00:00, 223.25it/s]


Epoch 13/20: train_loss=0.0254, val_loss=0.0640


100%|██████████| 469/469 [00:02<00:00, 228.32it/s]


Epoch 14/20: train_loss=0.0226, val_loss=0.0658


100%|██████████| 469/469 [00:02<00:00, 218.75it/s]


Epoch 15/20: train_loss=0.0200, val_loss=0.0645


100%|██████████| 469/469 [00:02<00:00, 222.36it/s]


Epoch 16/20: train_loss=0.0177, val_loss=0.0635


100%|██████████| 469/469 [00:02<00:00, 227.66it/s]


Epoch 17/20: train_loss=0.0159, val_loss=0.0647


100%|██████████| 469/469 [00:02<00:00, 231.88it/s]


Epoch 18/20: train_loss=0.0143, val_loss=0.0635


100%|██████████| 469/469 [00:02<00:00, 223.06it/s]


Epoch 19/20: train_loss=0.0129, val_loss=0.0631


100%|██████████| 469/469 [00:02<00:00, 219.28it/s]


Epoch 20/20: train_loss=0.0115, val_loss=0.0625
TRAINING MODEL: FFNN 1 HL, 128 HN | OPTIMISER: ngd


100%|██████████| 469/469 [00:04<00:00, 109.35it/s]


Epoch 1/20: train_loss=0.7152, val_loss=0.2002


100%|██████████| 469/469 [00:04<00:00, 109.77it/s]


Epoch 2/20: train_loss=0.1565, val_loss=0.1377


100%|██████████| 469/469 [00:04<00:00, 109.82it/s]


Epoch 3/20: train_loss=0.1136, val_loss=0.1161


100%|██████████| 469/469 [00:04<00:00, 106.91it/s]


Epoch 4/20: train_loss=0.0930, val_loss=0.1053


100%|██████████| 469/469 [00:04<00:00, 111.37it/s]


Epoch 5/20: train_loss=0.0768, val_loss=0.0961


100%|██████████| 469/469 [00:04<00:00, 111.32it/s]


Epoch 6/20: train_loss=0.0650, val_loss=0.0890


100%|██████████| 469/469 [00:04<00:00, 114.13it/s]


Epoch 7/20: train_loss=0.0560, val_loss=0.0885


100%|██████████| 469/469 [00:04<00:00, 109.94it/s]


Epoch 8/20: train_loss=0.0503, val_loss=0.0836


100%|██████████| 469/469 [00:04<00:00, 109.96it/s]


Epoch 9/20: train_loss=0.0448, val_loss=0.0834


100%|██████████| 469/469 [00:04<00:00, 112.16it/s]


Epoch 10/20: train_loss=0.0402, val_loss=0.0855


100%|██████████| 469/469 [00:04<00:00, 110.30it/s]


Epoch 11/20: train_loss=0.0373, val_loss=0.0848


100%|██████████| 469/469 [00:04<00:00, 107.30it/s]


Epoch 12/20: train_loss=0.0348, val_loss=0.0840


100%|██████████| 469/469 [00:04<00:00, 109.13it/s]


Epoch 13/20: train_loss=0.0325, val_loss=0.0819


100%|██████████| 469/469 [00:04<00:00, 111.45it/s]


Epoch 14/20: train_loss=0.0294, val_loss=0.0846


100%|██████████| 469/469 [00:04<00:00, 115.79it/s]


Epoch 15/20: train_loss=0.0284, val_loss=0.0826


100%|██████████| 469/469 [00:04<00:00, 110.92it/s]


Epoch 16/20: train_loss=0.0259, val_loss=0.0809


100%|██████████| 469/469 [00:04<00:00, 106.78it/s]


Epoch 17/20: train_loss=0.0225, val_loss=0.0827


100%|██████████| 469/469 [00:04<00:00, 114.64it/s]


Epoch 18/20: train_loss=0.0209, val_loss=0.0814


100%|██████████| 469/469 [00:04<00:00, 110.73it/s]


Epoch 19/20: train_loss=0.0192, val_loss=0.0825


100%|██████████| 469/469 [00:04<00:00, 112.52it/s]


Epoch 20/20: train_loss=0.0179, val_loss=0.0802
TRAINING MODEL: FFNN 1 HL, 256 HN | OPTIMISER: sgd


100%|██████████| 469/469 [00:01<00:00, 238.55it/s]


Epoch 1/20: train_loss=0.3416, val_loss=0.1761


100%|██████████| 469/469 [00:02<00:00, 229.03it/s]


Epoch 2/20: train_loss=0.1540, val_loss=0.1253


100%|██████████| 469/469 [00:02<00:00, 224.15it/s]


Epoch 3/20: train_loss=0.1090, val_loss=0.0981


100%|██████████| 469/469 [00:02<00:00, 230.65it/s]


Epoch 4/20: train_loss=0.0838, val_loss=0.0874


100%|██████████| 469/469 [00:02<00:00, 228.86it/s]


Epoch 5/20: train_loss=0.0681, val_loss=0.0786


100%|██████████| 469/469 [00:02<00:00, 228.50it/s]


Epoch 6/20: train_loss=0.0562, val_loss=0.0718


100%|██████████| 469/469 [00:02<00:00, 216.92it/s]


Epoch 7/20: train_loss=0.0473, val_loss=0.0712


100%|██████████| 469/469 [00:02<00:00, 225.24it/s]


Epoch 8/20: train_loss=0.0408, val_loss=0.0642


100%|██████████| 469/469 [00:02<00:00, 226.19it/s]


Epoch 9/20: train_loss=0.0348, val_loss=0.0647


100%|██████████| 469/469 [00:02<00:00, 220.97it/s]


Epoch 10/20: train_loss=0.0298, val_loss=0.0610


100%|██████████| 469/469 [00:02<00:00, 226.71it/s]


Epoch 11/20: train_loss=0.0256, val_loss=0.0604


100%|██████████| 469/469 [00:02<00:00, 227.63it/s]


Epoch 12/20: train_loss=0.0223, val_loss=0.0604


100%|██████████| 469/469 [00:02<00:00, 223.44it/s]


Epoch 13/20: train_loss=0.0196, val_loss=0.0587


100%|██████████| 469/469 [00:02<00:00, 220.05it/s]


Epoch 14/20: train_loss=0.0172, val_loss=0.0578


100%|██████████| 469/469 [00:02<00:00, 226.78it/s]


Epoch 15/20: train_loss=0.0153, val_loss=0.0581


100%|██████████| 469/469 [00:02<00:00, 222.80it/s]


Epoch 16/20: train_loss=0.0135, val_loss=0.0576


100%|██████████| 469/469 [00:02<00:00, 231.85it/s]


Epoch 17/20: train_loss=0.0120, val_loss=0.0588


100%|██████████| 469/469 [00:02<00:00, 230.22it/s]


Epoch 18/20: train_loss=0.0107, val_loss=0.0592


100%|██████████| 469/469 [00:02<00:00, 227.45it/s]


Epoch 19/20: train_loss=0.0096, val_loss=0.0579


100%|██████████| 469/469 [00:02<00:00, 232.03it/s]


Epoch 20/20: train_loss=0.0086, val_loss=0.0590
TRAINING MODEL: FFNN 1 HL, 256 HN | OPTIMISER: ngd


100%|██████████| 469/469 [00:04<00:00, 107.84it/s]


Epoch 1/20: train_loss=0.8522, val_loss=0.2145


100%|██████████| 469/469 [00:04<00:00, 109.83it/s]


Epoch 2/20: train_loss=0.1681, val_loss=0.1419


100%|██████████| 469/469 [00:04<00:00, 108.56it/s]


Epoch 3/20: train_loss=0.1173, val_loss=0.1175


100%|██████████| 469/469 [00:04<00:00, 115.01it/s]


Epoch 4/20: train_loss=0.0933, val_loss=0.1048


100%|██████████| 469/469 [00:04<00:00, 107.64it/s]


Epoch 5/20: train_loss=0.0783, val_loss=0.0928


100%|██████████| 469/469 [00:04<00:00, 109.85it/s]


Epoch 6/20: train_loss=0.0666, val_loss=0.0865


100%|██████████| 469/469 [00:04<00:00, 111.38it/s]


Epoch 7/20: train_loss=0.0572, val_loss=0.0821


100%|██████████| 469/469 [00:04<00:00, 107.09it/s]


Epoch 8/20: train_loss=0.0494, val_loss=0.0770


100%|██████████| 469/469 [00:04<00:00, 112.34it/s]


Epoch 9/20: train_loss=0.0442, val_loss=0.0748


100%|██████████| 469/469 [00:04<00:00, 107.94it/s]


Epoch 10/20: train_loss=0.0394, val_loss=0.0725


100%|██████████| 469/469 [00:04<00:00, 106.78it/s]


Epoch 11/20: train_loss=0.0360, val_loss=0.0713


100%|██████████| 469/469 [00:04<00:00, 113.54it/s]


Epoch 12/20: train_loss=0.0324, val_loss=0.0680


100%|██████████| 469/469 [00:04<00:00, 109.49it/s]


Epoch 13/20: train_loss=0.0290, val_loss=0.0669


100%|██████████| 469/469 [00:04<00:00, 108.37it/s]


Epoch 14/20: train_loss=0.0269, val_loss=0.0654


100%|██████████| 469/469 [00:04<00:00, 108.98it/s]


Epoch 15/20: train_loss=0.0247, val_loss=0.0652


100%|██████████| 469/469 [00:04<00:00, 109.66it/s]


Epoch 16/20: train_loss=0.0231, val_loss=0.0629


100%|██████████| 469/469 [00:04<00:00, 109.76it/s]


Epoch 17/20: train_loss=0.0206, val_loss=0.0626


100%|██████████| 469/469 [00:04<00:00, 111.07it/s]


Epoch 18/20: train_loss=0.0194, val_loss=0.0623


100%|██████████| 469/469 [00:04<00:00, 112.10it/s]


Epoch 19/20: train_loss=0.0183, val_loss=0.0611


100%|██████████| 469/469 [00:04<00:00, 111.74it/s]


Epoch 20/20: train_loss=0.0168, val_loss=0.0611
TRAINING MODEL: FFNN 2 HL, 64 HN | OPTIMISER: sgd


100%|██████████| 469/469 [00:02<00:00, 196.56it/s]


Epoch 1/20: train_loss=0.4375, val_loss=0.1994


100%|██████████| 469/469 [00:02<00:00, 190.21it/s]


Epoch 2/20: train_loss=0.1675, val_loss=0.1345


100%|██████████| 469/469 [00:02<00:00, 194.33it/s]


Epoch 3/20: train_loss=0.1230, val_loss=0.1149


100%|██████████| 469/469 [00:02<00:00, 192.44it/s]


Epoch 4/20: train_loss=0.0960, val_loss=0.0996


100%|██████████| 469/469 [00:02<00:00, 199.60it/s]


Epoch 5/20: train_loss=0.0791, val_loss=0.1070


100%|██████████| 469/469 [00:02<00:00, 195.96it/s]


Epoch 6/20: train_loss=0.0662, val_loss=0.0859


100%|██████████| 469/469 [00:02<00:00, 188.40it/s]


Epoch 7/20: train_loss=0.0560, val_loss=0.0822


100%|██████████| 469/469 [00:02<00:00, 192.44it/s]


Epoch 8/20: train_loss=0.0480, val_loss=0.0859


100%|██████████| 469/469 [00:02<00:00, 195.08it/s]


Epoch 9/20: train_loss=0.0416, val_loss=0.0800


100%|██████████| 469/469 [00:02<00:00, 193.88it/s]


Epoch 10/20: train_loss=0.0374, val_loss=0.0893


100%|██████████| 469/469 [00:02<00:00, 196.30it/s]


Epoch 11/20: train_loss=0.0318, val_loss=0.0823


100%|██████████| 469/469 [00:02<00:00, 192.72it/s]


Epoch 12/20: train_loss=0.0270, val_loss=0.0805


100%|██████████| 469/469 [00:02<00:00, 193.31it/s]


Epoch 13/20: train_loss=0.0227, val_loss=0.0796


100%|██████████| 469/469 [00:02<00:00, 194.20it/s]


Epoch 14/20: train_loss=0.0209, val_loss=0.0880


100%|██████████| 469/469 [00:02<00:00, 192.55it/s]


Epoch 15/20: train_loss=0.0176, val_loss=0.0900


100%|██████████| 469/469 [00:02<00:00, 197.36it/s]


Epoch 16/20: train_loss=0.0151, val_loss=0.0834


100%|██████████| 469/469 [00:02<00:00, 193.51it/s]


Epoch 17/20: train_loss=0.0141, val_loss=0.0894


100%|██████████| 469/469 [00:02<00:00, 194.68it/s]


Epoch 18/20: train_loss=0.0112, val_loss=0.0950


100%|██████████| 469/469 [00:02<00:00, 193.45it/s]


Epoch 19/20: train_loss=0.0093, val_loss=0.0914


100%|██████████| 469/469 [00:02<00:00, 192.57it/s]


Epoch 20/20: train_loss=0.0072, val_loss=0.0915
TRAINING MODEL: FFNN 2 HL, 64 HN | OPTIMISER: ngd


100%|██████████| 469/469 [00:05<00:00, 83.70it/s]


Epoch 1/20: train_loss=1.3747, val_loss=0.3236


100%|██████████| 469/469 [00:05<00:00, 87.76it/s] 


Epoch 2/20: train_loss=0.2093, val_loss=0.1553


100%|██████████| 469/469 [00:05<00:00, 84.95it/s]


Epoch 3/20: train_loss=0.1360, val_loss=0.1351


100%|██████████| 469/469 [00:05<00:00, 85.06it/s]


Epoch 4/20: train_loss=0.1046, val_loss=0.1127


100%|██████████| 469/469 [00:05<00:00, 84.23it/s]


Epoch 5/20: train_loss=0.0840, val_loss=0.1045


100%|██████████| 469/469 [00:05<00:00, 84.28it/s]


Epoch 6/20: train_loss=0.0693, val_loss=0.1017


100%|██████████| 469/469 [00:05<00:00, 85.05it/s]


Epoch 7/20: train_loss=0.0604, val_loss=0.0946


100%|██████████| 469/469 [00:05<00:00, 84.93it/s]


Epoch 8/20: train_loss=0.0546, val_loss=0.0995


100%|██████████| 469/469 [00:05<00:00, 83.67it/s]


Epoch 9/20: train_loss=0.0501, val_loss=0.1036


100%|██████████| 469/469 [00:05<00:00, 84.18it/s]


Epoch 10/20: train_loss=0.0476, val_loss=0.1022


100%|██████████| 469/469 [00:05<00:00, 83.47it/s]


Epoch 11/20: train_loss=0.0446, val_loss=0.1056


100%|██████████| 469/469 [00:05<00:00, 85.49it/s]


Epoch 12/20: train_loss=0.0434, val_loss=0.1121


100%|██████████| 469/469 [00:05<00:00, 84.86it/s]


Epoch 13/20: train_loss=0.0440, val_loss=0.1176


100%|██████████| 469/469 [00:05<00:00, 87.06it/s]


Epoch 14/20: train_loss=0.0480, val_loss=0.1218


100%|██████████| 469/469 [00:05<00:00, 85.87it/s]


Epoch 15/20: train_loss=0.0475, val_loss=0.1245


100%|██████████| 469/469 [00:05<00:00, 84.64it/s]


Epoch 16/20: train_loss=0.0453, val_loss=0.1295


100%|██████████| 469/469 [00:05<00:00, 85.39it/s]


Epoch 17/20: train_loss=0.0500, val_loss=0.1364


100%|██████████| 469/469 [00:05<00:00, 86.86it/s]


Epoch 18/20: train_loss=0.0492, val_loss=0.1347


100%|██████████| 469/469 [00:05<00:00, 87.38it/s]


Epoch 19/20: train_loss=0.0477, val_loss=0.1300


100%|██████████| 469/469 [00:05<00:00, 85.44it/s]


Epoch 20/20: train_loss=0.0437, val_loss=0.1276
TRAINING MODEL: FFNN 2 HL, 128 HN | OPTIMISER: sgd


100%|██████████| 469/469 [00:02<00:00, 195.42it/s]


Epoch 1/20: train_loss=0.4230, val_loss=0.1868


100%|██████████| 469/469 [00:02<00:00, 195.88it/s]


Epoch 2/20: train_loss=0.1576, val_loss=0.1233


100%|██████████| 469/469 [00:02<00:00, 205.64it/s]


Epoch 3/20: train_loss=0.1073, val_loss=0.1048


100%|██████████| 469/469 [00:02<00:00, 193.91it/s]


Epoch 4/20: train_loss=0.0822, val_loss=0.0887


100%|██████████| 469/469 [00:02<00:00, 200.43it/s]


Epoch 5/20: train_loss=0.0650, val_loss=0.0824


100%|██████████| 469/469 [00:02<00:00, 191.79it/s]


Epoch 6/20: train_loss=0.0530, val_loss=0.0746


100%|██████████| 469/469 [00:02<00:00, 200.18it/s]


Epoch 7/20: train_loss=0.0429, val_loss=0.0729


100%|██████████| 469/469 [00:02<00:00, 198.30it/s]


Epoch 8/20: train_loss=0.0352, val_loss=0.0761


100%|██████████| 469/469 [00:02<00:00, 192.65it/s]


Epoch 9/20: train_loss=0.0284, val_loss=0.0762


100%|██████████| 469/469 [00:02<00:00, 192.09it/s]


Epoch 10/20: train_loss=0.0245, val_loss=0.0723


100%|██████████| 469/469 [00:02<00:00, 194.16it/s]


Epoch 11/20: train_loss=0.0198, val_loss=0.0688


100%|██████████| 469/469 [00:02<00:00, 196.09it/s]


Epoch 12/20: train_loss=0.0156, val_loss=0.0718


100%|██████████| 469/469 [00:02<00:00, 196.86it/s]


Epoch 13/20: train_loss=0.0135, val_loss=0.0782


100%|██████████| 469/469 [00:02<00:00, 194.74it/s]


Epoch 14/20: train_loss=0.0111, val_loss=0.0787


100%|██████████| 469/469 [00:02<00:00, 197.91it/s]


Epoch 15/20: train_loss=0.0086, val_loss=0.0742


100%|██████████| 469/469 [00:02<00:00, 195.71it/s]


Epoch 16/20: train_loss=0.0067, val_loss=0.0758


100%|██████████| 469/469 [00:02<00:00, 193.43it/s]


Epoch 17/20: train_loss=0.0049, val_loss=0.0742


100%|██████████| 469/469 [00:02<00:00, 200.86it/s]


Epoch 18/20: train_loss=0.0044, val_loss=0.0781


100%|██████████| 469/469 [00:02<00:00, 194.50it/s]


Epoch 19/20: train_loss=0.0035, val_loss=0.0814


100%|██████████| 469/469 [00:02<00:00, 196.65it/s]


Epoch 20/20: train_loss=0.0030, val_loss=0.0794
TRAINING MODEL: FFNN 2 HL, 128 HN | OPTIMISER: ngd


100%|██████████| 469/469 [00:05<00:00, 83.14it/s]


Epoch 1/20: train_loss=1.6393, val_loss=0.6933


100%|██████████| 469/469 [00:05<00:00, 85.83it/s]


Epoch 2/20: train_loss=0.3164, val_loss=0.1760


100%|██████████| 469/469 [00:05<00:00, 83.75it/s]


Epoch 3/20: train_loss=0.1397, val_loss=0.1294


100%|██████████| 469/469 [00:05<00:00, 86.13it/s]


Epoch 4/20: train_loss=0.1069, val_loss=0.1127


100%|██████████| 469/469 [00:05<00:00, 82.04it/s]


Epoch 5/20: train_loss=0.0862, val_loss=0.1002


100%|██████████| 469/469 [00:05<00:00, 85.60it/s]


Epoch 6/20: train_loss=0.0710, val_loss=0.0915


100%|██████████| 469/469 [00:05<00:00, 85.45it/s]


Epoch 7/20: train_loss=0.0601, val_loss=0.0852


100%|██████████| 469/469 [00:05<00:00, 85.52it/s]


Epoch 8/20: train_loss=0.0530, val_loss=0.0823


100%|██████████| 469/469 [00:05<00:00, 84.54it/s]


Epoch 9/20: train_loss=0.0472, val_loss=0.0782


100%|██████████| 469/469 [00:05<00:00, 86.00it/s]


Epoch 10/20: train_loss=0.0409, val_loss=0.0779


100%|██████████| 469/469 [00:05<00:00, 86.49it/s]


Epoch 11/20: train_loss=0.0368, val_loss=0.0754


100%|██████████| 469/469 [00:05<00:00, 84.79it/s]


Epoch 12/20: train_loss=0.0330, val_loss=0.0733


100%|██████████| 469/469 [00:05<00:00, 83.48it/s]


Epoch 13/20: train_loss=0.0296, val_loss=0.0738


100%|██████████| 469/469 [00:05<00:00, 84.92it/s]


Epoch 14/20: train_loss=0.0270, val_loss=0.0732


100%|██████████| 469/469 [00:05<00:00, 80.06it/s]


Epoch 15/20: train_loss=0.0242, val_loss=0.0739


100%|██████████| 469/469 [00:05<00:00, 84.23it/s]


Epoch 16/20: train_loss=0.0223, val_loss=0.0734


100%|██████████| 469/469 [00:05<00:00, 82.65it/s]


Epoch 17/20: train_loss=0.0206, val_loss=0.0732


100%|██████████| 469/469 [00:05<00:00, 86.74it/s]


Epoch 18/20: train_loss=0.0199, val_loss=0.0727


100%|██████████| 469/469 [00:05<00:00, 85.55it/s]


Epoch 19/20: train_loss=0.0179, val_loss=0.0723


100%|██████████| 469/469 [00:05<00:00, 83.29it/s]


Epoch 20/20: train_loss=0.0170, val_loss=0.0718
TRAINING MODEL: FFNN 2 HL, 256 HN | OPTIMISER: sgd


100%|██████████| 469/469 [00:02<00:00, 194.50it/s]


Epoch 1/20: train_loss=0.4074, val_loss=0.1863


100%|██████████| 469/469 [00:02<00:00, 194.24it/s]


Epoch 2/20: train_loss=0.1481, val_loss=0.1174


100%|██████████| 469/469 [00:02<00:00, 195.81it/s]


Epoch 3/20: train_loss=0.0994, val_loss=0.1009


100%|██████████| 469/469 [00:02<00:00, 196.35it/s]


Epoch 4/20: train_loss=0.0738, val_loss=0.0789


100%|██████████| 469/469 [00:02<00:00, 195.59it/s]


Epoch 5/20: train_loss=0.0565, val_loss=0.0719


100%|██████████| 469/469 [00:02<00:00, 196.10it/s]


Epoch 6/20: train_loss=0.0445, val_loss=0.0732


100%|██████████| 469/469 [00:02<00:00, 197.70it/s]


Epoch 7/20: train_loss=0.0352, val_loss=0.0694


100%|██████████| 469/469 [00:02<00:00, 200.10it/s]


Epoch 8/20: train_loss=0.0277, val_loss=0.0683


100%|██████████| 469/469 [00:02<00:00, 201.32it/s]


Epoch 9/20: train_loss=0.0222, val_loss=0.0621


100%|██████████| 469/469 [00:02<00:00, 198.18it/s]


Epoch 10/20: train_loss=0.0174, val_loss=0.0635


100%|██████████| 469/469 [00:02<00:00, 196.87it/s]


Epoch 11/20: train_loss=0.0138, val_loss=0.0690


100%|██████████| 469/469 [00:02<00:00, 198.41it/s]


Epoch 12/20: train_loss=0.0110, val_loss=0.0634


100%|██████████| 469/469 [00:02<00:00, 193.43it/s]


Epoch 13/20: train_loss=0.0084, val_loss=0.0634


100%|██████████| 469/469 [00:02<00:00, 189.86it/s]


Epoch 14/20: train_loss=0.0062, val_loss=0.0646


100%|██████████| 469/469 [00:02<00:00, 200.27it/s]


Epoch 15/20: train_loss=0.0050, val_loss=0.0624


100%|██████████| 469/469 [00:02<00:00, 197.42it/s]


Epoch 16/20: train_loss=0.0040, val_loss=0.0658


100%|██████████| 469/469 [00:02<00:00, 200.29it/s]


Epoch 17/20: train_loss=0.0032, val_loss=0.0655


100%|██████████| 469/469 [00:02<00:00, 194.34it/s]


Epoch 18/20: train_loss=0.0026, val_loss=0.0650


100%|██████████| 469/469 [00:02<00:00, 196.61it/s]


Epoch 19/20: train_loss=0.0023, val_loss=0.0661


100%|██████████| 469/469 [00:02<00:00, 192.42it/s]


Epoch 20/20: train_loss=0.0019, val_loss=0.0683
TRAINING MODEL: FFNN 2 HL, 256 HN | OPTIMISER: ngd


100%|██████████| 469/469 [00:05<00:00, 80.96it/s]


Epoch 1/20: train_loss=1.8561, val_loss=1.2382


100%|██████████| 469/469 [00:05<00:00, 81.52it/s]


Epoch 2/20: train_loss=0.7266, val_loss=0.3523


100%|██████████| 469/469 [00:05<00:00, 83.91it/s]


Epoch 3/20: train_loss=0.2295, val_loss=0.1725


100%|██████████| 469/469 [00:05<00:00, 83.90it/s]


Epoch 4/20: train_loss=0.1379, val_loss=0.1320


100%|██████████| 469/469 [00:05<00:00, 82.62it/s]


Epoch 5/20: train_loss=0.1046, val_loss=0.1139


100%|██████████| 469/469 [00:05<00:00, 81.08it/s]


Epoch 6/20: train_loss=0.0861, val_loss=0.1003


100%|██████████| 469/469 [00:05<00:00, 83.62it/s]


Epoch 7/20: train_loss=0.0712, val_loss=0.0907


100%|██████████| 469/469 [00:05<00:00, 82.73it/s]


Epoch 8/20: train_loss=0.0608, val_loss=0.0846


100%|██████████| 469/469 [00:05<00:00, 82.50it/s]


Epoch 9/20: train_loss=0.0529, val_loss=0.0808


100%|██████████| 469/469 [00:05<00:00, 84.01it/s]


Epoch 10/20: train_loss=0.0470, val_loss=0.0761


100%|██████████| 469/469 [00:05<00:00, 81.78it/s]


Epoch 11/20: train_loss=0.0414, val_loss=0.0731


100%|██████████| 469/469 [00:05<00:00, 84.21it/s]


Epoch 12/20: train_loss=0.0371, val_loss=0.0703


100%|██████████| 469/469 [00:05<00:00, 80.03it/s]


Epoch 13/20: train_loss=0.0338, val_loss=0.0675


100%|██████████| 469/469 [00:05<00:00, 80.98it/s]


Epoch 14/20: train_loss=0.0302, val_loss=0.0659


100%|██████████| 469/469 [00:05<00:00, 82.49it/s]


Epoch 15/20: train_loss=0.0271, val_loss=0.0648


100%|██████████| 469/469 [00:05<00:00, 81.65it/s]


Epoch 16/20: train_loss=0.0250, val_loss=0.0633


100%|██████████| 469/469 [00:05<00:00, 83.42it/s]


Epoch 17/20: train_loss=0.0235, val_loss=0.0628


100%|██████████| 469/469 [00:05<00:00, 83.19it/s]


Epoch 18/20: train_loss=0.0221, val_loss=0.0617


100%|██████████| 469/469 [00:05<00:00, 82.93it/s]


Epoch 19/20: train_loss=0.0202, val_loss=0.0606


100%|██████████| 469/469 [00:05<00:00, 82.63it/s]


Epoch 20/20: train_loss=0.0185, val_loss=0.0606


In [28]:
# EXPERIMENT 2: Train to convergence with SGD, then continue with SGD and NGD and compare

if hp["experiment_type"] == "swap":
    train_losses = {}
    val_losses = {}
    update_norms = {}
    all_update_norms = {}

    criterion = nn.CrossEntropyLoss()

    for title, model in models.items():

        del model["ngd"][0]

        model_train_losses = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_val_losses = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_update_norms = {"sgd" : [], "ngd" : [None for i in range(hp["cut_off_epochs"]+1)]}
        model_all_update_norms = {"sgd" : [], "ngd" : [None for i in range(len(train_loader)*hp["cut_off_epochs"])]}

        state = copy.deepcopy(model["sgd"][0])
        optimizer = t.optim.SGD(
            params=state.parameters(),
            lr=hp["sgd_lr"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
    
        print(f"TRAINING MODEL: {title} | OPTIMISER: SGD")
        initial_train_loss = evaluate(state, train_loader, criterion, device)
        initial_val_loss = evaluate(state, test_loader, criterion, device)
        model_train_losses["sgd"].append(initial_train_loss)
        model_val_losses["sgd"].append(initial_val_loss)
        for epoch in range(1, hp["cut_off_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optimizer, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["sgd"].append(train_loss)
            model_update_norms["sgd"].append(update_norm)
            model_val_losses["sgd"].append(val_loss)
            model_all_update_norms["sgd"] += epoch_update_norms
            model["sgd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")


        # Train with SGD
        print("Training from model checkpoint with SGD.")
        state = copy.deepcopy(model["sgd"][hp["cut_off_epochs"]])
        optim_sgd = t.optim.SGD(
            params=state.parameters(),
            lr=hp["sgd_lr"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
        for epoch in range(hp["cut_off_epochs"]+1, hp["num_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optim_sgd, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["sgd"].append(train_loss)
            model_update_norms["sgd"].append(update_norm)
            model_val_losses["sgd"].append(val_loss)
            model_all_update_norms["sgd"] += epoch_update_norms
            model["sgd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
        
        # Train with NGD
        print("Training from model checkpoint with NGD.")
        state = copy.deepcopy(model["sgd"][hp["cut_off_epochs"]])
        optim_ngd = NGD(
            params=state.parameters(),
            lr=hp["ngd_lr"],
            alpha=hp["alpha"],
            eta=hp["eta"],
            epsilon=hp["epsilon"],
            delta=hp["delta"],
            momentum=hp["momentum"],
            nesterov=hp["nesterov"],
        )
        for epoch in range(hp["cut_off_epochs"]+1, hp["num_epochs"]+1):
            train_loss, update_norm, _, epoch_update_norms = train_one_epoch(state, train_loader, optim_ngd, criterion, device)
            val_loss = evaluate(state, test_loader, criterion, device)
            model_train_losses["ngd"].append(train_loss)
            model_update_norms["ngd"].append(update_norm)
            model_val_losses["ngd"].append(val_loss)
            model_all_update_norms["ngd"] += epoch_update_norms
            model["ngd"].append(copy.deepcopy(state))
            print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")

        # Save train/val loss dictionaries to the right model key
        train_losses[title] = model_train_losses
        val_losses[title] = model_val_losses
        update_norms[title] = model_update_norms
        all_update_norms[title] = model_all_update_norms


In [29]:
# If we are doing the swap experiment, fill in the first hp["cut_off_epochs"] models with None

if hp["experiment_type"] == "swap":
    none_models = [None for i in range(hp["cut_off_epochs"]+1)]
    for title, model in models.items():
        model["ngd"] = none_models + model["ngd"]

#### **3. Visualising training / validation loss results**

Check that the models all converged.

Displays training and testing data for each model separately, with 4 traces on each graph.

The traces are: SGD training, SGD testing, NGD training, NGD testing.

In [30]:
# Display training / val data for all models for NGD / SGD

epochs = np.arange(1, hp["num_epochs"]+1)

loss_figures = {}

color_cycle = ['rgb(32, 102, 168)', 'rgb(200, 0, 0)']

for title in models.keys():
    loss_fig = make_subplots(specs=[[{"secondary_y" : True}]])
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        loss_fig.add_trace(go.Scatter(
            x=np.arange(0, hp["num_epochs"]+1),
            y=train_losses[title][optim],
            mode="lines+markers",
            line=dict(color=color),
            name=f"{optim} Train",
        ), secondary_y=False)
        loss_fig.add_trace(go.Scatter(
            x=np.arange(0, hp["num_epochs"]+1),
            y=val_losses[title][optim],
            mode="lines+markers",
            line=dict(color=color, dash='dot'),
            name=f"{optim} Validation",
        ), secondary_y=True)
    loss_fig.update_layout(
        title=f"{title} training / validation loss",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=1,
        ),
        yaxis_type="linear",
    )
    loss_fig.update_yaxes(title_text="Training Loss", secondary_y=False)
    loss_fig.update_yaxes(title_text="Validation Loss", secondary_y=True)
    loss_figures[title] = loss_fig
    loss_fig.show()

In [31]:
# Gradient norms over epochs for NGD and SGD, note that each step corresponds to one minibatch

update_norm_figures = {}


for title in models.keys():
    update_norm_fig = go.Figure()
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        update_norm_fig.add_trace(go.Scatter(
            x=epochs,
            y=update_norms[title][optim],
            mode="lines",
            line=dict(color=color),
            name=f"{optim} update norm"))
    update_norm_fig.update_layout(
        title=f"{title} update norms over epochs",
        xaxis_title="Epochs",
        yaxis_title="Update Size",
        yaxis_type="linear",
    )
    update_norm_figures[title] = update_norm_fig
    update_norm_fig.show()

#### **4. Perform Hessian rank computation**

As a way of verifying results produced by the RLCT, we will compute an approximation of the Hessian for each model at convergence. Then, we'll estimate the rank of this matrix using its eigenspectrum. SLT predicts the following: 

$\text{RLCT} \geq \frac{\text{rank}(\textbf{Hess})}{2}$ 

We will check whether this is true for our experiments. Hessian computation is done using helper functions from `utils_hessian_fim.py` which acts as a wrapper for the `PyHessian` module (and the `nngeometry` module, for doing computations involving the Fisher Information Matrix).

In [32]:
# Create Hessians for each model - recall that model is a list containing all its past versions
if hp["hessian"]:
    hessians = {}
    for title, model in models.items():
        hessian = {}
        for optim in optimizers:
            hessians_list = produce_hessians(
                models=model[optim],
                data_loader=train_loader,
                num_batches=data_args["num_hessian_batches"],
                criterion=criterion,
                device=device,
            )
            hessian[optim] = hessians_list
        hessians[title] = hessian

In [33]:
# Get the eigenspectum data from the Hessian objects
if hp["hessian"]:
    eigenspectra_data = {}
    eigenspectra_figs = {}
    for title, hessian in tqdm(hessians.items(), desc="Processing Hessians"):
        eigenspectrum_data = {}
        eigenspectrum_figs = {}
        for optim in optimizers:
            eigenspectrum_figs_list, eigenspectrum_data_list = produce_eigenspectra(
                hessians=hessian[optim],
                plot_type="log",
            )
            eigenspectrum_figs[optim] = eigenspectrum_figs_list
            eigenspectrum_data[optim] = eigenspectrum_data_list
        eigenspectra_data[title] = eigenspectrum_data
        eigenspectra_figs[title] = eigenspectrum_figs

In [34]:
# Produce the traces of Hessian dimensionality over epochs
if hp["hessian"]:
    hessian_ranks = {}
    for title, eigenspectrum_data in eigenspectra_data.items():
        hessian_rank = {}
        for optim in optimizers:
            hessian_rank_list = find_hessian_dimensionality(eigenspectrum_data[optim])
            hessian_rank[optim] = hessian_rank_list
        hessian_ranks[title] = hessian_rank

#### **5. Perform RLCT estimation**

Using the `devinterp` library, we perform estimation of the RLCT (Real Log Canonical Threshold) or otherwise known as the LLC (Local Learning Coefficient).

`rlct_estimates` is a dictionary containing dictionaries, each of which contain two lists, one for SGD RLCT values over epochs, and one for NGD RLCT values over epochs. The same is true for `wbic_estimates`.

In [35]:
rlct_estimates = {}
histories = {}

for title, model in models.items():
    history = {}
    rlct_estimate = {}
    for optim in optimizers:
        rlct_list, history_list = estimate_rlcts(
            model[optim], train_loader, criterion, device, devinterp_args,
        )
        rlct_estimate[optim] = rlct_list
        history[optim] = history_list
    rlct_estimates[title] = rlct_estimate 
    histories[title] = history

  0%|          | 0/21 [00:00<?, ?it/s]

Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 303.16it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 295.09it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 288.40it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 291.31it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 293.56it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 296.13it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 286.27it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 292.43it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 288.26it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 288.48it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 292.23it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 291.97it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 300.52it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 287.29it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 299.61it/s]
Chain 0: 100%|██████████| 1000/1000 [00:03<00:00, 299.17it/s]
Chain 0:

In [36]:
# Compute generalisation losses for each model, for SGD and NGD, so we can compare this to the testing loss

gen_losses = {}
for title in models.keys():
    gen_loss = {}
    for optim in optimizers:
        gen_loss_list = []
        for i in range(hp["num_epochs"]):
            if histories[title][optim][i] is None:
                gen_loss_list.append(None)
            else:
                #gen loss uses actual train loss here
                gen_loss_list.append(train_losses[title][optim][i] + rlct_estimates[title][optim][i]/data_args["batch_size"])
        gen_loss[optim] = gen_loss_list
    gen_losses[title] = gen_loss

#### **6. Visualise RLCT / Hessian rank, Hessian eigenspectra, generalisation loss, RLCT convergence**

We display the following final figures:
- RLCT and Hessian rank evolution for each model, for NGD and SGD
- Hessian eigenspectra for SGD and NGD overlaid on the same plot
- Evolution of generalisation loss for each model, compared to validation loss
- RLCT moving average evolution for each model, for NGD and SGD, to check for convergence

In [37]:
# Visualise RLCT and Hessian rank data

exp_figures = {}

for title in models.keys():
    exp_fig = make_subplots(specs=[[{"secondary_y" : True}]])
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        exp_fig.add_trace(go.Scatter(
            x=epochs,
            y=rlct_estimates[title][optim][1:],
            mode="lines+markers",
            name=f"{optim} RLCT",
            line=dict(color=color),
        ), secondary_y=False)
        if hp["hessian"]:
            exp_fig.add_trace(go.Scatter(
                x=epochs,
                y=hessian_ranks[title][optim][1:],
                mode="lines+markers",
                name=f"{optim} Hessian Rank",
                line=dict(color=color, dash="dot"),
            ), secondary_y=True)
    exp_fig.update_layout(
        title=f"{title} RLCT / Hessian Rank (optional) evolution during training",
        xaxis_title="Epoch",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=1,
        ),
        yaxis_type="linear",
    )
    exp_fig.update_yaxes(title_text="RLCT", secondary_y=False)
    exp_fig.update_yaxes(title_text="Hessian rank", secondary_y=True)
    exp_figures[title] = exp_fig
    exp_fig.show()   

In [38]:
# Visualise converged eigenspectra for SGD and NGD for each model

if hp["hessian"]:
    combined_eigenspectra = {}

    for title in models.keys():

        combined_eigenspectrum = go.Figure()
        final_eigenspectra = {}
        traces = {}

        for optim in optimizers:
            final_eigenspectra[optim] = eigenspectra_figs[title][optim][-2]
            final_eigenspectra[optim].data[0].name = optim
            combined_eigenspectrum.add_trace(final_eigenspectra[optim].data[0])

        combined_eigenspectrum.update_layout(
            title=f"{title} Hessian eigenspectra at convergence",
            xaxis_title="Eigenvalue",
            yaxis_title="Probability density (log scale)",
            yaxis_type="log",
        )

        combined_eigenspectrum.show()
        combined_eigenspectra[title] = (combined_eigenspectrum)

In [39]:
# Visualise generalisation loss vs. testing loss for each model

train_gen_figs = {}
for title in models.keys():
    train_gen_fig = go.Figure()
    color_index = 0
    for optim in optimizers:
        color = color_cycle[color_index % len(color_cycle)]
        color_index += 1
        train_gen_fig.add_trace(go.Scatter(
            x=epochs,
            y=val_losses[title][optim][1:],
            mode="lines+markers",
            name=f"{optim} Validation",
            line=dict(color=color),
        ))
        train_gen_fig.add_trace(go.Scatter(
            x=epochs,
            y=gen_losses[title][optim][1:],
            mode="lines+markers",
            name=f"{optim} Generalisation",
            line=dict(color=color, dash="dot"),
        ))
    train_gen_fig.update_layout(
        title=f"{title} validation / generalisation loss",
        xaxis_title="Epoch",
        yaxis_title="Loss",
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=1,
        ),
        yaxis_type="linear",
    )
    train_gen_figs[title] = train_gen_fig
    train_gen_fig.show()

In [40]:
# Check the LLC chains converged for each model

rlct_converge_plots = {}
for title in models.keys():
    rlct_converge_plot = go.Figure()
    for optim in optimizers:
        if histories[title][optim][-1] is not None:
            rlct_converge_plot.add_trace(go.Scatter(
                #average over chains
                y=np.mean(histories[title][optim][-1]["llc/moving_avg"],axis=0),
                name=f"{optim} Last Epoch",
            ))
    rlct_converge_plot.update_layout(
        title=f"Evolution of LLC moving average for each model over epochs for {title}",
        xaxis_title="Draws",
        yaxis_title="RLCT",
        yaxis_type="linear",
        legend_title="Epoch"
    )
    rlct_converge_plots[title] = rlct_converge_plot
    rlct_converge_plot.show()

In [41]:
# Save the results to a HTML file.

figures = []

combined_args = {**hp, **data_args, **devinterp_args}

#summary = pprint.pformat(combined_args, sort_dicts=False)

for fig in loss_figures.values():
    figures.append(fig)
for fig in update_norm_figures.values():
    figures.append(fig)
for fig in exp_figures.values():
    figures.append(fig)
if hp["hessian"]:
    for fig in combined_eigenspectra.values():
        figures.append(fig)
for fig in train_gen_figs.values():
    figures.append(fig)
for fig in rlct_converge_plots.values():
    figures.append(fig)

curr_time = datetime.now().strftime("%Y-%m-%d-%H-%M")

write_figs_to_html(
    figs=figures,
    dest=f"./ngd_sgd/{hp['model_type']}_ngd_sgd_rlct_{curr_time}.html",
    title="Does NGD converge to minima that are 'more complex' i.e. have a higher RLCT?",
    summary=combined_args,
)

In [42]:
# Visualise RLCT and Hessian rank data

def round_to_1sf(value):
    return round(value, 1)

for title in models.keys():
    for optim in optimizers:
        print('---'+title+'---'+optim+'---')
        rounded_list = [round_to_1sf(value) for value in rlct_estimates[title][optim][1:]]
        print(rounded_list)

---FFNN 1 HL, 64 HN---sgd---
[14.7, 16.9, 17.9, 15.4, 14.9, 16.4, 15.2, 14.9, 15.6, 15.6, 15.5, 15.4, 17.7, 16.0, 17.5, 14.5, 16.3, 15.5, 16.3, 17.4]
---FFNN 1 HL, 64 HN---ngd---
[31.4, 29.2, 28.7, 27.9, 29.0, 26.1, 25.6, 26.4, 26.4, 25.7, 25.6, 25.7, 23.3, 24.3, 23.6, 23.4, 24.0, 25.0, 24.9, 25.4]
---FFNN 1 HL, 128 HN---sgd---
[19.2, 20.9, 18.0, 19.5, 20.5, 19.1, 19.7, 20.1, 18.0, 19.7, 17.8, 19.9, 19.2, 18.6, 19.5, 18.6, 20.0, 19.9, 19.3, 21.3]
---FFNN 1 HL, 128 HN---ngd---
[35.6, 35.7, 35.7, 36.5, 36.7, 32.7, 32.2, 34.7, 34.0, 33.3, 34.2, 34.2, 31.3, 32.9, 29.6, 31.8, 32.7, 32.5, 29.6, 29.9]
---FFNN 1 HL, 256 HN---sgd---
[26.0, 24.3, 25.0, 26.9, 25.8, 26.7, 23.8, 23.8, 25.4, 24.7, 24.1, 25.3, 23.0, 24.0, 24.3, 23.9, 22.6, 23.8, 22.5, 23.4]
---FFNN 1 HL, 256 HN---ngd---
[43.9, 42.4, 42.3, 40.5, 41.8, 39.9, 41.3, 38.5, 40.7, 39.8, 38.2, 37.8, 39.0, 37.0, 36.4, 36.7, 37.1, 36.4, 38.3, 35.4]
---FFNN 2 HL, 64 HN---sgd---
[20.6, 21.6, 18.3, 22.2, 24.0, 20.6, 25.0, 23.1, 24.2, 25.7, 24.3, 