# Hessian Analysis of DLNs

This notebook compares the MSE loss and energy Hessian at the origin for deep linear networks (DLNs) on different datasets and architectures.



## Setup

In [None]:
#@title Installations


%%capture
!sudo apt install nvidia-utils-515
!pip install -U kaleido
!pip install gif==3.0.0

!pip install git+https://github.com/greydanus/mnist1d.git@master

In [None]:
#@title Imports


import os
import random
import subprocess
import numpy as np
from typing import Tuple, List, Dict, Optional

from mnist1d.data import get_dataset, get_dataset_args

import torch
import torch.nn as nn
from torch.nn.utils import vector_to_parameters
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

from jax import jacfwd, jacrev
from jax.numpy.linalg import eigh

import gif
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import plotly.graph_objs as go
import plotly.express as px
from plotly.colors import hex_to_rgb
from plotly.express.colors import sample_colorscale


In [None]:
#@title Config

DATASETS = ["toy_gaussian", "MNIST", "MNIST-1D"]
N_HIDDEN_WIDTHS = {
    "toy_gaussian": [
        [1, 10],
        [2, 5],
        [3, 4],
        [4, 4]
    ],
    "MNIST": [
        [1, 10],
        [2, 10],
        [3, 5],
        [4, 5]
    ],
    "MNIST-1D": [
        [1, 100],
        [2, 50],
        [3, 10],
        [4, 5],
    ]
}
INIT_TYPES = {
    1: ["origin"],
    2: ["origin"],
    3: ["origin", "other_saddle"],
    4: ["origin", "other_saddle"],
}
N_SEEDS = 1
RESULTS_DIR = "results"

# toy dataset
DATA_MEAN, DATA_STD = 1., 0.1
BATCH_SIZE = 64

# PC hyperparameters
N_ITERS = 50
DT = 0.1

# landscape plotting
DOMAINS = [2, 1, 5e-1, 1e-1, 5e-2]
SAMPLING_RESOLUTION = 30
COLORSCALE = "RdBu_r"
PLOT_ITERS = [0, 1, 2, 3, 4, 5, 10, 20, 50, 100, 200, 500, 1000]


In [None]:
#@title Utils


def setup_experiment(results_dir, dataset, data_dim, n_hidden, width, init_type, seed):
    print(
f"""
Starting Hessian analysis with experiment configuration:

  Dataset: {dataset}
  Data dim: {data_dim}
  N hidden: {n_hidden}
  Width: {width}
  Init type: {init_type}
  Seed: {seed}
"""
)
    if dataset == "toy_gaussian":
        experiment_dir = os.path.join(
            results_dir,
            dataset,
            f"data_dim_{data_dim}",
            f"n_hidden_{n_hidden}",
            f"width_{width}",
            f"{init_type}_init",
            str(seed)
        )
    else:
        experiment_dir = os.path.join(
            results_dir,
            dataset,
            f"n_hidden_{n_hidden}",
            f"width_{width}",
            f"{init_type}_init",
            str(seed)
        )
    return experiment_dir


def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_fc_network(input_dim, n_hidden, width, act_fn, output_dim, bias=False):
    layers = []
    for n in range(n_hidden):
        n_input = input_dim if n == 0 else width
        if act_fn == "linear":
            hidden_layer = nn.Sequential(nn.Linear(n_input, width, bias=bias))
        elif act_fn == "tanh":
            hidden_layer = nn.Sequential(
                nn.Linear(n_input, width, bias=bias),
                nn.Tanh()
            )
        elif act_fn == "relu":
            hidden_layer = nn.Sequential(
                nn.Linear(n_input, width, bias=bias),
                nn.ReLU(inplace=True)
            )
        layers.append(hidden_layer)

    output_layer = nn.Sequential(nn.Linear(width, output_dim, bias=bias))
    layers.append(output_layer)
    network = nn.Sequential(*layers)
    return network


def zero_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, mean=0., std=0.)


def init_weights(model, init_type):
    if init_type == "origin":
        model.network.apply(zero_weights)
    else:
        n_layers = len(model.network)
        for l in range(n_layers):
            # zero weights all layers except penultimate
            if l+1 != n_layers-1:
                model.network[l].apply(zero_weights)


In [None]:
#@title Datasets


def get_dataset_sample(dataset_id, data_dim, n_hidden):
    if dataset_id == "toy_gaussian":
        input = np.random.normal(
            loc=DATA_MEAN,
            scale=DATA_STD,
            size=(data_dim, BATCH_SIZE)
        )
        target = -input
        if data_dim == 3 and n_hidden == 4:
            y2 = np.random.normal(
                loc=DATA_MEAN,
                scale=DATA_STD,
                size=BATCH_SIZE
            )
            target[1, :] = y2
    elif dataset_id == "MNIST":
        input, target = get_MNIST_sample()
    elif dataset_id == "MNIST-1D":
        input, target = get_MNIST_1d_sample()
    else:
        raise ValueError("Invalid dataset ID. Options are 'MNIST' and 'toy_gaussian'")

    return input, target


def get_MNIST_1d_sample():
    args = get_dataset_args()
    dataset = get_dataset(
        args,
        path=f"./mnist1d_data.pkl",
        download=False,
        regenerate=False
    )
    train_data = TensorDataset(
        torch.Tensor(dataset["x"]),
        torch.Tensor(one_hot(dataset["y"]))
    )
    data_loader = DataLoader(
        dataset=train_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True
    )
    img_batch, label_batch = next(iter(data_loader))
    return img_batch.numpy().T, label_batch.numpy().T


def get_data_dim(dataset, n_hidden):
    if dataset != "toy_gaussian":
        data_dim = None
    else:
        if n_hidden > 2:
            data_dim = 3
        else:
            data_dim = 2

    return data_dim


def get_MNIST_sample():
    train_data = MNIST(train=True)
    train_loader = DataLoader(
        dataset=train_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True
    )
    img_batch, label_batch = next(iter(train_loader))
    return img_batch.numpy().T, label_batch.numpy().T


class MNIST(datasets.MNIST):
    def __init__(self, train, normalise=True, save_dir="./data"):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.1307), std=(0.3081)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


