### **Is SLT correct about the relationship between the RLCT and the Hessian rank for minimally singular models?**

A minimally singular convergence point has a loss function that can be locally written as a sum of square of parameters, i.e.

$$ L(\textbf{w}) = \sum_{i}^{r} w_i^2, \text{ where } r < d  \text{  (num. of parameters)  } \tag{1} $$

SLT predicts the following results for the RLCT in relation to the rank of the Hessian, $\text{rank(\textbf{Hess})}$:

$$ \text{regular model: } \lambda = \frac{d}{2} \tag{2} $$ 
$$ \text{minimally singular model: } \lambda = \frac{r}{2} \text{ where } r = \text{rank(\textbf{Hess})} \tag{3} $$ 
$$ \text{singular model: } \lambda \ge \frac{r}{2} \tag{4} $$ 

This notebook will look at a toy example for an artifically constructed model that is minimally singular, and verify that equation $(3)$ is indeed true.

#### **Methodology**
- Model artifically constructed such that $(1)$ is satisfied
- Hessian rank calculated using `PyHessian` library
- RLCT evaluated using `devinterp` library


#### **0. Import libraries**

Standard machine learning libraries are imported. `devinterp` is imported for LLC estimation.

In [1]:
from multiprocessing import freeze_support

import os
import sys
import copy
import pickle
import pprint
import json
from pathlib import Path
from datetime import datetime
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm

sys.path.append("../")

import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, TensorDataset, random_split

from devinterp.slt import estimate_learning_coeff_with_summary
from devinterp.optim import SGLD, SGNHT
from devinterp.slt import sample, OnlineLLCEstimator
from devinterp.slt.wbic import OnlineWBICEstimator
from devinterp.slt.mala import MalaAcceptanceRate
from devinterp.utils import plot_trace, optimal_temperature

from PyHessian.pyhessian import *
from PyHessian.density_plot import *
from nngeometry.metrics import FIM
from nngeometry.object import PMatKFAC, PMatDiag, PVector

from utils_general import *
from utils_hessian_fim import *
from networks import *
from ngd import NGD

from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap

device = "cuda" if t.cuda.is_available() else "cpu"
print(f"Device in use: {device}")

%load_ext autoreload
%autoreload
%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


Device in use: cuda


#### **1. Setup model and training data**

We create a model such that equation $(1)$ is satisfied. 

We also create some training data such that the set of minima is when all of the parameters are equal to 0. This is done by creating a normally distributed dataset with a mean about zero.

In [12]:
# Experiment args

with open("args_min_singular_models.json", "r") as file:
    args = json.load(file)

model_args, data_args, devinterp_args = args
model_args["r_list"] = np.array(model_args["r_list"])
pprint.pprint(model_args)
pprint.pprint(data_args)
pprint.pprint(devinterp_args)

{'d': 100000, 'r_list': array([    10,    100,   1000,  10000, 100000])}
{'batch_size': 128, 'num_samples': 1000, 'num_workers': 6, 'sigma': 0.1}
{'localization': 100.0,
 'num_chains': 2,
 'num_draws': 400,
 'sampler': 'sgld',
 'sampler_lr': 0.0001}


In [3]:
# Data loading function

def generate_dataset_for_seed(data_args, seed=0):
    t.manual_seed(seed)
    np.random.seed(seed)
    x = t.normal(0, 2, size=(data_args["num_samples"],))
    y = data_args["sigma"] * t.normal(0, 1, size=(data_args["num_samples"],))
    train_data = TensorDataset(x, y)
    train_loader = DataLoader(train_data, batch_size=data_args["batch_size"], shuffle=True, num_workers=data_args["num_workers"], persistent_workers=True)
    return train_loader, train_data, x, y

In [4]:
# Generate data

train_loader, train_data, x, y = generate_dataset_for_seed(data_args, seed=0)

In [5]:
# Initialise models, with d total parameters, and r used parameters.

models = []

for r in model_args["r_list"]:
    params = t.zeros(model_args["d"])
    model = MinimallySingularModel(params=params, r=r).to(device)
    models.append(model)

In [6]:
# Specify loss function

criterion = nn.CrossEntropyLoss()

#### **2. Perform RLCT estimation and Hessian rank calculation at most degenerate point**

The most degenerate singularity occurs when all the parameters are equal to zero. 

We'll set all the weights to be zero for our model, and then perform RLCT estimation at this point. We will also calculate the Hessian rank at each of these points.

In [13]:
# Estimate RLCT values for all models

