### **Observing phase transitions in the FIM and Hessian spectra during training**

We show empirical evidence for phase transitions in the Hessian and Fisher Information Matrices during convergence for models with a high number of hidden layers.

#### **Methodology**

We train on the MNIST dataset using:
- Deep neural networks trained with a large number of hidden layers (on the order of 10 hidden layers), showing that the qualitative changes are induced irrespective of model width
- CNNs with a varying number of channels and CNNs with a varying number of convolution layers

#### **0. Import libraries**

`nngeometry` is used to produce the diagonal elements of the FIM.
`PyHessian` is used to estimate the eigenspectrum of the Hessian.

In [8]:
from multiprocessing import freeze_support

import os
import sys
import copy
import pickle
import pprint
import json
from pathlib import Path
from datetime import datetime
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm

sys.path.append("../")

import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split

from devinterp.slt import estimate_learning_coeff_with_summary
from devinterp.optim import SGLD, SGNHT
from devinterp.slt import sample, OnlineLLCEstimator
from devinterp.slt.wbic import OnlineWBICEstimator
from devinterp.slt.mala import MalaAcceptanceRate
from devinterp.utils import plot_trace, optimal_temperature

from approxngd import KFAC
from PyHessian.pyhessian import *
from PyHessian.density_plot import *

from utils_general import *
from utils_hessian_fim import *
from networks import *
from ngd import NGD

from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap

device = "cuda" if t.cuda.is_available() else "cpu"
print(f"Device in use: {device}")

%load_ext autoreload
%autoreload
%matplotlib inline

Device in use: cuda
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### **1. Import data for training**

We import our dataset (CIFAR10 or MNIST) for use in training our models. We also define our hyperparameters.

In [4]:
# Neural network hyperparameters

with open("args_phase_transitions.json", "r") as file:
    args = json.load(file) 

hp, data_args, devinterp_args = args
pprint.pprint(hp)
pprint.pprint(data_args)
pprint.pprint(devinterp_args)

epochs = np.arange(1, hp["num_epochs"]+1)

{'cnn_hidden_layers': 7,
 'linear_hidden_layers': 5,
 'linear_hidden_nodes': 512,
 'lr': 0.009,
 'model_type': 'cnn',
 'momentum': 0.9,
 'nesterov': True,
 'ngd_weight_decay': 0.0001,
 'num_epochs': 10,
 'optimiser': 'sgd'}
{'batch_size': 128, 'dataset': 'mnist', 'num_workers': 6}
{'localization': 100.0,
 'lr': 0.0001,
 'num_chains': 1,
 'num_draws': 1000,
 'sampler': 'sgld'}


In [5]:
# Get training and test loader

train_loader, test_loader = build_data_loaders(data_args)

#### **2. Training models**

For these experiments, we are interested in models that suffer from "unstable gradients". To artifically induce unstable gradients, we will use a large number of hidden layers.

This might require some trial and error to find the kind of loss characteristic we're interested in (a sharp drop at some epoch).

In [6]:
# Model initialisation

models = []

if data_args["dataset"] == "mnist":
    if hp["model_type"] == "dnn":
        model = LinearMNIST(hidden_layers=hp["linear_hidden_layers"], hidden_nodes=hp["linear_hidden_nodes"]).to(device)
    else:
        model = CnnMNIST(hidden_conv_layers=hp["cnn_hidden_layers"]).to(device)
elif data_args["dataset"] == "cifar10":
    if hp["model_type"] == "cnn":
        model = LinearCIFAR10(hidden_layers=hp["linear_hidden_layers"], hidden_nodes=hp["linear_hidden_nodes"]).to(device)
    else:
        model = CnnCIFAR10(hidden_conv_layers=hp["cnn_hidden_layers"]).to(device)

models.append(copy.deepcopy(model))

In [7]:
# Training models (either DNN or CNN)

sgd = t.optim.SGD(model.parameters(),
                  lr=hp["lr"],
                  momentum=hp["momentum"],
                  nesterov=True)
ngd = NGD(params=model.parameters(),
          lr=hp["lr"],
          momentum=hp["momentum"],
          weight_decay=hp["ngd_weight_decay"],
          nesterov=hp["nesterov"])