def one_hot(labels, n_classes=10):
    arr = torch.eye(n_classes)
    return arr[labels]


In [None]:
#@title Models


class BPN(nn.Module):
    def __init__(self, network):
        super(BPN, self).__init__()
        self.network = network

    def forward(self, x):
        x = self.network(x)
        return x

    def get_weights(self):
        weights = []
        for param in self.network.parameters():
            if len(param.shape) > 1:
                weights.append(param.cpu().detach().numpy())
        return weights



class PCN(object):
    def __init__(self, network, dt, device="cpu"):
        self.network = network.to(device)
        self.n_layers = len(self.network)
        self.n_nodes = self.n_layers + 1
        self.dt = dt
        self.n_params = sum(
            p.numel() for p in network.parameters() if p.requires_grad
        )
        self.device = device

    def reset(self):
        self.zero_grad()
        self.preds = [None] * self.n_nodes
        self.errs = [None] * self.n_nodes
        self.xs = [None] * self.n_nodes

    def reset_xs(self, prior, init_std):
        self.set_prior(prior)
        self.propagate_xs()
        for l in range(self.n_layers):
            self.xs[l] = torch.empty(self.xs[l].shape).normal_(
                mean=0,
                std=init_std
            ).to(self.device)

    def set_obs(self, obs):
        self.xs[-1] = obs.clone()

    def set_prior(self, prior):
        self.xs[0] = prior.clone()

    def forward(self, x):
        return self.network(x)

    def propagate_xs(self):
        for l in range(1, self.n_layers):
            self.xs[l] = self.network[l - 1](self.xs[l - 1])

    def infer_train(
            self,
            obs,
            prior,
            n_iters,
            record_grad_norms=False,
        ):
        self.reset()
        self.set_prior(prior)
        self.propagate_xs()
        self.set_obs(obs)

        if record_grad_norms:
            grad_norms_iters = [np.zeros(n_iters) for p in range(len(self.network)*2)]

        dt = self.dt
        tot_energy_iters = []
        for t in range(n_iters):
            self.network.zero_grad()
            self.preds[-1] = self.network[self.n_layers - 1](self.xs[self.n_layers - 1])
            self.errs[-1] = self.xs[-1] - self.preds[-1]

            for l in reversed(range(1, self.n_layers)):
                self.preds[l] = self.network[l - 1](self.xs[l - 1])
                self.errs[l] = self.xs[l] - self.preds[l]
                _, epsdfdx = torch.autograd.functional.vjp(
                    self.network[l],
                    self.xs[l],
                    self.errs[l + 1]
                )
                with torch.no_grad():
                    dx = epsdfdx - self.errs[l]
                    self.xs[l] = self.xs[l] + self.dt * dx

            tot_energy_iters.append(self.compute_tot_energy())
            if t > 0 and tot_energy_iters[t] >= tot_energy_iters[t-1] and self.dt > 0.025:
                self.dt /= 2
                if self.dt <= 0.025:
                    self.dt = dt
                    break

            if (t+1) != n_iters:
                self.clear_grads()

        if record_grad_norms:
            self.set_grads(
                grad_norms_iters=grad_norms_iters,
                t=t
            )
        else:
            self.set_grads()

        if record_grad_norms:
            return tot_energy_iters, grad_norms_iters
        else:
            return tot_energy_iters

    def infer_test(
            self,
            obs,
            prior,
            n_iters,
            update_prior=True,
            update_obs=False,
            init_std=0.05,
        ):

        self.reset()
        self.reset_xs(prior, init_std)
        if not update_prior:
            self.set_prior(prior)
        self.set_obs(obs)

        for t in range(n_iters):
            self.network.zero_grad()
            self.preds[-1] = self.network[self.n_layers - 1](self.xs[self.n_layers - 1])
            self.errs[-1] = self.xs[-1] - self.preds[-1]

            if update_obs:
                with torch.no_grad():
                    self.xs[-1] = self.xs[-1] + self.dt * (- self.errs[-1])

            for l in reversed(range(1, self.n_layers)):
                self.preds[l] = self.network[l - 1](self.xs[l - 1])
                self.errs[l] = self.xs[l] - self.preds[l]
                _, epsdfdx = torch.autograd.functional.vjp(
                    self.network[l],
                    self.xs[l],
                    self.errs[l + 1]
                )
                with torch.no_grad():
                    dx = epsdfdx - self.errs[l]
                    self.xs[l] = self.xs[l] + self.dt * dx

            if update_prior:
                _, epsdfdx = torch.autograd.functional.vjp(
                    self.network[0],
                    self.xs[0],
                    self.errs[1]
                )
                with torch.no_grad():
                    self.xs[0] = self.xs[0] + self.dt * epsdfdx

            if (t+1) != n_iters:
                self.clear_grads()

        if update_prior:
            return self.xs[0]
        elif update_obs:
            return self.xs[-1]

    def set_grads(self, grad_norms_iters=None, t=None):
        n = 0
        for l in range(self.n_layers):
            for i, param in enumerate(self.network[l].parameters()):
                dparam = torch.autograd.grad(
                    self.preds[l + 1],
                    param,
                    - self.errs[l + 1],
                    allow_unused=True,
                    retain_graph=True
                )[0]
                param.grad = dparam.clone()
                if grad_norms_iters is not None:
                    is_even = (i+1) % 2 == 0
                    dparam_norm = vector_norm(dparam) if is_even else matrix_norm(dparam)
                    grad_norms_iters[n][t] = dparam_norm.item()
                    n += 1

    def zero_grad(self):
        self.network.zero_grad()

    def clear_grads(self):
        with torch.no_grad():
            for l in range(1, self.n_nodes):
                self.preds[l] = self.preds[l].clone()
                self.errs[l] = self.errs[l].clone()
                self.xs[l] = self.xs[l].clone()

    def save_weights(self, path):
        torch.save(self.network.state_dict(), path)

    def load_weights(self, path):
        self.network.load_state_dict(torch.load(path))

    def get_weights(self):
        weights = []
        for param in self.network.parameters():
            if len(param.shape) > 1:
                weights.append(param.cpu().detach().numpy())
        return weights

    def compute_tot_energy(self):
        energy = 0.
        for err in self.errs:
            if err is not None:
                energy += (err**2).sum()
        return energy.item()

    def parameters(self):
        return self.network.parameters()

    def __str__(self):
        return f"PCN(\n{self.network}\n"