rlct_values, history = estimate_rlcts(
    models=models,
    data_loader=train_loader,
    criterion=criterion,
    device=device,
    devinterp_args=devinterp_args
)

Chain 0: 100%|██████████| 400/400 [00:00<00:00, 437.52it/s]
Chain 1: 100%|██████████| 400/400 [00:00<00:00, 652.65it/s]
Chain 0: 100%|██████████| 400/400 [00:00<00:00, 649.17it/s]
Chain 1: 100%|██████████| 400/400 [00:00<00:00, 526.80it/s]
Chain 0: 100%|██████████| 400/400 [00:00<00:00, 629.62it/s]
Chain 1: 100%|██████████| 400/400 [00:00<00:00, 613.37it/s]
Chain 0: 100%|██████████| 400/400 [00:00<00:00, 565.26it/s]
Chain 1: 100%|██████████| 400/400 [00:00<00:00, 587.76it/s]
Chain 0: 100%|██████████| 400/400 [00:00<00:00, 610.63it/s]
Chain 1: 100%|██████████| 400/400 [00:00<00:00, 641.00it/s]
100%|██████████| 5/5 [00:13<00:00,  2.67s/it]


In [14]:
# Generate Hessian objects for each model

hessians = produce_hessians(
    models=models,
    data_loader=train_loader,
    num_batches=1,
    criterion=criterion,
    device=device,
)

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [15]:
# Produce model eigenspectra

hessian_figs, eigenspectrum_data = produce_eigenspectra(
    hessians=hessians,
    plot_type="log",
)


Casting complex values to real discards the imaginary part



In [16]:
# Compute Hessian rank for each model

hessian_ranks = find_hessian_dimensionality(eigenspectrum_data)
hessian_ranks = np.array(hessian_ranks)
print(hessian_ranks)

[    10     94    938   9509 100000]


#### **3. Display figures and results**

The following data is presented:
- Hessian eigenspectra of models of different dimensionalities
- Plot of $ \log(\text{RLCT}) $ against $ \log\left(\frac{\text{rank}(\text{Hess})}{2}\right) $.
- Evolution of RLCT moving average for the different models to check convergence.

In [17]:
# Display Hessian eigenspectra data

for i, hessian_fig in enumerate(hessian_figs):
    if i == len(hessian_figs) - 1:
        hessian_fig.update_layout(title = f"Combined eigenspectra, d = {model_args['d']}")
    else:
        hessian_fig.update_layout(title = f"Hessian eigenspectrum, r = {10**i}, d = {model_args['d']}, Hessian rank = {hessian_ranks[i]}, RLCT = {rlct_values[i]}")
    hessian_fig.show()

In [18]:
# Plot of log(RLCT) vs. log(hessian rank / 2)

rlct_fig = go.Figure()
rlct_fig.add_trace(go.Scatter(
    x=np.log10(model_args["r_list"]/2),
    y=np.log10(rlct_values),
    mode="markers",
    name="Experimental",
))
rlct_fig.add_trace(go.Scatter(
    x=np.log10(model_args["r_list"]/2),
    y=np.log10(model_args["r_list"]/2),
    mode="lines",
    name="Theoretical",

))
rlct_fig.update_layout(
    title=f"RLCT vs. rank(Hess)/2, for d = {model_args['d']}",
    yaxis_title="log(RLCT)",
    xaxis_title="log(rank(Hess) / 2)"
)
rlct_fig.show()


invalid value encountered in log10



In [19]:
# Plot of LLC convergence data over models

rlct_converge_fig = go.Figure()
for i, results in enumerate(history):
    rlct_converge_fig.add_trace(go.Scatter(
        y=np.log10(results["llc/moving_avg"][1]),
        mode="lines",
        name=f"{10**i}"
    ))
rlct_converge_fig.update_layout(
    title=f"Evolution of log(RLCT) as moving average over draws for different models, d = {model_args['d']}",
    xaxis_title="Draws",
    yaxis_title="log(RLCT)",
    legend_title="r (non-free parameters)",
)


invalid value encountered in log10



In [62]:
# Compile figs into a list and save to HTML

figures = []
figures += hessian_figs
figures.append(rlct_fig)
figures.append(rlct_converge_fig)

combined_args = {**model_args, **data_args, **devinterp_args}
summary = pprint.pformat(combined_args)

curr_time = datetime.now().strftime("%Y-%m-%d-%H-%M")
write_figs_to_html(
    figs=figures,
    dest=f"./min_singular_models/min_singular_models_{curr_time}.html",
    title="Does RLCT = rank(Hess) / 2 for a minimally singular model?",
    summary=summary,
)