train_losses = []
val_losses = []
optimiser = ngd if hp["optimiser"] == "ngd" else sgd
criterion = nn.CrossEntropyLoss()
print(f"========== TRAINING | model_type : {hp['model_type']}, dataset : {data_args['dataset']}, optimiser : {hp['optimiser']} ==========")
for epoch in range(1, hp["num_epochs"]+1):
    train_loss = train_one_epoch(model, train_loader, optimiser, criterion, device)
    val_loss = evaluate(model, test_loader, criterion, device)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    models.append(copy.deepcopy(model))
    print(f"Epoch {epoch}/{hp['num_epochs']}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")



  0%|          | 0/469 [00:00<?, ?it/s]

100%|██████████| 469/469 [00:20<00:00, 22.81it/s]


Epoch 1/10: train_loss=0.1211, val_loss=0.0378


100%|██████████| 469/469 [00:07<00:00, 63.32it/s]


Epoch 2/10: train_loss=0.0331, val_loss=0.0334


100%|██████████| 469/469 [00:07<00:00, 62.59it/s]


Epoch 3/10: train_loss=0.0231, val_loss=0.0307


100%|██████████| 469/469 [00:07<00:00, 62.71it/s]


Epoch 4/10: train_loss=0.0196, val_loss=0.0244


100%|██████████| 469/469 [00:07<00:00, 62.91it/s]


Epoch 5/10: train_loss=0.0139, val_loss=0.0231


100%|██████████| 469/469 [00:07<00:00, 62.09it/s]


Epoch 6/10: train_loss=0.0116, val_loss=0.0237


100%|██████████| 469/469 [00:07<00:00, 61.45it/s]


Epoch 7/10: train_loss=0.0097, val_loss=0.0218


100%|██████████| 469/469 [00:07<00:00, 60.86it/s]


Epoch 8/10: train_loss=0.0062, val_loss=0.0210


100%|██████████| 469/469 [00:07<00:00, 61.11it/s]


Epoch 9/10: train_loss=0.0057, val_loss=0.0270


100%|██████████| 469/469 [00:07<00:00, 60.98it/s]


Epoch 10/10: train_loss=0.0051, val_loss=0.0331


#### **3. Visualising the training / validation loss results**

At this point, it's essential that we check we got the right loss shape we were expecting. We'll plot the loss for both model types over epochs using Plotly.

In [9]:
# Training / validation data for linear model

epochs = np.arange(1, hp["num_epochs"]+1)

train_val_fig = make_subplots(specs=[[{"secondary_y" : True}]])
train_val_fig.add_trace(go.Scatter(x=epochs, y=train_losses, mode="lines+markers", name="Train"), secondary_y=False)
train_val_fig.add_trace(go.Scatter(x=epochs, y=val_losses, mode="lines+markers", name="Validation"), secondary_y=True)
train_val_fig.update_layout(title="Training / validation loss",
                               xaxis_title="Epoch",
                               xaxis=dict(
                                   tickmode='linear',
                                   tick0=0,
                                   dtick=1,
                               ))
train_val_fig.update_yaxes(title_text="Training Loss", secondary_y=False)
train_val_fig.update_yaxes(title_text="Validation Loss", secondary_y=True)
train_val_fig.show()

In [10]:
# Note down the epochs over which the phase transition occurred

pre_transition_epoch = 2
post_transition_epoch = 10

# Create a filtered set of models that only includes these epochs

models_transition = []
models_transition.append(models[pre_transition_epoch-1])
models_transition.append(models[post_transition_epoch-1])

#### **4. Compute the Hessian and RLCT estimates**

We will now compute the Hessian eigenspectra before and after the phase transition. Furthermore, we will analyse the evolution of the real canonical log threshold (RLCT) throughout training of the model.

In [11]:
# Compute Hessians for linear and CNN models

hessians = produce_hessians(models=models_transition,
                            data_loader=test_loader,
                            num_batches=1,
                            criterion=criterion,
                            device=device)


Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak. We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak. (Triggered internally at ..\torch\csrc\autograd\engine.cpp:1182.)



In [12]:
# Compute Hessian eigenspectra