In [None]:
#@title Weight manipulation
# Adapted from https://github.com/tomgoldstein/loss-landscape/blob/master/net_plotter.py


def get_weights(net):
    """ Extract parameters from net, and return a list of tensors"""
    return [p.data for p in net.parameters()]


def get_random_weights(weights, device='cpu'):
    """
        Produce a random direction that is a list of random Gaussian tensors
        with the same shape as the network's weights, so one direction entry per weight.
    """
    return [torch.randn(w.size()).to(device) for w in weights]


def normalize_direction(direction, weights, norm='filter'):
    """
        Rescale the direction so that it has similar norm as their corresponding
        model in different levels.

        Args:
          direction: a variables of the random direction for one layer
          weights: a variable of the original model for one layer
          norm: normalization method, 'filter' | 'layer' | 'weight'
    """
    if norm == 'filter':
        # Rescale the filters (weights in group) in 'direction' so that each
        # filter has the same norm as its corresponding filter in 'weights'.
        for d, w in zip(direction, weights):
            d.mul_(w.norm()/(d.norm() + 1e-10))
    elif norm == 'layer':
        # Rescale the layer variables in the direction so that each layer has
        # the same norm as the layer variables in weights.
        direction.mul_(weights.norm()/direction.norm())
    elif norm == 'weight':
        # Rescale the entries in the direction so that each entry has the same
        # scale as the corresponding weight.
        direction.mul_(weights)
    elif norm == 'dfilter':
        # Rescale the entries in the direction so that each filter direction
        # has the unit norm.
        for d in direction:
            d.div_(d.norm() + 1e-10)
    elif norm == 'dlayer':
        # Rescale the entries in the direction so that each layer direction has
        # the unit norm.
        direction.div_(direction.norm())


def normalize_directions_for_weights(direction, weights, norm='filter', ignore='biasbn'):
    """
        The normalization scales the direction entries according to the entries of weights.
    """
    assert(len(direction) == len(weights))
    for d, w in zip(direction, weights):
        if d.dim() <= 1:
            if ignore == 'biasbn':
                d.fill_(0) # ignore directions for weights with 1 dimension
            else:
                d.copy_(w) # keep directions for weights/bias that are only 1 per node
        else:
            normalize_direction(d, w, norm)


def create_random_weight_direction(net, device='cpu', ignore='biasbn', norm='filter'):
    """
        Setup a random (normalized) direction with the same dimension as
        the weights.
        Args:
          net: the given trained model
          ignore: 'biasbn', ignore biases and BN parameters.
          norm: direction normalization method, including
                'filter" | 'layer' | 'weight' | 'dlayer' | 'dfilter'
        Returns:
          direction: a random direction with the same dimension as weights.
    """

    # random direction
    weights = get_weights(net) # a list of parameters.
    direction = get_random_weights(weights, device)
    normalize_directions_for_weights(direction, weights, norm, ignore)
    return direction


In [None]:
#@title Plotting


@gif.frame
def plot_hessian_matrix(hessian_matrix, save_path, log=False, title=None):
    fig, ax = plt.subplots()
    if log:
        heatmap = ax.imshow(
            X=hessian_matrix,
            cmap="bwr",
            norm=LogNorm()
        )
    else:
        heatmap = ax.imshow(
            X=hessian_matrix,
            cmap="bwr",
            vmin=-1,
            vmax=1
        )
    cbar = fig.colorbar(heatmap, ax=ax, location="right", ticks=[-1, 0, 1])
    cbar.ax.tick_params(labelsize=25)

    if len(hessian_matrix) > 10:
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
    else:
        ticks = np.arange(len(hessian_matrix), dtype=int)
        ax.set_xticks(ticks)
        ax.set_yticks(ticks)
        ax.set_xticklabels(ticks+1)
        ax.set_yticklabels(ticks+1)

    if title is not None:
        plt.title(title, fontsize=20)
    fig.savefig(save_path)
    return fig


@gif.frame
def plot_hessian_eigenvalues(eigenvalues, save_path, title=None):
    fig = go.Figure(
        data=go.Histogram(
            x=eigenvalues,
            histnorm="probability",
            xbins=dict(size=0.2),
            marker_color="#FF7F0E"
        )
    )
    fig.update_layout(
        height=325,
        width=500,
        title=dict(
            text=title if title is not None else "",
            y=0.75,
            x=0.5,
            xanchor="center",
            yanchor="top"
        ),
        xaxis=dict(title="Hessian eigenvalue"),
        yaxis=dict(
            title=f"Density (log)",
            type="log",
            exponentformat="power",
            dtick=1
        ),
        font=dict(size=16)
    )
    fig.write_image(save_path)
    return fig


def plot_loss_and_energy_hessian_eigenvals(hessian_eigenvals: List, n_hidden: str, dataset: str, save_path: str) -> None:
    fig = go.Figure()
    names = ["loss", "energy (numeric)", "energy (theory)"]
    colors = ["#EF553B", "#636EFA", "rgba(0,0,0,0)"]

    n_loss_bins = 5 if n_hidden == 1 and dataset != "toy_gaussian" else 10
    n_energy_bins = 30 if n_hidden == 1 and dataset != "toy_gaussian" else 10
    for eigenval, name, color in zip(hessian_eigenvals, names, colors):
        fig.add_trace(
            go.Histogram(
                x=eigenval,
                histnorm="probability",
                nbinsx=n_loss_bins if name == "loss" else n_energy_bins,
                name=name,
                marker=dict(
                    color=color,
                    line=dict(
                        color="black",
                        width=2 if "theory" in name else 0
                    )
                ),
            )
        )

    fig.update_layout(
        barmode="overlay",
        height=360,
        width=525,
        xaxis=dict(title="Hessian eigenvalue"),
        yaxis=dict(
            title=f"Density (log)",
            type="log",
            exponentformat="power",
            dtick=1,
        ),
        font=dict(size=18)
    )
    fig.update_layout(
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01,
            font=dict(size=16)
        )
    )
    fig.update_traces(opacity=0.75)
    fig.write_image(save_path)


In [None]:
#@title Hessian utils


def compute_hessian(f):
    return jacfwd(jacrev(f))


def mse_loss_fun(Ws, X, Y):
    n_hidden = len(Ws)-1
    if n_hidden == 1:
        loss = ( 0.5*(Y - Ws[1]@Ws[0]@X)**2 ).mean()
    elif n_hidden == 2:
        loss = ( 0.5*(Y - Ws[2]@Ws[1]@Ws[0]@X)**2 ).mean()
    elif n_hidden == 3:
        loss = ( 0.5*(Y - Ws[3]@Ws[2]@Ws[1]@Ws[0]@X)**2 ).mean()
    elif n_hidden == 4:
        loss = ( 0.5*(Y - Ws[4]@Ws[3]@Ws[2]@Ws[1]@Ws[0]@X)**2 ).mean()
    return loss


def energy_fun(Ws, Xs, n_iters, dt):
    n_hidden = len(Ws)-1
    if n_hidden == 1:
        # initialisation
        Z1 = Xs[1]
        e2 = Xs[2] - Ws[1]@Z1
        e1 = Z1 - Ws[0]@Xs[0]

        # iterative inference
        for t in range(n_iters):
            dZ1 = e1 - (e2.T@Ws[1]).T
            Z1 += - dZ1 * dt

            e2 = Xs[2] - Ws[1]@Z1
            e1 = Z1 - Ws[0]@Xs[0]

        energy = ( (0.5*e1**2).sum() + (0.5*e2**2).sum() ) / (BATCH_SIZE)

    elif n_hidden == 2:
        # initialisation
        Z1, Z2 = Xs[1], Xs[2]
        e3 = Xs[3] - Ws[2]@Z2
        e2 = Z2 - Ws[1]@Z1
        e1 = Z1 - Ws[0]@Xs[0]

        # iterative inference
        for t in range(n_iters):
            dZ1 = e1 - (e2.T@Ws[1]).T
            dZ2 = e2 - (e3.T@Ws[2]).T
            Z1 += - dZ1 * dt
            Z2 += - dZ2 * dt

            e3 = Xs[3] - Ws[2]@Z2
            e2 = Z2 - Ws[1]@Z1
            e1 = Z1 - Ws[0]@Xs[0]

        energy = ( (0.5*e1**2).sum() + (0.5*e2**2).sum() + (0.5*e3**2).sum() ) / (BATCH_SIZE)

    elif n_hidden == 3:
        # initialisation
        Z1, Z2, Z3 = Xs[1], Xs[2], Xs[3]
        e4 = Xs[4] - Ws[3]@Z3
        e3 = Z3 - Ws[2]@Z2
        e2 = Z2 - Ws[1]@Z1
        e1 = Z1 - Ws[0]@Xs[0]

        # iterative inference
        for t in range(n_iters):
            dZ1 = e1 - (e2.T@Ws[1]).T
            dZ2 = e2 - (e3.T@Ws[2]).T
            dZ3 = e3 - (e4.T@Ws[3]).T
            Z1 += - dZ1 * dt
            Z2 += - dZ2 * dt
            Z3 += - dZ3 * dt

            e4 = Xs[4] - Ws[3]@Z3
            e3 = Z3 - Ws[2]@Z2
            e2 = Z2 - Ws[1]@Z1
            e1 = Z1 - Ws[0]@Xs[0]

        energy = ( (0.5*e1**2).sum() + (0.5*e2**2).sum() + (0.5*e3**2).sum() + (0.5*e4**2).sum() ) / (BATCH_SIZE)

    elif n_hidden == 4:
        # initialisation
        Z1, Z2, Z3, Z4 = Xs[1], Xs[2], Xs[3], Xs[4]
        e5 = Xs[5] - Ws[4]@Z4
        e4 = Z4 - Ws[3]@Z3
        e3 = Z3 - Ws[2]@Z2
        e2 = Z2 - Ws[1]@Z1
        e1 = Z1 - Ws[0]@Xs[0]

        # iterative inference
        for t in range(n_iters):
            dZ1 = e1 - (e2.T@Ws[1]).T
            dZ2 = e2 - (e3.T@Ws[2]).T
            dZ3 = e3 - (e4.T@Ws[3]).T
            dZ4 = e4 - (e5.T@Ws[4]).T
            Z1 += - dZ1 * dt
            Z2 += - dZ2 * dt
            Z3 += - dZ3 * dt
            Z4 += - dZ4 * dt

            e5 = Xs[5] - Ws[4]@Z4
            e4 = Z4 - Ws[3]@Z3
            e3 = Z3 - Ws[2]@Z2
            e2 = Z2 - Ws[1]@Z1
            e1 = Z1 - Ws[0]@Xs[0]

        energy = ( (0.5*e1**2).sum() + (0.5*e2**2).sum() + (0.5*e3**2).sum() + (0.5*e4**2).sum() + (0.5*e5**2).sum() ) / (BATCH_SIZE)

    return energy