hessian_figs, _ = produce_eigenspectra(
    hessians=hessians,
    plot_type="log",
)


Casting complex values to real discards the imaginary part



In [15]:
# Compute RLCT estimates over all models

rlct_estimates, history = estimate_rlcts(
    models=models,
    data_loader=train_loader,
    criterion=criterion,
    devinterp_args=devinterp_args,
    device=device,
)


You are taking more draws than burn-in steps, your LLC estimates will likely be underestimates. Please check LLC chain convergence.


You are taking more sample batches than there are dataloader batches available, this removes some randomness from sampling but is probably fine. (All sample batches beyond the number dataloader batches are cycled from the start, f.e. 9 samples from [A, B, C] would be [B, A, C, B, A, C, B, A, C].)


You are taking more sample batches than there are dataloader batches available, this removes some randomness from sampling but is probably fine. (All sample batches beyond the number dataloader batches are cycled from the start, f.e. 9 samples from [A, B, C] would be [B, A, C, B, A, C, B, A, C].)

Chain 0: 100%|██████████| 1000/1000 [00:18<00:00, 54.58it/s]

std(): degrees of freedom is <= 0. Correction should be strictly less than the reduction factor (input numel divided by output numel). (Triggered internally at ..\aten\src\ATen\native\ReduceOps.cpp:1760.)

#### **5. Plot final figures**

The following data is displayed:
- Hessian eigenspectra pre- and post- convergence
- RLCT evolution across epochs
- RLCT chain convergence analysis

In [13]:
### LINEAR HESSIAN EIGENSPECTRA PRE- AND POST- CONVERGENCE

index_to_name = {
    0: "Pre-convergence eigenspectrum",
    1: "Post-convergence eigenspectrum",
    2: "Combined eigenspectra",
}

for i, hessian_fig in enumerate(hessian_figs):
    hessian_fig.update_layout(title=index_to_name[i])
    hessian_fig.show()

In [16]:
# Display RLCT estimate evolution and overlay with training / validation data graph

rlct_train_val_fig = make_subplots(specs=[[{"secondary_y" : True}]])
rlct_train_val_fig.add_trace(go.Scatter(x=epochs, y=train_losses, mode="lines+markers", name="Train"), secondary_y=False)
rlct_train_val_fig.add_trace(go.Scatter(x=epochs, y=val_losses, mode="lines+markers", name="Validation"), secondary_y=False)
rlct_train_val_fig.add_trace(go.Scatter(x=epochs, y=rlct_estimates, mode="lines+markers", name="RLCT"), secondary_y=True)
rlct_train_val_fig.update_layout(title="RLCT evolution",
                               xaxis_title="Epoch",
                               xaxis=dict(
                                   tickmode='linear',
                                   tick0=0,
                                   dtick=1,
                               ))
rlct_train_val_fig.update_yaxes(title_text="Loss", secondary_y=False)
rlct_train_val_fig.update_yaxes(title_text="RLCT", secondary_y=True)
rlct_train_val_fig.show()

In [17]:
# Check RLCT estimate convergence

rlct_converge_plot = go.Figure()
for epoch in range(hp["num_epochs"]):
    rlct_converge_plot.add_trace(go.Scatter(
        y=history[epoch]["llc/moving_avg"][0],
        name=f"Epoch {epoch+1}",
    ))
rlct_converge_plot.update_layout(
    title=f"Evolution of RLCT moving average for each model over epochs",
    xaxis_title="Draws",
    yaxis_title="RLCT",
    legend_title="Epoch"
)
rlct_converge_plot.show()

In [18]:
# Compile figures into a list, and export experiment figures and summary to HTML file

figures = []
figures.append(train_val_fig)
figures += hessian_figs
figures.append(rlct_train_val_fig)
figures.append(rlct_converge_plot)

combined_args = {**hp, **data_args, **devinterp_args}
summary = pprint.pformat(combined_args)

curr_time = datetime.now().strftime("%Y-%m-%d-%H-%M")
write_figs_to_html(
    figs=figures,
    dest=f"./phase_transitions/hln_phase_transitions_{curr_time}.html",
    title="Phase transitions in high layer networks",
    summary=summary,
)