def reshape_to_hessian_matrix(hessian, weights):
    n_hidden = len(weights)-1
    if n_hidden == 1:
        N1, N2 = np.prod(weights[0].shape), np.prod(weights[1].shape)
        hessian_matrix = np.zeros((N1+N2, N1+N2))

        hessian_matrix[:N1, :N1] = hessian[0][0].reshape(N1, N1)
        hessian_matrix[:N1, N1:] = hessian[0][1].reshape(N1, N2)
        hessian_matrix[N1:, :N1] = hessian[1][0].reshape(N2, N1)
        hessian_matrix[N1:, N1:] = hessian[1][1].reshape(N2, N2)

    if n_hidden == 2:
        N1, N2, N3 = np.prod(weights[0].shape), np.prod(weights[1].shape), np.prod(weights[2].shape)
        hessian_matrix = np.zeros((N1+N2+N3, N1+N2+N3))

        hessian_matrix[:N1, :N1] = hessian[0][0].reshape(N1, N1)
        hessian_matrix[:N1, N1:N1+N2] = hessian[0][1].reshape(N1, N2)
        hessian_matrix[:N1, N1+N2:] = hessian[0][2].reshape(N1, N3)

        hessian_matrix[N1:N1+N2, :N1] = hessian[1][0].reshape(N2, N1)
        hessian_matrix[N1:N1+N2, N1:N1+N2] = hessian[1][1].reshape(N2, N2)
        hessian_matrix[N1:N1+N2, N1+N2:] = hessian[1][2].reshape(N2, N3)

        hessian_matrix[N1+N2:, :N1] = hessian[2][0].reshape(N3, N1)
        hessian_matrix[N1+N2:, N1:N1+N2] = hessian[2][1].reshape(N3, N2)
        hessian_matrix[N1+N2:, N1+N2:] = hessian[2][2].reshape(N3, N3)

    if n_hidden == 3:
        N1, N2, N3, N4 = np.prod(weights[0].shape), np.prod(weights[1].shape), np.prod(weights[2].shape), np.prod(weights[3].shape)
        hessian_matrix = np.zeros((N1+N2+N3+N4, N1+N2+N3+N4))

        hessian_matrix[:N1, :N1] = hessian[0][0].reshape(N1, N1)
        hessian_matrix[:N1, N1:N1+N2] = hessian[0][1].reshape(N1, N2)
        hessian_matrix[:N1, N1+N2:N1+N2+N3] = hessian[0][2].reshape(N1, N3)
        hessian_matrix[:N1, N1+N2+N3:] = hessian[0][3].reshape(N1, N4)

        hessian_matrix[N1:N1+N2, :N1] = hessian[1][0].reshape(N2, N1)
        hessian_matrix[N1:N1+N2, N1:N1+N2] = hessian[1][1].reshape(N2, N2)
        hessian_matrix[N1:N1+N2, N1+N2:N1+N2+N3] = hessian[1][2].reshape(N2, N3)
        hessian_matrix[N1:N1+N2, N1+N2+N3:] = hessian[1][3].reshape(N2, N4)

        hessian_matrix[N1+N2:N1+N2+N3, :N1] = hessian[2][0].reshape(N3, N1)
        hessian_matrix[N1+N2:N1+N2+N3, N1:N1+N2] = hessian[2][1].reshape(N3, N2)
        hessian_matrix[N1+N2:N1+N2+N3, N1+N2:N1+N2+N3] = hessian[2][2].reshape(N3, N3)
        hessian_matrix[N1+N2:N1+N2+N3, N1+N2+N3:] = hessian[2][3].reshape(N3, N4)

        hessian_matrix[N1+N2+N3:, :N1] = hessian[3][0].reshape(N4, N1)
        hessian_matrix[N1+N2+N3:, N1:N1+N2] = hessian[3][1].reshape(N4, N2)
        hessian_matrix[N1+N2+N3:, N1+N2:N1+N2+N3] = hessian[3][2].reshape(N4, N3)
        hessian_matrix[N1+N2+N3:, N1+N2+N3:] = hessian[3][3].reshape(N4, N4)

    if n_hidden == 4:
        N1, N2, N3, N4, N5 = np.prod(weights[0].shape), np.prod(weights[1].shape), np.prod(weights[2].shape), np.prod(weights[3].shape), np.prod(weights[4].shape)
        hessian_matrix = np.zeros((N1+N2+N3+N4+N5, N1+N2+N3+N4+N5))

        hessian_matrix[:N1, :N1] = hessian[0][0].reshape(N1, N1)
        hessian_matrix[:N1, N1:N1+N2] = hessian[0][1].reshape(N1, N2)
        hessian_matrix[:N1, N1+N2:N1+N2+N3] = hessian[0][2].reshape(N1, N3)
        hessian_matrix[:N1, N1+N2+N3:N1+N2+N3+N4] = hessian[0][3].reshape(N1, N4)
        hessian_matrix[:N1, N1+N2+N3+N4:] = hessian[0][4].reshape(N1, N5)

        hessian_matrix[N1:N1+N2, :N1] = hessian[1][0].reshape(N2, N1)
        hessian_matrix[N1:N1+N2, N1:N1+N2] = hessian[1][1].reshape(N2, N2)
        hessian_matrix[N1:N1+N2, N1+N2:N1+N2+N3] = hessian[1][2].reshape(N2, N3)
        hessian_matrix[N1:N1+N2, N1+N2+N3:N1+N2+N3+N4] = hessian[1][3].reshape(N2, N4)
        hessian_matrix[N1:N1+N2, N1+N2+N3+N4:] = hessian[1][4].reshape(N2, N5)

        hessian_matrix[N1+N2:N1+N2+N3, :N1] = hessian[2][0].reshape(N3, N1)
        hessian_matrix[N1+N2:N1+N2+N3, N1:N1+N2] = hessian[2][1].reshape(N3, N2)
        hessian_matrix[N1+N2:N1+N2+N3, N1+N2:N1+N2+N3] = hessian[2][2].reshape(N3, N3)
        hessian_matrix[N1+N2:N1+N2+N3, N1+N2+N3:N1+N2+N3+N4] = hessian[2][3].reshape(N3, N4)
        hessian_matrix[N1+N2:N1+N2+N3, N1+N2+N3+N4:] = hessian[2][4].reshape(N3, N5)

        hessian_matrix[N1+N2+N3:N1+N2+N3+N4, :N1] = hessian[3][0].reshape(N4, N1)
        hessian_matrix[N1+N2+N3:N1+N2+N3+N4, N1:N1+N2] = hessian[3][1].reshape(N4, N2)
        hessian_matrix[N1+N2+N3:N1+N2+N3+N4, N1+N2:N1+N2+N3] = hessian[3][2].reshape(N4, N3)
        hessian_matrix[N1+N2+N3:N1+N2+N3+N4, N1+N2+N3:N1+N2+N3+N4] = hessian[3][3].reshape(N4, N4)
        hessian_matrix[N1+N2+N3:N1+N2+N3+N4, N1+N2+N3+N4:] = hessian[3][4].reshape(N4, N5)

        hessian_matrix[N1+N2+N3+N4:, :N1] = hessian[4][0].reshape(N5, N1)
        hessian_matrix[N1+N2+N3+N4:, N1:N1+N2] = hessian[4][1].reshape(N5, N2)
        hessian_matrix[N1+N2+N3+N4:, N1+N2:N1+N2+N3] = hessian[4][2].reshape(N5, N3)
        hessian_matrix[N1+N2+N3+N4:, N1+N2+N3:N1+N2+N3+N4] = hessian[4][3].reshape(N5, N4)
        hessian_matrix[N1+N2+N3+N4:, N1+N2+N3+N4:] = hessian[4][4].reshape(N5, N5)

    return hessian_matrix


def initialise_pc_activities(Ws, input, target):
    n_hidden = len(Ws)-1
    if n_hidden == 1:
        activities = [input, Ws[0]@input, target]
    if n_hidden == 2:
        activities = [input, Ws[0]@input, Ws[1]@Ws[0]@input, target]
    if n_hidden == 3:
        activities = [input, Ws[0]@input, Ws[1]@Ws[0]@input, Ws[2]@Ws[1]@Ws[0]@input, target]
    if n_hidden == 4:
        activities = [input, Ws[0]@input, Ws[1]@Ws[0]@input, Ws[2]@Ws[1]@Ws[0]@input, Ws[3]@Ws[2]@Ws[1]@Ws[0]@input, target]
    return activities


def compute_and_plot_hessian_metrics(model, input, target, init_type, save_dir):
    weights = model.get_weights()
    activities = initialise_pc_activities(Ws=weights, input=input, target=target)

    hessian_matrix_frames, log_hessian_matrix_frames, hessian_eigenvals_frames = [], [], []
    for t in range(N_ITERS+1):
        if t in PLOT_ITERS:
            energy_hessian = compute_hessian(energy_fun)(
                weights,
                activities,
                n_iters=t,
                dt=DT
            )
            hessian_matrix = reshape_to_hessian_matrix(
                hessian=energy_hessian,
                weights=weights
            )
            fig = plot_hessian_matrix(
                hessian_matrix=hessian_matrix,
                title=f"Inference iteration = {t}",
                save_path=f"{save_dir}/hessian_matrix_iter_{t}.pdf"
            )
            hessian_matrix_frames.append(fig)
            fig = plot_hessian_matrix(
                hessian_matrix=hessian_matrix,
                log=True,
                title=f"Inference iteration = {t}",
                save_path=f"{save_dir}/log_hessian_matrix_iter_{t}.pdf"
            )
            log_hessian_matrix_frames.append(fig)

            eigenvals, eigenvecs = eigh(hessian_matrix)
            fig = plot_hessian_eigenvalues(
                eigenvalues=eigenvals,
                title=f"Inference iteration = {t}",
                save_path=f"{save_dir}/hessian_eigenvalues_iter_{t}.pdf"
            )
            hessian_eigenvals_frames.append(fig)

        if (t+1) == N_ITERS+1:
            print(f"\tmax Hessian eigenvalue: {eigenvals[0]}")
            print(f"\tmin Hessian eigenvalue: {eigenvals[-1]}\n")

    gif.save(
        frames=hessian_matrix_frames,
        path=f"{save_dir}/hessian_matrix_infer_dynamics.gif",
        duration=1,
        unit="s"
    )
    gif.save(
        frames=log_hessian_matrix_frames,
        path=f"{save_dir}/log_hessian_matrix_infer_dynamics.gif",
        duration=1,
        unit="s"
    )
    gif.save(
        frames=hessian_eigenvals_frames,
        path=f"{save_dir}/hessian_eigenvalues_infer_dynamics.gif",
        duration=1,
        unit="s"
    )
    if init_type == "origin":
        theory_eigenvals = compute_theoretical_energy_eigenvals(
            model=model,
            X=input,
            Y=target
        )
    else:
        theory_eigenvals = None
    return eigenvals, theory_eigenvals, np.array(eigenvecs)


def compute_theoretical_energy_eigenvals(model, X, Y):
    n_params = model.n_params
    width = model.network[0][0].out_features
    theory_hessian_at_origin = np.zeros((n_params, n_params))
    output_dim = Y.shape[0]

    sigma_yy = (Y@Y.T) / BATCH_SIZE
    d2F_dWL2 = - np.kron(sigma_yy, np.identity(width))
    theory_hessian_at_origin[-output_dim*width:, -output_dim*width:] = d2F_dWL2

    n_hidden = len(model.network)-1
    if n_hidden == 1:
        input_dim = X.shape[0]
        sigma_xy = (X@Y.T) / BATCH_SIZE
        d2F_dW1dW2 = - np.kron(sigma_xy, np.identity(width))
        theory_hessian_at_origin[:input_dim*width, input_dim*width:] = d2F_dW1dW2
        theory_hessian_at_origin[input_dim*width:, :input_dim*width] = d2F_dW1dW2.T

    theory_eigenvals, _ = eigh(theory_hessian_at_origin)

    return theory_eigenvals


In [None]:
#@title Landscape plotting


@gif.frame
def plot_objective_surface(
        objective_mesh: np.ndarray,
        weights: Tuple[np.ndarray, np.ndarray],
        objective_name: str,
        save_path: str,
        inference_iter: Optional[int] = None,
        tot_objective_max: Optional[float] = None,
        show_background: bool = True
    ) -> go.Figure():
    objective_notation = "L" if objective_name == "loss" else "F"
    objective_max, objective_min = objective_mesh.max(), objective_mesh.min()
    min_max_diff = objective_max - objective_min
    fig = go.Figure(
        data=go.Surface(
            z=objective_mesh,
            x=weights[0],
            y=weights[1],
            colorscale=COLORSCALE,
            colorbar=dict(
                title=f"$\LARGE{{\mathcal{{{objective_notation}}}}}$",
                x=0.85,
                y=0.57,
                len=0.3,
                titleside="right",
                tickfont=dict(size=16),
                tickvals=[objective_min, objective_max],
                ticktext=["Low" if min_max_diff > 0.1 else "", "High" if min_max_diff > 0.1 else ""]
            )
        )
    )
    fig.update_traces(
        contours_z=dict(
            show=True,
            usecolormap=True,
            highlightcolor="limegreen",
            project_z=True if show_background else False,
        )
    )
    fig.update_layout(
        scene=dict(zaxis=(dict(
            title="",
            range=[0, tot_objective_max*2 if tot_objective_max is not None else objective_max*2],
            showticklabels=False
        ))),
        font=dict(size=16),
        height=600,
        width=700,
        margin=dict(r=30, b=10, l=0, t=40),
        scene_aspectmode="cube"
    )
    if show_background:
        title = f"Inference iteration: {inference_iter}" if inference_iter is not None else ""
        fig.update_layout(
            scene=dict(
                xaxis=dict(title="", nticks=3, autorange="reversed"),
                yaxis=dict(title="", nticks=3, autorange="reversed"),
            ),
            scene_camera=dict(
                center=dict(x=0.05, y=0.2, z=0),
                eye=dict(x=1.4, y=1.4, z=1.25)
            ),
            title=dict(
                text=title,
                y=0.95,
                x=0.42,
                xanchor="center",
                yanchor="top"
            )
        )
    else:
        fig.update_layout(
            scene=dict(
                xaxis=dict(
                    title="",
                    autorange="reversed",
                    showticklabels=False,
                    showbackground=False,
                ),
                yaxis=dict(
                    title="",
                    autorange="reversed",
                    showticklabels=False,
                    showbackground=False,
                ),
                zaxis=dict(showbackground=False)
            ),
            scene_camera=dict(
                center=dict(x=0.2, y=0.15, z=0),
                eye=dict(x=1.3, y=1.3, z=1)
            ),
        )

    fig.write_image(save_path)
    return fig


In [None]:
#@title Projection visualisations


def visualise_2D_loss_projections(
        model,
        input,
        target,
        domain,
        device,
        save_dir,
        hessian_eigenvecs = None
    ):
    input = torch.Tensor(input.T).to(device)
    target = torch.Tensor(target.T).to(device)

    n_directions = 2
    if hessian_eigenvecs is None:
        random_directions = []
        for i in range(n_directions):
            random_direction = create_random_weight_direction(
                net=model.network,
                device=device
            )
            random_direction = [w for w in random_direction if len(w.shape) > 1]
            random_directions.append(random_direction)
    else:
        hessian_eigenvecs = torch.Tensor(np.array(hessian_eigenvecs)).to(device)

    scaling_factors = [
        np.linspace(
            -domain, domain, SAMPLING_RESOLUTION
        ) for d in range(n_directions)
    ]

    loss_fn = nn.MSELoss()
    loss_mesh = np.zeros((SAMPLING_RESOLUTION, SAMPLING_RESOLUTION))
    for j, a in enumerate(scaling_factors[0]):
        for i, b in enumerate(scaling_factors[1]):

            if hessian_eigenvecs is None:
                n = 0
                for p in model.parameters():
                    if len(p.shape) > 1:
                        p.data = p.data + (a * random_directions[0][n]) + (b * random_directions[1][n])
                        n += 1
            else:
                param_vec = torch.cat(
                    [torch.flatten(p) for p in model.parameters() if len(p.shape) > 1]
                ).to(device)
                param_vec = param_vec + (a * hessian_eigenvecs[0]) + (b * hessian_eigenvecs[-1])
                vector_to_parameters(param_vec, model.parameters())

            preds = model.forward(input)
            loss = loss_fn(preds, target)
            loss_mesh[i, j] = loss

    directions_type = "random" if hessian_eigenvecs is None else "hessian"
    plot_objective_surface(
        objective_mesh=loss_mesh,
        weights=scaling_factors,
        objective_name="loss",
        save_path=f"{save_dir}/{directions_type}_surface_{domain}.pdf"
    )


def visualise_2D_energy_projections(
        model,
        input,
        target,
        domain,
        device,
        save_dir,
        hessian_eigenvecs = None
    ):
    input = torch.Tensor(input.T).to(device)
    target = torch.Tensor(target.T).to(device)

    n_directions = 2
    random_directions = []
    for i in range(n_directions):
        random_direction = create_random_weight_direction(
            net=model.network,
            device=device
        )
        random_direction = [w for w in random_direction if len(w.shape) > 1]
        random_directions.append(random_direction)

    if hessian_eigenvecs is not None:
        hessian_eigenvecs = torch.Tensor(np.array(hessian_eigenvecs)).to(device)

    scaling_factors = [
        np.linspace(
            -domain, domain, SAMPLING_RESOLUTION
        ) for d in range(n_directions)
    ]

    loss_fn = nn.MSELoss()
    energy_mesh = np.zeros((SAMPLING_RESOLUTION, SAMPLING_RESOLUTION, N_ITERS+1))
    for j, a in enumerate(scaling_factors[0]):
        for i, b in enumerate(scaling_factors[1]):

            if hessian_eigenvecs is None:
                n = 0
                for p in model.parameters():
                    if len(p.shape) > 1:
                        p.data = p.data + (a * random_directions[0][n]) + (b * random_directions[1][n])
                        n += 1
            else:
                param_vec = torch.cat(
                    [torch.flatten(p) for p in model.parameters() if len(p.shape) > 1]
                ).to(device)
                param_vec = param_vec + (a * hessian_eigenvecs[0]) + (b * hessian_eigenvecs[-1])
                vector_to_parameters(param_vec, model.parameters())

            preds = model.forward(input)
            loss = loss_fn(preds, target)
            energy_mesh[i, j, 0] = loss

            tot_energy_iters = model.infer_train(
                obs=target,
                prior=input,
                n_iters=N_ITERS
            )
            if len(tot_energy_iters) != N_ITERS:
                n_missing_iters = N_ITERS - len(tot_energy_iters)
                tot_energy_iters.extend([tot_energy_iters[-1]]*n_missing_iters)

            energy_mesh[i, j, 1:] = tot_energy_iters

    energy_surface_frames = []
    directions_type = "random" if hessian_eigenvecs is None else "hessian"
    energy_iters_max = energy_mesh.max()
    for t in range(N_ITERS+1):
        if t in PLOT_ITERS:
            fig = plot_objective_surface(
                objective_mesh=energy_mesh[:, :, t],
                weights=scaling_factors,
                objective_name="energy",
                inference_iter=t,
                tot_objective_max=energy_iters_max,
                save_path=f"{save_dir}/{directions_type}_surface_{domain}_iter_{t}.pdf"
            )
            energy_surface_frames.append(fig)

    gif.save(
        energy_surface_frames,
        f"{save_dir}/{directions_type}_surface_{domain}_dynamics.gif",
        duration=1,
        unit="s"
    )


## Scripts

In [None]:
#@title Loss Hessian Analysis


def analyse_loss_hessian(dataset, data_dim, n_hidden, width, init_type, seed, save_dir):
    print("\tAnalysing loss Hessian at the origin...")
    set_seed(seed)
    device = get_device()
    os.makedirs(save_dir, exist_ok=True)

    input, target = get_dataset_sample(
        dataset_id=dataset,
        data_dim=data_dim,
        n_hidden=n_hidden
    )
    network = get_fc_network(
        input_dim=input.shape[0],
        n_hidden=n_hidden,
        width=width,
        act_fn="linear",
        output_dim=target.shape[0]
    )
    model = BPN(network=network).to(device)
    init_weights(model=model, init_type=init_type)

    # compute and plot Hessian matrix
    weights = model.get_weights()
    loss_hessian = compute_hessian(mse_loss_fun)(
        weights,
        input,
        target
    )
    hessian_matrix = reshape_to_hessian_matrix(
        hessian=loss_hessian,
        weights=weights
    )
    plot_hessian_matrix(
        hessian_matrix=hessian_matrix,
        save_path=f"{save_dir}/hessian_matrix.pdf"
    )
    hessian_eigenvals, hessian_eigenvecs = eigh(hessian_matrix)
    print(f"\tmax Hessian eigenvalue: {hessian_eigenvals[0]}")
    print(f"\tmin Hessian eigenvalue: {hessian_eigenvals[-1]}")
    hessian_eigenvals = hessian_eigenvals+1e-3 if dataset == "MNIST-1D" and n_hidden > 2 else hessian_eigenvals
    plot_hessian_eigenvalues(
        eigenvalues=hessian_eigenvals,
        save_path=f"{save_dir}/hessian_eigenvalues.pdf"
    )
    if dataset == "toy_gaussian" and n_hidden == 3 and width == 4:
        # plot landscape projected onto Hessian max and min eigenvectors
        for domain in DOMAINS:
            visualise_2D_loss_projections(
                model=model,
                input=input,
                target=target,
                domain=domain,
                device=device,
                save_dir=save_dir,
                hessian_eigenvecs=hessian_eigenvecs
            )

    return hessian_eigenvals


In [None]:
#@title Energy Hessian Analysis


def analyse_energy_hessian(dataset, data_dim, n_hidden, width, init_type, seed, save_dir):
    print("\n\tAnalysing energy Hessian at the origin...")
    set_seed(seed)
    device = get_device()
    os.makedirs(save_dir, exist_ok=True)

    input, target = get_dataset_sample(
        dataset_id=dataset,
        data_dim=data_dim,
        n_hidden=n_hidden
    )
    network = get_fc_network(
        input_dim=input.shape[0],
        n_hidden=n_hidden,
        width=width,
        act_fn="linear",
        output_dim=target.shape[0]
    )
    model = PCN(network=network, dt=DT, device=device)
    init_weights(model=model, init_type=init_type)

    # compute and plot Hessian matrix
    numeric_hessian_eigenvals, theory_hessian_eigenvals, hessian_eigenvecs = compute_and_plot_hessian_metrics(
        model=model,
        input=input,
        target=target,
        init_type=init_type,
        save_dir=save_dir
    )

    if dataset == "toy_gaussian" and n_hidden == 3 and width == 4:
        # plot landscape projected onto Hessian max and min eigenvectors
        for domain in DOMAINS:
            visualise_2D_energy_projections(
                model=model,
                input=input,
                target=target,
                domain=domain,
                device=device,
                save_dir=save_dir,
                hessian_eigenvecs=hessian_eigenvecs
            )

    return numeric_hessian_eigenvals, theory_hessian_eigenvals


In [None]:
#@title Main script


def run_hessian_analysis():
    for dataset in DATASETS:
        for n_hidden, width in N_HIDDEN_WIDTHS[dataset]:
            data_dim = get_data_dim(dataset=dataset, n_hidden=n_hidden)
            for init_type in INIT_TYPES[n_hidden]:
                for seed in range(N_SEEDS):
                    experiment_dir = setup_experiment(
                        results_dir=RESULTS_DIR,
                        dataset=dataset,
                        data_dim=data_dim,
                        n_hidden=n_hidden,
                        width=width,
                        init_type=init_type,
                        seed=seed,
                    )
                    loss_hessian_eigenvals = analyse_loss_hessian(
                        dataset=dataset,
                        data_dim=data_dim,
                        n_hidden=n_hidden,
                        width=width,
                        init_type=init_type,
                        seed=seed,
                        save_dir=f"{experiment_dir}/bp"
                    )
                    energy_hessian_eigenvals, theory_energy_hessian_eigenvals = analyse_energy_hessian(
                        dataset=dataset,
                        data_dim=data_dim,
                        n_hidden=n_hidden,
                        width=width,
                        init_type=init_type,
                        seed=seed,
                        save_dir=f"{experiment_dir}/pc"
                    )
                    np.save("loss_eigens.npy", loss_hessian_eigenvals)
                    np.save("energy_eigens.npy", energy_hessian_eigenvals)
                    np.save("theory_energy_eigens.npy", theory_energy_hessian_eigenvals)
                    plot_loss_and_energy_hessian_eigenvals(
                        hessian_eigenvals=[
                            loss_hessian_eigenvals,
                            energy_hessian_eigenvals,
                            theory_energy_hessian_eigenvals
                        ],
                        n_hidden=n_hidden,
                        dataset=dataset,
                        save_path=f"{experiment_dir}/hessian_eigenspectrum.pdf"
                    )


## Run analysis

In [None]:
run_hessian_analysis()
!zip -r DLNs_hessian_results.zip results

## Download results

In [None]:
import shutil
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%%capture
!zip -r /content/linear_fashion_depth_10.zip /content/results

In [None]:
colab_link = "/content/DLNs_hessian_results.zip"
gdrive_link = "/content/drive/MyDrive/"
shutil.copy(colab_link, gdrive_link)

'/content/drive/MyDrive/DLNs_hessian_results.zip'