# Convergence Experiments on DNNs

This notebook analyses the convergence of the stochastic gradient descent dynamics of deep neural networks (DNNs) trained with backprop and predictive coding when initialised near the origin.



## Setup

In [1]:
#@title Installations


%%capture
!sudo apt install nvidia-utils-515
!pip install -U kaleido
!pip install gif==3.0.0


In [2]:
#@title Imports


import os
import random
import subprocess
import numpy as np
from typing import Tuple, List, Dict, Optional

import torch
import torch.nn as nn
from torch.nn.utils import parameters_to_vector
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from torch.linalg import norm

from jax import jacfwd, jacrev
from jax.numpy.linalg import eigh

import gif
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.express as px
from plotly.colors import hex_to_rgb
from plotly.express.colors import sample_colorscale


In [3]:
#@title Config


DATASETS = ["MNIST", "Fashion-MNIST", "CIFAR10"]
N_HIDDEN_WIDTHS = {
    #"toy_gaussian": [[3, 4]],
    "MNIST": [[4, 500]],
    "Fashion-MNIST": [[4, 500]],
    "CIFAR10": [[4, None]]
}
ACT_FNS = ["linear", "tanh", "relu"]
INIT_TYPES = ["origin"]
OPTIMISERS = ["SGD"]
N_SEEDS = 1

DATA_DIR = "data"
RESULTS_DIR = "results"

FC_INPUT_DIM = 784
FC_OUTPUT_DIM = 10

# toy dataset
DATA_MEAN, DATA_STD = 1., 0.1
INPUT_DIM = 3

# PC hyperparameters
N_ITERS = 50
DT = 0.1

# optimization hyperparameters
LR = 1e-3
BATCH_SIZE = 64
MAX_EPOCHS = 35
LOG_BATCH_EVERY = 100
LOSS_TOLERANCE = 0.001

# landscape plotting
DOMAINS = [2, 1, 5e-1, 1e-1, 5e-2]
SAMPLING_RESOLUTION = 30
COLORSCALE = "RdBu_r"
PLOT_ITERS = [0, 1, 2, 3, 4, 5, 10, 20, 50]


In [4]:
#@title Utils


def setup_experiment(
        results_dir,
        dataset,
        arch_type,
        n_hidden,
        width,
        act_fn,
        init_type,
        optimiser,
        lr
    ):
    print(
f"""
Starting experiment with configuration:

  Dataset: {dataset}
  Arch type: {arch_type}
  N hidden: {n_hidden}
  Width: {width}
  Act fn: {act_fn}
  Init type: {init_type}
  Optimiser: {optimiser}
  Learning rate: {lr}
"""
)
    if arch_type == "fc":
        experiment_dir = os.path.join(
            results_dir,
            dataset,
            arch_type,
            f"n_hidden_{n_hidden}",
            f"width_{width}",
            act_fn,
            f"{init_type}_init",
            optimiser,
            f"lr_{lr}"
        )
    elif arch_type == "conv":
        experiment_dir = os.path.join(
            results_dir,
            dataset,
            arch_type,
            f"n_hidden_{n_hidden}",
            act_fn,
            f"{init_type}_init",
            optimiser,
            f"lr_{lr}"
        )
    return experiment_dir


def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_parameter_scale(dataset, n_hidden, optimiser):
    if dataset in "toy_gaussian":
        parameter_scale = 1e-1

    if optimiser == "SGD":
        parameter_scale = 5e-3

    elif optimiser == "Adam":
        parameter_scale = 1e-4

    return parameter_scale


def init_weights(module, param_scale):
    if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
        nn.init.normal_(module.weight, mean=0., std=param_scale)


def get_architecture_type(dataset):
    if dataset in ["toy_gaussian", "MNIST", "Fashion-MNIST"]:
        return "fc"
    else:
      return "conv"


def get_optimiser(id, model, lr):
    if id == "SGD":
        optimizer = optim.SGD(params=model.parameters(), lr=lr)
    elif id == "Adam":
        optimizer = optim.Adam(params=model.parameters(), lr=lr)
    else:
        raise ValueError("Invalid optimiser ID")
    return optimizer


def get_gradient_vector(model):
    grad_vec = []
    for param in model.parameters():
        grad_vec.append(param.grad.view(-1))
    grad_vec = torch.cat(grad_vec)
    return grad_vec


def get_min_iter(lists):
    min_iter = 100000
    for i in lists:
        if len(i) < min_iter:
            min_iter = len(i)
    return min_iter


def get_min_iter_metrics(metrics):
    n_seeds = len(metrics)
    min_iter = get_min_iter(lists=metrics)

    min_iter_metrics = np.zeros((n_seeds, min_iter))
    for seed in range(n_seeds):
        min_iter_metrics[seed, :] = metrics[seed][:min_iter]

    return min_iter_metrics


def compute_metric_stats(metric):
    min_iter_metrics = get_min_iter_metrics(metrics=metric)
    metric_means = min_iter_metrics.mean(axis=0)
    metric_stds = min_iter_metrics.std(axis=0)
    return metric_means, metric_stds


In [5]:
#@title Datasets


def get_dataloaders(dataset_id):
    train_data = get_dataset(
        id=dataset_id,
        train=True,
        normalise=True
    )
    test_data = get_dataset(
        id=dataset_id,
        train=False,
        normalise=True
    )
    train_loader = DataLoader(
        dataset=train_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True
    )
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True
    )
    return train_loader, test_loader


def get_dataset(id, train, normalise):
    if id == "toy_gaussian":
        dataset = make_gaussian_dataset(train=train)
    elif id == "MNIST":
        dataset = MNIST(train=train, normalise=normalise)
    elif id == "Fashion-MNIST":
        dataset = FashionMNIST(train=train, normalise=normalise)
    elif id == "CIFAR10":
        dataset = CIFAR10(train=train, normalise=normalise)
    elif id == "TinyImageNet":
        download_tiny_imagenet()
        dataset = Tiny_ImageNet(train=train, normalise=normalise)
    return dataset


def make_gaussian_dataset(train):
    input = torch.normal(
        mean=DATA_MEAN,
        std=DATA_STD if train else 0,
        size=(60000, INPUT_DIM)
    )
    target = -input
    return TensorDataset(input, target)


class MNIST(datasets.MNIST):
    def __init__(self, train, normalise=True, save_dir=DATA_DIR):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.1307), std=(0.3081)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


class FashionMNIST(datasets.FashionMNIST):
    def __init__(self, train, normalise=True, save_dir=DATA_DIR):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.5), std=(0.5)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


class CIFAR10(datasets.CIFAR10):
    def __init__(self, train, normalise=True, save_dir=f"{DATA_DIR}/CIFAR10"):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.4914, 0.4822, 0.4465),
                        std=(0.247, 0.243, 0.261)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        label = one_hot(label)
        return img, label


class Tiny_ImageNet(datasets.ImageFolder):
    def __init__(self, train, normalise=True, save_dir="tiny-imagenet-200"):
        dataset_type = "train" if train else "test"
        path = os.path.join(save_dir, f"{dataset_type}")
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(path, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        label = one_hot(label, n_classes=200)
        return img, label


def download_tiny_imagenet():
    cmd = f"""
          wget http://cs231n.stanford.edu/tiny-imagenet-200.zip; \
          unzip tiny-imagenet-200.zip
    """
    subprocess.run(cmd, shell=True)


def one_hot(labels, n_classes=10):
    arr = torch.eye(n_classes)
    return arr[labels]


def accuracy(predictions, truths):
    batch_size = predictions.size(0)
    correct = 0
    for b in range(batch_size):
        if torch.argmax(predictions[b, :]) == torch.argmax(truths[b, :]):
            correct += 1
    return correct / batch_size


In [6]:
#@title Archs


def get_network(arch_type, dataset, act_fn, n_hidden=None, width=None):
    if arch_type == "fc":
        network = get_fc_network(
            dataset=dataset,
            n_hidden=n_hidden,
            width=width,
            act_fn=act_fn
        )
    elif arch_type == "conv":
        network = get_conv_network(dataset=dataset, act_fn=act_fn)
    else:
        raise ValueError(
            "Invalid architecture type ID. Options are 'fc' and 'conv'"
        )
    return network


def get_fc_network(dataset, n_hidden, width, act_fn):
    input_dim = INPUT_DIM if dataset == "toy_gaussian" else FC_INPUT_DIM
    output_dim = INPUT_DIM if dataset == "toy_gaussian" else FC_OUTPUT_DIM

    layers = []
    for n in range(n_hidden):
        n_input = input_dim if n == 0 else width
        if act_fn == "linear":
            hidden_layer = nn.Sequential(nn.Linear(n_input, width, bias=False))
        elif act_fn == "tanh":
            hidden_layer = nn.Sequential(
                nn.Linear(n_input, width, bias=False),
                nn.Tanh()
            )
        elif act_fn == "relu":
            hidden_layer = nn.Sequential(
                nn.Linear(n_input, width, bias=False),
                nn.ReLU(inplace=True)
            )
        layers.append(hidden_layer)

    output_layer = nn.Sequential(nn.Linear(width, output_dim, bias=False))
    layers.append(output_layer)
    network = nn.Sequential(*layers)
    return network


def get_conv_network(dataset, act_fn):
    conv_network = []
    if act_fn == "linear":
        conv_network.append(nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # out size = 16/32
            nn.MaxPool2d(kernel_size=2, stride=2)
        ))
        conv_network.append(nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # out size = 8/16
            nn.MaxPool2d(kernel_size=2, stride=2)
        ))
        if dataset == "CIFAR10":
            conv_network.append(nn.Sequential(
                nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),  # out size = 4
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Flatten()
            ))
            conv_network.append(nn.Sequential(nn.Linear(4*4 * 256, 4096, bias=False)))
            conv_network.append(nn.Sequential(nn.Linear(4096, 10, bias=False)))
        elif dataset == "TinyImageNet":
            conv_network.append(nn.Sequential(
                nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),  # out size = 8
                nn.MaxPool2d(kernel_size=2, stride=2)
            ))
            conv_network.append(nn.Sequential(
                nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),  # out size = 4
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Flatten()
            ))
            conv_network.append(nn.Sequential(nn.Linear(4*4 * 512, 8192, bias=False)))
            conv_network.append(nn.Sequential(nn.Linear(8192, 200, bias=False)))

    elif act_fn == "tanh":
        conv_network.append(nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # out size = 16/32
            nn.Tanh(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        ))
        conv_network.append(nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # out size = 8/16
            nn.Tanh(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        ))
        if dataset == "CIFAR10":
            conv_network.append(nn.Sequential(
                nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),  # out size = 4
                nn.Tanh(),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Flatten()
            ))
            conv_network.append(nn.Sequential(
                nn.Linear(4*4 * 256, 4096, bias=False),
                nn.Tanh()
            ))
            conv_network.append(nn.Linear(4096, 10, bias=False))

        elif dataset == "TinyImageNet":
            conv_network.append(nn.Sequential(
                nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),  # out size = 8
                nn.Tanh(),
                nn.MaxPool2d(kernel_size=2, stride=2)
            ))
            conv_network.append(nn.Sequential(
                nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),  # out size = 4
                nn.Tanh(),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Flatten()
            ))
            conv_network.append(nn.Sequential(
                nn.Linear(4*4 * 512, 8192, bias=False),
                nn.Tanh()
            ))
            conv_network.append(nn.Linear(8192, 200, bias=False))

    elif act_fn == "relu":
        conv_network.append(nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # out size = 16/32
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        ))
        conv_network.append(nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # out size = 8/16
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        ))
        if dataset == "CIFAR10":
            conv_network.append(nn.Sequential(
                nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),  # out size = 4
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Flatten()
            ))
            conv_network.append(nn.Sequential(
                nn.Linear(4*4 * 256, 4096, bias=False),
                nn.ReLU(inplace=True)
            ))
            conv_network.append(nn.Linear(4096, 10, bias=False))

        elif dataset == "TinyImageNet":
            conv_network.append(nn.Sequential(
                nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),  # out size = 8
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2)
            ))
            conv_network.append(nn.Sequential(
                nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),  # out size = 4
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Flatten()
            ))
            conv_network.append(nn.Sequential(
                nn.Linear(4*4 * 512, 8192, bias=False),
                nn.ReLU(inplace=True),
            ))
            conv_network.append(nn.Linear(8192, 200, bias=False))

    return nn.Sequential(*conv_network)


In [7]:
#@title Models


class BPN(nn.Module):
    def __init__(self, network):
        super(BPN, self).__init__()
        self.network = network

    def forward(self, x):
        x = self.network(x)
        return x

    def get_weights(self):
        weights = []
        for param in self.network.parameters():
            if len(param.shape) > 1:
                weights.append(param.cpu().detach().numpy())
        return weights



class PCN(object):
    def __init__(self, network, dt, device="cpu"):
        self.network = network.to(device)
        self.n_layers = len(self.network)
        self.n_nodes = self.n_layers + 1
        self.dt = dt
        self.n_params = sum(
            p.numel() for p in network.parameters() if p.requires_grad
        )
        self.device = device

    def reset(self):
        self.zero_grad()
        self.preds = [None] * self.n_nodes
        self.errs = [None] * self.n_nodes
        self.xs = [None] * self.n_nodes

    def reset_xs(self, prior, init_std):
        self.set_prior(prior)
        self.propagate_xs()
        for l in range(self.n_layers):
            self.xs[l] = torch.empty(self.xs[l].shape).normal_(
                mean=0,
                std=init_std
            ).to(self.device)

    def set_obs(self, obs):
        self.xs[-1] = obs.clone()

    def set_prior(self, prior):
        self.xs[0] = prior.clone()

    def forward(self, x):
        return self.network(x)

    def propagate_xs(self):
        for l in range(1, self.n_layers):
            self.xs[l] = self.network[l - 1](self.xs[l - 1])

    def infer_train(
            self,
            obs,
            prior,
            n_iters,
            record_grad_norms=False,
        ):
        self.reset()
        self.set_prior(prior)
        self.propagate_xs()
        self.set_obs(obs)

        if record_grad_norms:
            grad_norms_iters = [np.zeros(n_iters) for p in range(len(self.network)*2)]

        dt = self.dt
        tot_energy_iters = []
        for t in range(n_iters):
            self.network.zero_grad()
            self.preds[-1] = self.network[self.n_layers - 1](self.xs[self.n_layers - 1])
            self.errs[-1] = self.xs[-1] - self.preds[-1]

            for l in reversed(range(1, self.n_layers)):
                self.preds[l] = self.network[l - 1](self.xs[l - 1])
                self.errs[l] = self.xs[l] - self.preds[l]
                _, epsdfdx = torch.autograd.functional.vjp(
                    self.network[l],
                    self.xs[l],
                    self.errs[l + 1]
                )
                with torch.no_grad():
                    dx = epsdfdx - self.errs[l]
                    self.xs[l] = self.xs[l] + self.dt * dx

            tot_energy_iters.append(self.compute_tot_energy())
            if t > 0 and tot_energy_iters[t] >= tot_energy_iters[t-1] and self.dt > 0.025:
                self.dt /= 2
                if self.dt <= 0.025:
                    self.dt = dt
                    break

            if (t+1) != n_iters:
                self.clear_grads()

        if record_grad_norms:
            self.set_grads(
                grad_norms_iters=grad_norms_iters,
                t=t
            )
        else:
            self.set_grads()

        if record_grad_norms:
            return tot_energy_iters, grad_norms_iters
        else:
            return tot_energy_iters

    def infer_test(
            self,
            obs,
            prior,
            n_iters,
            update_prior=True,
            update_obs=False,
            init_std=0.05,
        ):

        self.reset()
        self.reset_xs(prior, init_std)
        if not update_prior:
            self.set_prior(prior)
        self.set_obs(obs)

        for t in range(n_iters):
            self.network.zero_grad()
            self.preds[-1] = self.network[self.n_layers - 1](self.xs[self.n_layers - 1])
            self.errs[-1] = self.xs[-1] - self.preds[-1]

            if update_obs:
                with torch.no_grad():
                    self.xs[-1] = self.xs[-1] + self.dt * (- self.errs[-1])

            for l in reversed(range(1, self.n_layers)):
                self.preds[l] = self.network[l - 1](self.xs[l - 1])
                self.errs[l] = self.xs[l] - self.preds[l]
                _, epsdfdx = torch.autograd.functional.vjp(
                    self.network[l],
                    self.xs[l],
                    self.errs[l + 1]
                )
                with torch.no_grad():
                    dx = epsdfdx - self.errs[l]
                    self.xs[l] = self.xs[l] + self.dt * dx

            if update_prior:
                _, epsdfdx = torch.autograd.functional.vjp(
                    self.network[0],
                    self.xs[0],
                    self.errs[1]
                )
                with torch.no_grad():
                    self.xs[0] = self.xs[0] + self.dt * epsdfdx

            if (t+1) != n_iters:
                self.clear_grads()

        if update_prior:
            return self.xs[0]
        elif update_obs:
            return self.xs[-1]

    def set_grads(self, grad_norms_iters=None, t=None):
        n = 0
        for l in range(self.n_layers):
            for i, param in enumerate(self.network[l].parameters()):
                dparam = torch.autograd.grad(
                    self.preds[l + 1],
                    param,
                    - self.errs[l + 1],
                    allow_unused=True,
                    retain_graph=True
                )[0]
                param.grad = dparam.clone()
                if grad_norms_iters is not None:
                    grad_norm = torch.linalg.norm(dparam)
                    grad_norms_iters[n][t] = grad_norm.item()
                    n += 1

    def zero_grad(self):
        self.network.zero_grad()

    def clear_grads(self):
        with torch.no_grad():
            for l in range(1, self.n_nodes):
                self.preds[l] = self.preds[l].clone()
                self.errs[l] = self.errs[l].clone()
                self.xs[l] = self.xs[l].clone()

    def save_weights(self, path):
        torch.save(self.network.state_dict(), path)

    def load_weights(self, path):
        self.network.load_state_dict(torch.load(path))

    def get_weights(self):
        weights = []
        for param in self.network.parameters():
            if len(param.shape) > 1:
                weights.append(param.cpu().detach().numpy())
        return weights

    def compute_tot_energy(self):
        energy = 0.
        for err in self.errs:
            if err is not None:
                energy += (err**2).sum()
        return energy.item()

    def parameters(self):
        return self.network.parameters()

    def __str__(self):
        return f"PCN(\n{self.network}\n"


In [8]:
#@title Weight manipulation
# Adapted from https://github.com/tomgoldstein/loss-landscape/blob/master/net_plotter.py


def get_weights(net):
    """ Extract parameters from net, and return a list of tensors"""
    return [p.data for p in net.parameters()]


def get_random_weights(weights, device='cpu'):
    """
        Produce a random direction that is a list of random Gaussian tensors
        with the same shape as the network's weights, so one direction entry per weight.
    """
    return [torch.randn(w.size()).to(device) for w in weights]


def normalize_direction(direction, weights, norm='filter'):
    """
        Rescale the direction so that it has similar norm as their corresponding
        model in different levels.

        Args:
          direction: a variables of the random direction for one layer
          weights: a variable of the original model for one layer
          norm: normalization method, 'filter' | 'layer' | 'weight'
    """
    if norm == 'filter':
        # Rescale the filters (weights in group) in 'direction' so that each
        # filter has the same norm as its corresponding filter in 'weights'.
        for d, w in zip(direction, weights):
            d.mul_(w.norm()/(d.norm() + 1e-10))
    elif norm == 'layer':
        # Rescale the layer variables in the direction so that each layer has
        # the same norm as the layer variables in weights.
        direction.mul_(weights.norm()/direction.norm())
    elif norm == 'weight':
        # Rescale the entries in the direction so that each entry has the same
        # scale as the corresponding weight.
        direction.mul_(weights)
    elif norm == 'dfilter':
        # Rescale the entries in the direction so that each filter direction
        # has the unit norm.
        for d in direction:
            d.div_(d.norm() + 1e-10)
    elif norm == 'dlayer':
        # Rescale the entries in the direction so that each layer direction has
        # the unit norm.
        direction.div_(direction.norm())


def normalize_directions_for_weights(direction, weights, norm='filter', ignore='biasbn'):
    """
        The normalization scales the direction entries according to the entries of weights.
    """
    assert(len(direction) == len(weights))
    for d, w in zip(direction, weights):
        if d.dim() <= 1:
            if ignore == 'biasbn':
                d.fill_(0) # ignore directions for weights with 1 dimension
            else:
                d.copy_(w) # keep directions for weights/bias that are only 1 per node
        else:
            normalize_direction(d, w, norm)


def create_random_weight_direction(net, device='cpu', ignore='biasbn', norm='filter'):
    """
        Setup a random (normalized) direction with the same dimension as
        the weights.
        Args:
          net: the given trained model
          ignore: 'biasbn', ignore biases and BN parameters.
          norm: direction normalization method, including
                'filter" | 'layer' | 'weight' | 'dlayer' | 'dfilter'
        Returns:
          direction: a random direction with the same dimension as weights.
    """

    # random direction
    weights = get_weights(net) # a list of parameters.
    direction = get_random_weights(weights, device)
    normalize_directions_for_weights(direction, weights, norm, ignore)
    return direction


In [9]:
#@title Plotting


@gif.frame
def plot_hessian_matrix(hessian_matrix, save_path, title=None):
    fig, ax = plt.subplots()
    heatmap = ax.imshow(
        X=hessian_matrix,
        cmap="RdBu_r",
        vmin=-1,
        vmax=1
    )
    fig.colorbar(heatmap, ax=ax, location="right")
    if title is not None:
        plt.title(f"${title}$", fontsize=18)
    fig.savefig(save_path)
    return fig


@gif.frame
def plot_hessian_eigenvalues(eigenvalues, save_path, title=None):
    fig = go.Figure(
        data=go.Histogram(
            x=eigenvalues,
            histnorm="probability",
            marker_color="#FF7F0E"
        )
    )
    fig.update_layout(
        height=300,
        width=500,
        title=dict(
            text=f"${title}$" if title is not None else "",
            y=0.7,
            x=0.5,
            xanchor="center",
            yanchor="top"
        ),
        xaxis=dict(title="Hessian eigenvalue"),
        yaxis=dict(
            title=f"Density (Log Scale)",
            nticks=5,
            type="log"
        ),
        font=dict(size=16)
    )
    fig.write_image(save_path)
    return fig


def plot_loss(loss, mode, save_path):
    n_train_iters = len(loss)
    train_iters = [b+1 for b in range(n_train_iters)]

    loss_color = "#EF553B"
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=train_iters,
            y=loss,
            mode="lines",
            line=dict(width=2, color=loss_color),
            showlegend=False
        )
    )
    fig.update_layout(
        height=300,
        width=400,
        xaxis=dict(
            title="Batch" if mode == "train" else "Epoch",
            tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
            ticktext=[1, int(train_iters[-1]/2), train_iters[-1]]
        ),
        yaxis=dict(
            title=f"$\Large{{\mathcal{{L}}_{{{mode}}}}}$"
        ),
        font=dict(size=16)
    )
    fig.write_image(save_path)


def plot_loss_and_accuracy(loss, accuracy, mode, save_path):
    n_train_iters = len(loss)
    train_iters = [b+1 for b in range(n_train_iters)]

    loss_color, accuracy_color = "#EF553B", "#636EFA"
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=train_iters,
            y=loss,
            mode="lines",
            line=dict(width=2, color=loss_color),
            showlegend=False
        )
    )
    fig.add_trace(
        go.Scatter(
            x=train_iters,
            y=accuracy,
            mode="lines",
            line=dict(width=2, color=accuracy_color),
            showlegend=False,
            yaxis="y2"
        )
    )
    fig.update_layout(
        height=300,
        width=400,
        xaxis=dict(
            title="Batch" if mode == "train" else "Epoch",
            tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
            ticktext=[1, int(train_iters[-1]/2), train_iters[-1]]
        ),
        yaxis=dict(
            title=f"$\Large{{\mathcal{{L}}_{{{mode}}}}}$",
            titlefont=dict(
                color=loss_color
            ),
            tickfont=dict(
                color=loss_color
            )
        ),
        yaxis2=dict(
            title=f"{mode.capitalize()} accuracy (%)",
            side="right",
            overlaying="y",
            titlefont=dict(
                color=accuracy_color
            ),
            tickfont=dict(
                color=accuracy_color,
            )
        ),
        font=dict(size=16)
    )
    fig.write_image(save_path)


def plot_norms(norms, norm_type, mode, save_path):
    n_params = len(norms)
    n_iterations = len(norms[0])
    iterations = [b+1 for b in range(n_iterations)]

    fig = go.Figure()
    weights_id = [
        f"$W_{i+1}$" if norm_type == "parameters" else f"$\partial W_{i+1}$" for i in range(n_params)
    ]
    colors = px.colors.qualitative.Plotly[2:]
    for weight_norms, weight_id, color in zip(norms, weights_id, colors):
        fig.add_traces(
            go.Scatter(
                x=iterations,
                y=weight_norms,
                name=weight_id,
                mode="lines",
                line=dict(width=2, color=color)
            )
        )

    fig_width = 300 if norm_type == "parameters" else 400
    xaxis_title = "Training iteration" if mode == "learning" else "Inference iteration"
    fig.update_layout(
        height=300,
        width=fig_width,
        xaxis=dict(
            title=xaxis_title,
            tickvals=[1, int(iterations[-1]/2), iterations[-1]],
            ticktext=[1, int(iterations[-1]/2), iterations[-1]],
        ),
        yaxis=dict(
            title="$\Large{||W||_F}$" if norm_type == "parameters" else "$\Large{||\partial W||_F}$",
            nticks=3
        ),
        font=dict(size=16),
    )
    fig.write_image(f"{save_path}_weight.pdf")


def plot_bp_and_pc_metric_stats(means, stds, dataset, optimiser, metric_title, save_path):
    max_train_iter = len(means[0]) if len(means[0]) >= len(means[1]) else len(means[1])

    fig = go.Figure()
    for i in range(2):
        n_train_iters = len(means[i])
        train_iters = [b+1 for b in range(n_train_iters)]

        color = "#EF553B" if i == 0 else "#636EFA"
        y_upper, y_lower = means[i] + stds[i], means[i] - stds[i]

        fig.add_traces(
            go.Scatter(
                x=list(train_iters)+list(train_iters[::-1]),
                y=list(y_upper)+list(y_lower[::-1]),
                fill="toself",
                fillcolor=color,
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                opacity=0.3
            )
        )
        fig.add_traces(
            go.Scatter(
                x=train_iters,
                y=means[i],
                name="BP" if i == 0 else "PC",
                mode="lines+markers",
                line=dict(width=3, color=color)
            )
        )

    if "train" in metric_title:
        xaxis_title = "Training iteration (log)" if dataset == "toy_gaussian" else "Training iteration"
    else:
        xaxis_title = "Epoch"

    if dataset == "toy_gaussian":
        fig.update_layout(
            xaxis=dict(
                type="log",
                exponentformat="power",
                dtick=1
            )
        )
    else:
        fig.update_layout(
            xaxis=dict(
                tickvals=[1, int(max_train_iter/2), max_train_iter],
                ticktext=[1, int(max_train_iter/2), max_train_iter]
            )
        )
    if optimiser == "Adam":
        fig.update_yaxes(range=[0, 0.15])

    fig.update_layout(
        height=300,
        width=350,
        xaxis=dict(title=xaxis_title),
        yaxis=dict(title=metric_title),
        font=dict(size=16)
    )
    fig.write_image(save_path)


def plot_bp_vs_pc_grad_norm_stats(
        means: Tuple[np.ndarray],
        stds: Tuple[np.ndarray],
        dataset: int,
        save_path: str
    ) -> None:
    max_train_iter = len(means[0]) if len(means[0]) >= len(means[1]) else len(means[1])

    fig = go.Figure()
    for i in range(2):
        n_train_iters = len(means[i])
        train_iters = [b+1 for b in range(n_train_iters)]

        color = "#EF553B" if i == 0 else "#636EFA"
        y_upper, y_lower = means[i] + stds[i], means[i] - stds[i]

        fig.add_traces(
            go.Scatter(
                x=list(train_iters)+list(train_iters[::-1]),
                y=list(y_upper)+list(y_lower[::-1]),
                fill="toself",
                fillcolor=color,
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                opacity=0.3
            )
        )
        fig.add_traces(
            go.Scatter(
                x=train_iters,
                y=means[i],
                name="BP" if i == 0 else "PC",
                mode="lines+markers",
                line=dict(width=3, color=color)
            )
        )

    if dataset == "toy_gaussian":
        fig.update_layout(
            xaxis=dict(
                type="log",
                exponentformat="power",
                dtick=1
            )
        )
    else:
        fig.update_layout(
            xaxis=dict(
                tickvals=[1, int(max_train_iter/2), max_train_iter],
                ticktext=[1, int(max_train_iter/2), max_train_iter]
            )
        )

    fig.update_layout(
        height=300,
        width=350,
        xaxis=dict(
            title="Training iteration (log)" if dataset == "toy_gaussian" else "Training iteration"
        ),
        yaxis=dict(
            title="$\Large{||\partial \\theta||_2}$",
            nticks=3
        ),
        font=dict(size=16),
    )
    fig.write_image(save_path)


In [10]:
#@title Landscape plotting


@gif.frame
def plot_objective_surface(
        objective_mesh: np.ndarray,
        weights: Tuple[np.ndarray, np.ndarray],
        objective_name: str,
        save_path: str,
        inference_step: Optional[int] = None,
        show_background: bool = True
    ) -> go.Figure():
    objective_notation = "L" if objective_name == "loss" else "F"
    objective_max, objective_min = objective_mesh.max(), objective_mesh.min()
    fig = go.Figure(
        data=go.Surface(
            z=objective_mesh,
            x=weights[0],
            y=weights[1],
            colorscale=COLORSCALE,
            colorbar=dict(
                title=f"$\LARGE{{\mathcal{{{objective_notation}}}}}$",
                x=0.85,
                y=0.57,
                len=0.3,
                titleside="right",
                tickfont=dict(size=16),
                tickvals=[objective_min, objective_max],
                ticktext=["Low", "High"]
            )
        )
    )
    fig.update_traces(
        contours_z=dict(
            show=True,
            usecolormap=True,
            highlightcolor="limegreen",
            project_z=True,
        )
    )
    fig.update_layout(
        scene=dict(zaxis=(dict(
            title="",
            range=[objective_min, objective_max],
            showticklabels=False
        ))),
        font=dict(size=16),
        height=600,
        width=700,
        margin=dict(r=30, b=10, l=0, t=40),
        scene_aspectmode="cube"
    )
    if show_background:
        title = f"Inference step: {inference_step}" if inference_step is not None else ""
        fig.update_layout(
            scene=dict(
                xaxis=dict(title="", nticks=3, autorange="reversed"),
                yaxis=dict(title="", nticks=3, autorange="reversed")
            ),
            scene_camera=dict(
                center=dict(x=0.05, y=0.2, z=0),
                eye=dict(x=1.4, y=1.4, z=1.25)
            ),
            title=dict(
                text=title,
                y=0.95,
                x=0.42,
                xanchor="center",
                yanchor="top"
            )
        )
    else:
        fig.update_layout(
            scene=dict(
                xaxis=dict(
                    title="",
                    autorange="reversed",
                    showticklabels=False,
                    showbackground=False,
                ),
                yaxis=dict(
                    title="",
                    autorange="reversed",
                    showticklabels=False,
                    showbackground=False,
                )
            ),
            scene_camera=dict(
                center=dict(x=0.2, y=0.15, z=0),
                eye=dict(x=1.3, y=1.3, z=1)
            )
        )
        fig.update_traces(showscale=False)

    fig.write_image(save_path)
    return fig


@gif.frame
def plot_objective_contour(
        objective_mesh: np.ndarray,
        weights: Tuple[np.ndarray, np.ndarray],
        objective_name: str,
        save_path: str,
        inference_step: Optional[int] = None,
        show_origin: bool = True
    ):
    fig = go.Figure(
        data=go.Contour(
            z=objective_mesh,
            x=weights[0],
            y=weights[1],
            colorscale=COLORSCALE,
            showscale=False,
            contours_coloring="heatmap",
            # contours=dict(
            #     showlabels=True,
            #     labelfont=dict(
            #         size=12,
            #         color="black"
            #     )
            # )
        )
    )
    # colorbar
    objective_min, objective_max = 0, objective_mesh.max()
    objective_notation = "L" if objective_name == "loss" else "F"
    colorbar = go.Scatter(
        x=[None],
        y=[None],
        mode="markers",
        showlegend=False,
        marker=dict(
            colorscale=COLORSCALE,
            showscale=True,
            cmin=objective_min,
            cmax=objective_max,
            colorbar=dict(
                title=f"$\LARGE{{\mathcal{{{objective_notation}}}}}$",
                x=1.05,
                len=0.5,
                title_side="right",
                tickfont=dict(size=16),
                tickvals=[objective_min, objective_max],
                ticktext=["Low", "High"]
            )
        ),
        hoverinfo="none"
    )
    fig.add_trace(colorbar)

    if show_origin:
        marker_color = "rgb(255, 255, 51)"
        fig.add_traces(
            go.Scatter(
                x=[0],
                y=[0],
                mode="markers+text",
                marker=dict(size=10, color=marker_color),
                showlegend=False,
                text=[f"$\LARGE{{\\theta^*}}$"],
                textposition="top right",
                textfont=dict(color=marker_color)
            )
        )

    title = f"Inference step: {inference_step}" if inference_step is not None else ""
    fig.update_layout(
        title=dict(
            text=title,
            y=0.85,
            x=0.41,
            xanchor="center",
            yanchor="top"
        ),
        xaxis=dict(title="$\LARGE{α}$"),
        yaxis=dict(title="$\LARGE{β}$"),
        margin=dict(r=200, b=20, l=20, t=100),
        font=dict(size=16),
        plot_bgcolor="white",
        width=700,
        height=400
    )

    fig.write_image(save_path)
    return fig


In [11]:
#@title Projection visualisations


def visualise_2D_loss_random_projections(
        model,
        input,
        target,
        domain,
        device,
        save_dir
    ):
    n_directions = 2
    random_directions = []
    for i in range(n_directions):
        random_direction = create_random_weight_direction(
            net=model.network,
            device=device
        )
        random_direction = [w for w in random_direction if len(w.shape) > 1]
        random_directions.append(random_direction)

    scaling_factors = [
        np.linspace(
            -domain, domain, SAMPLING_RESOLUTION
        ) for d in range(n_directions)
    ]

    loss_fn = nn.MSELoss()
    loss_mesh = np.zeros((SAMPLING_RESOLUTION, SAMPLING_RESOLUTION))
    for j, a in enumerate(scaling_factors[0]):
        for i, b in enumerate(scaling_factors[1]):

            n = 0
            for p in model.parameters():
                if len(p.shape) > 1:
                    p.data = p.data + (a * random_directions[0][n]) + (b * random_directions[1][n])
                    n += 1

            preds = model.forward(input)
            loss = loss_fn(preds, target)
            loss_mesh[i, j] = loss

    surface_fig = plot_objective_surface(
        objective_mesh=loss_mesh,
        weights=scaling_factors,
        objective_name="loss",
        save_path=f"{save_dir}/random_surface_{domain}.pdf"
    )


def visualise_2D_energy_random_projections(
        model,
        input,
        target,
        domain,
        device,
        save_dir
    ):
    n_directions = 2
    random_directions = []
    for i in range(n_directions):
        random_direction = create_random_weight_direction(
            net=model.network,
            device=device
        )
        random_direction = [w for w in random_direction if len(w.shape) > 1]
        random_directions.append(random_direction)

    scaling_factors = [
        np.linspace(
            -domain, domain, SAMPLING_RESOLUTION
        ) for d in range(n_directions)
    ]

    loss_fn = nn.MSELoss()
    energy_mesh = np.zeros((SAMPLING_RESOLUTION, SAMPLING_RESOLUTION, N_ITERS+1))
    for j, a in enumerate(scaling_factors[0]):
        for i, b in enumerate(scaling_factors[1]):

            n = 0
            for p in model.parameters():
                if len(p.shape) > 1:
                    p.data = p.data + (a * random_directions[0][n]) + (b * random_directions[1][n])
                    n += 1

            preds = model.forward(input)
            loss = loss_fn(preds, target)
            energy_mesh[i, j, 0] = loss

            tot_energy_iters = model.infer_train(
                obs=target,
                prior=input,
                n_iters=N_ITERS
            )
            if len(tot_energy_iters) != N_ITERS:
                n_missing_iters = N_ITERS - len(tot_energy_iters)
                tot_energy_iters.extend([tot_energy_iters[-1]]*n_missing_iters)

            energy_mesh[i, j, 1:] = tot_energy_iters

    energy_surface_frames = []
    energy_iters_max = energy_mesh.max()
    for t in range(N_ITERS+1):
        if t in PLOT_ITERS:
            fig = plot_objective_surface(
                objective_mesh=energy_mesh[:, :, t],
                weights=scaling_factors,
                objective_name="energy",
                inference_step=t,
                save_path=f"{save_dir}/random_surface_{domain}_iter_{t}.pdf"
            )
            energy_surface_frames.append(fig)

        gif.save(
            energy_surface_frames,
            f"{save_dir}/random_surface_{domain}_dynamics.gif",
            duration=1,
            unit="s"
        )


## Scripts

In [12]:
#@title BP train script


def train_bp(dataset, arch_type, n_hidden, width, act_fn, init_type, optimiser, seed, max_epochs, save_dir):
    print("\nStarting training with BP...\n")
    set_seed(seed)
    device = get_device()
    os.makedirs(save_dir, exist_ok=True)

    network = get_network(
        arch_type=arch_type,
        dataset=dataset,
        n_hidden=n_hidden,
        width=width,
        act_fn=act_fn
    )
    print(f"network: {network}\n")
    model = BPN(network=network).to(device)
    parameter_scale = get_parameter_scale(
        dataset=dataset,
        n_hidden=n_hidden,
        optimiser=optimiser
    )
    for i, layer in enumerate(model.network):
        if (i+1) == len(network)-1 and init_type == "other_saddle":
            continue
        layer.apply(lambda m: init_weights(m, parameter_scale))

    loss_fn = nn.MSELoss()
    optimizer = get_optimiser(id=optimiser, model=model, lr=LR)

    # metrics
    batch_train_losses, epoch_test_losses = [], []
    batch_train_accs, epoch_test_accs = [], []
    grad_norms = []
    n_params = len(network)*2
    norms = {
        key: [[] for p in range(n_params)] for key in ["params", "grads"]
    }
    for i, param in enumerate(network.parameters()):
        norms["params"][i].append(norm(param).item())

    # record metrics at initialisation (batch 0)
    train_loader, test_loader = get_dataloaders(dataset_id=dataset)
    img_batch, label_batch = next(iter(train_loader))
    img_batch = img_batch.to(device)
    label_batch = label_batch.to(device)
    label_preds = model(img_batch)
    train_loss = loss_fn(label_preds, label_batch).item()
    batch_train_losses.append(train_loss)
    if dataset != "toy_gaussian":
        train_acc = accuracy(label_preds, label_batch)
        batch_train_accs.append(train_acc)

    # plot landscape onto random directions
    if dataset == "toy_gaussian":
        for domain in DOMAINS:
            visualise_2D_loss_random_projections(
                model=model,
                input=img_batch,
                target=label_batch,
                domain=domain,
                device=device,
                save_dir=save_dir
            )

    img_batch, label_batch = next(iter(test_loader))
    img_batch, label_batch = img_batch.to(device), label_batch.to(device)
    label_preds = model(img_batch)
    test_loss = loss_fn(label_preds, label_batch).item()
    epoch_test_losses.append(test_loss)
    if dataset != "toy_gaussian":
        test_acc = accuracy(label_preds, label_batch)
        epoch_test_accs.append(test_acc)

    global_batch_id = 0
    epoch_train_losses = []
    for epoch in range(1, max_epochs+1):
        print(f"\nEpoch {epoch}\n-------------------------------")

        epoch_train_loss = 0
        for batch_id, (img_batch, label_batch) in enumerate(train_loader):
            img_batch = img_batch.to(device)
            label_batch = label_batch.to(device)

            label_preds = model(img_batch)
            loss = loss_fn(label_preds, label_batch)
            epoch_train_loss += loss.item()
            if dataset != "toy_gaussian":
                train_acc = accuracy(label_preds, label_batch)

            loss.backward()
            optimizer.step()
            grad_vec = get_gradient_vector(model=model)
            grad_norms.append(norm(grad_vec).item())
            for i, param in enumerate(network.parameters()):
                norms["params"][i].append(norm(param).item())
                norms["grads"][i].append(norm(param.grad).item())

            model.zero_grad()

            batch_train_losses.append(loss.item())
            if dataset != "toy_gaussian":
                batch_train_accs.append(train_acc)

            global_batch_id += 1

            if global_batch_id % LOG_BATCH_EVERY == 0:
                print(f"Train loss: {loss.item():.5f} [{batch_id*len(img_batch)}/{len(train_loader.dataset)}]")

        test_loss, test_acc = (0, 0)
        for batch_id, (img_batch, label_batch) in enumerate(test_loader):
            img_batch = img_batch.to(device)
            label_batch = label_batch.to(device)

            label_preds = model(img_batch)
            test_loss += loss_fn(label_preds, label_batch).item()
            if dataset != "toy_gaussian":
                test_acc += accuracy(label_preds, label_batch)

        epoch_train_losses.append(epoch_train_loss / len(train_loader))
        epoch_test_losses.append(test_loss / len(test_loader))
        if dataset != "toy_gaussian":
            epoch_test_accs.append(test_acc / len(test_loader))
            print(f"\nAvg test accuracy: {test_acc / len(test_loader):.4f}")

        if init_type == "standard":
            if epoch > 1 and (epoch_train_losses[-2] - epoch_train_losses[-1]) < LOSS_TOLERANCE:
                break

    print(f"\nTraining stopped at epoch {epoch}\n")

    # plot losses, accuracies and norms
    if dataset == "toy_gaussian":
        plot_loss(
            loss=batch_train_losses,
            mode="train",
            save_path=f"{save_dir}/train_losses.pdf"
        )
        plot_loss(
            loss=epoch_test_losses,
            mode="test",
            save_path=f"{save_dir}/test_losses.pdf"
        )
    else:
        plot_loss_and_accuracy(
            loss=batch_train_losses,
            accuracy=batch_train_accs,
            mode="train",
            save_path=f"{save_dir}/train_losses_and_accs.pdf"
        )
        plot_loss_and_accuracy(
            loss=epoch_test_losses,
            accuracy=epoch_test_accs,
            mode="test",
            save_path=f"{save_dir}/test_losses_and_accs.pdf"
        )
    plot_norms(
        norms=norms["params"],
        norm_type="parameters",
        mode="learning",
        save_path=f"{save_dir}/parameters_norm"
    )
    plot_norms(
        norms=norms["grads"],
        norm_type="gradient",
        mode="learning",
        save_path=f"{save_dir}/gradient_norm"
    )

    np.save(f"{save_dir}/batch_train_losses.npy", batch_train_losses)
    np.save(f"{save_dir}/epoch_test_losses.npy", epoch_test_losses)

    np.save(f"{save_dir}/batch_train_accs.npy", batch_train_accs)
    np.save(f"{save_dir}/epoch_test_accs.npy", epoch_test_accs)

    np.save(f"{save_dir}/grad_norms.npy", grad_norms)

    return {
        "losses": {
            "train": batch_train_losses,
            "test": epoch_test_losses
        },
        "accs": {
            "train": batch_train_accs,
            "test": epoch_test_accs
        },
        "grad_norms": grad_norms
    }


In [13]:
#@title PC train script


def train_pc(dataset, arch_type, n_hidden, width, act_fn, init_type, optimiser, seed, save_dir):
    print("Starting training with PC...\n")
    set_seed(seed)
    device = get_device()
    os.makedirs(save_dir, exist_ok=True)

    network = get_network(
        arch_type=arch_type,
        dataset=dataset,
        n_hidden=n_hidden,
        width=width,
        act_fn=act_fn
    )
    model = PCN(network=network, dt=DT, device=device)
    parameter_scale = get_parameter_scale(
        dataset=dataset,
        n_hidden=n_hidden,
        optimiser=optimiser
    )
    for i, layer in enumerate(model.network):
        if (i+1) == len(network)-1 and init_type == "other_saddle":
            continue
        layer.apply(lambda m: init_weights(m, parameter_scale))

    loss_fn = nn.MSELoss()
    optimizer = get_optimiser(id=optimiser, model=model, lr=LR)

    # metrics
    batch_train_losses, epoch_test_losses = [], []
    batch_train_accs, epoch_test_accs = [], []
    grad_norms = []
    norms = {
        key: [[] for p in range(len(network)*2)] for key in ["params", "grads"]
    }
    for i, param in enumerate(network.parameters()):
        norms["params"][i].append(norm(param).item())

    # record train metrics at initialisation (batch 0)
    train_loader, test_loader = get_dataloaders(dataset_id=dataset)
    img_batch, label_batch = next(iter(train_loader))
    img_batch = img_batch.to(device)
    label_batch = label_batch.to(device)
    label_preds = model.forward(img_batch)
    train_loss = loss_fn(label_preds, label_batch).item()
    batch_train_losses.append(train_loss)
    if dataset != "toy_gaussian":
        train_acc = accuracy(label_preds, label_batch)
        batch_train_accs.append(train_acc)

    # plot gradient inference dynamics
    _, grad_norms_iters = model.infer_train(
        obs=label_batch,
        prior=img_batch,
        n_iters=N_ITERS,
        record_grad_norms=True
    )
    plot_norms(
        norms=grad_norms_iters,
        norm_type="gradient",
        mode="inference",
        save_path=f"{save_dir}/gradient_norm_inference"
    )

    # plot landscape onto random directions
    if dataset == "toy_gaussian":
        for domain in DOMAINS:
            visualise_2D_energy_random_projections(
                model=model,
                input=img_batch,
                target=label_batch,
                domain=domain,
                device=device,
                save_dir=save_dir
            )

    img_batch, label_batch = next(iter(test_loader))
    img_batch = img_batch.to(device)
    label_batch = label_batch.to(device)
    label_preds = model.forward(img_batch)
    test_loss = loss_fn(label_preds, label_batch).item()
    epoch_test_losses.append(test_loss)
    if dataset != "toy_gaussian":
        test_acc = accuracy(label_preds, label_batch)
        epoch_test_accs.append(test_acc)

    global_batch_id = 0
    epoch_train_losses = []
    for epoch in range(1, MAX_EPOCHS+1):
        print(f"\nEpoch {epoch}\n-------------------------------")

        epoch_train_loss = 0
        for batch_id, (img_batch, label_batch) in enumerate(train_loader):
            img_batch = img_batch.to(device)
            label_batch = label_batch.to(device)

            label_preds = model.forward(img_batch)
            loss = loss_fn(label_preds, label_batch).item()
            epoch_train_loss += loss
            if dataset != "toy_gaussian":
                train_acc = accuracy(label_preds, label_batch)

            tot_energies = model.infer_train(
                obs=label_batch,
                prior=img_batch,
                n_iters=N_ITERS
            )
            optimizer.step()
            grad_vec = get_gradient_vector(model=model)
            grad_norms.append(norm(grad_vec).item())
            for i, param in enumerate(network.parameters()):
                norms["params"][i].append(norm(param).item())
                norms["grads"][i].append(norm(param.grad).item())

            batch_train_losses.append(loss)
            if dataset != "toy_gaussian":
                batch_train_accs.append(train_acc)

            global_batch_id += 1

            if global_batch_id % LOG_BATCH_EVERY == 0:
                print(f"Train loss: {loss:.7f} [{batch_id * len(img_batch)}/{len(train_loader.dataset)}]")

        test_loss, test_acc = (0, 0)
        for batch_id, (img_batch, label_batch) in enumerate(test_loader):
            img_batch = img_batch.to(device)
            label_batch = label_batch.to(device)

            label_preds = model.forward(img_batch)
            test_loss += loss_fn(label_preds, label_batch).item()
            if dataset != "toy_gaussian":
                test_acc += accuracy(label_preds, label_batch)

        epoch_train_losses.append(epoch_train_loss / len(train_loader))
        epoch_test_losses.append(test_loss / len(test_loader))
        if dataset != "toy_gaussian":
            epoch_test_accs.append(test_acc / len(test_loader))
            print(f"\nAvg test accuracy: {test_acc / len(test_loader):.4f}")

        if epoch > 1 and (epoch_train_losses[-2] - epoch_train_losses[-1]) < LOSS_TOLERANCE:
            print(f"\nTraining stopped at epoch {epoch}\n")
            break

    # plot losses, accuracies and norms
    if dataset == "toy_gaussian":
        plot_loss(
            loss=batch_train_losses,
            mode="train",
            save_path=f"{save_dir}/train_losses.pdf"
        )
        plot_loss(
            loss=epoch_test_losses,
            mode="test",
            save_path=f"{save_dir}/test_losses.pdf"
        )
    else:
        plot_loss_and_accuracy(
            loss=batch_train_losses,
            accuracy=batch_train_accs,
            mode="train",
            save_path=f"{save_dir}/train_losses_and_accs.pdf"
        )
        plot_loss_and_accuracy(
            loss=epoch_test_losses,
            accuracy=epoch_test_accs,
            mode="test",
            save_path=f"{save_dir}/test_losses_and_accs.pdf"
        )
    plot_norms(
        norms=norms["params"],
        norm_type="parameters",
        mode="learning",
        save_path=f"{save_dir}/parameters_norm"
    )
    plot_norms(
        norms=norms["grads"],
        norm_type="gradient",
        mode="learning",
        save_path=f"{save_dir}/gradient_norm"
    )

    np.save(f"{save_dir}/batch_train_losses.npy", batch_train_losses)
    np.save(f"{save_dir}/epoch_test_losses.npy", epoch_test_losses)

    np.save(f"{save_dir}/batch_train_accs.npy", batch_train_accs)
    np.save(f"{save_dir}/epoch_test_accs.npy", epoch_test_accs)

    np.save(f"{save_dir}/grad_norms.npy", grad_norms)

    return {
        "losses": {
            "train": batch_train_losses,
            "test": epoch_test_losses
        },
        "accs": {
            "train": batch_train_accs,
            "test": epoch_test_accs
        },
        "grad_norms": grad_norms,
        "n_epochs": epoch
    }


In [14]:
#@title Main script


def main():
    os.makedirs(DATA_DIR, exist_ok=True)
    for dataset in DATASETS:
        arch_type = get_architecture_type(dataset=dataset)
        for n_hidden, width in N_HIDDEN_WIDTHS[dataset]:
            for act_fn in ACT_FNS:
                for init_type in INIT_TYPES:
                    for optimiser in OPTIMISERS:
                        experiment_dir = setup_experiment(
                            results_dir=RESULTS_DIR,
                            dataset=dataset,
                            arch_type=arch_type,
                            n_hidden=n_hidden,
                            width=width,
                            act_fn=act_fn,
                            init_type=init_type,
                            optimiser=optimiser,
                            lr=LR
                        )
                        bp_train_losses_all_seeds = [[] for seed in range(N_SEEDS)]
                        bp_test_losses_all_seeds = bp_train_losses_all_seeds.copy()
                        pc_train_losses_all_seeds = bp_train_losses_all_seeds.copy()
                        pc_test_losses_all_seeds = bp_train_losses_all_seeds.copy()

                        bp_grad_norms_all_seeds = bp_train_losses_all_seeds.copy()
                        pc_grad_norms_all_seeds = bp_train_losses_all_seeds.copy()

                        if dataset != "toy_gaussian":
                            bp_train_accs_all_seeds = bp_train_losses_all_seeds.copy()
                            bp_test_accs_all_seeds = bp_train_losses_all_seeds.copy()
                            pc_train_accs_all_seeds = bp_train_losses_all_seeds.copy()
                            pc_test_accs_all_seeds = bp_train_losses_all_seeds.copy()

                        for seed in range(N_SEEDS):
                            print(f"\nSeed {seed+1}/{N_SEEDS}...")

                            pc_metrics = train_pc(
                                dataset=dataset,
                                arch_type=arch_type,
                                n_hidden=n_hidden,
                                width=width,
                                act_fn=act_fn,
                                init_type=init_type,
                                optimiser=optimiser,
                                seed=seed,
                                save_dir=f"{experiment_dir}/{str(seed)}/pc",
                            )
                            max_epochs = pc_metrics["n_epochs"] if dataset != "toy_gaussian" else MAX_EPOCHS
                            bp_metrics = train_bp(
                                dataset=dataset,
                                arch_type=arch_type,
                                n_hidden=n_hidden,
                                width=width,
                                act_fn=act_fn,
                                init_type=init_type,
                                optimiser=optimiser,
                                seed=seed,
                                max_epochs=max_epochs,
                                save_dir=f"{experiment_dir}/{str(seed)}/bp",
                            )
                            bp_train_losses_all_seeds[seed] = bp_metrics["losses"]["train"]
                            bp_test_losses_all_seeds[seed] = bp_metrics["losses"]["test"]
                            pc_train_losses_all_seeds[seed] = pc_metrics["losses"]["train"]
                            pc_test_losses_all_seeds[seed] = pc_metrics["losses"]["test"]

                            bp_grad_norms_all_seeds[seed] = bp_metrics["grad_norms"]
                            pc_grad_norms_all_seeds[seed] = pc_metrics["grad_norms"]

                            if dataset != "toy_gaussian":
                                bp_train_accs_all_seeds[seed] = bp_metrics["accs"]["train"]
                                bp_test_accs_all_seeds[seed] = bp_metrics["accs"]["test"]
                                pc_train_accs_all_seeds[seed] = pc_metrics["accs"]["train"]
                                pc_test_accs_all_seeds[seed] = pc_metrics["accs"]["test"]

                        bp_train_loss_means, bp_train_loss_stds = compute_metric_stats(
                            metric=bp_train_losses_all_seeds
                        )
                        bp_test_loss_means, bp_test_loss_stds = compute_metric_stats(
                            metric=bp_test_losses_all_seeds
                        )
                        pc_train_loss_means, pc_train_loss_stds = compute_metric_stats(
                            metric=pc_train_losses_all_seeds
                        )
                        pc_test_loss_means, pc_test_loss_stds = compute_metric_stats(
                            metric=pc_test_losses_all_seeds
                        )

                        plot_bp_and_pc_metric_stats(
                            means=[bp_train_loss_means, pc_train_loss_means],
                            stds=[bp_train_loss_stds, pc_train_loss_stds],
                            dataset=dataset,
                            optimiser=optimiser,
                            metric_title="$\LARGE{\mathcal{L}_{\\text{train}}}$",
                            save_path=f"{experiment_dir}/train_loss_stats.pdf"
                        )
                        plot_bp_and_pc_metric_stats(
                            means=[bp_test_loss_means, pc_test_loss_means],
                            stds=[bp_test_loss_stds, pc_test_loss_stds],
                            dataset=dataset,
                            optimiser=optimiser,
                            metric_title="$\LARGE{\mathcal{L}_{\\text{test}}}$",
                            save_path=f"{experiment_dir}/test_loss_stats.pdf"
                        )

                        bp_grad_norm_means, bp_grad_norm_stds = compute_metric_stats(
                            metric=bp_grad_norms_all_seeds
                        )
                        pc_grad_norm_means, pc_grad_norm_stds = compute_metric_stats(
                            metric=pc_grad_norms_all_seeds
                        )
                        plot_bp_vs_pc_grad_norm_stats(
                            means=[bp_grad_norm_means, pc_grad_norm_means],
                            stds=[bp_grad_norm_stds, pc_grad_norm_stds],
                            dataset=dataset,
                            save_path=f"{experiment_dir}/gradient_norm_stats.pdf"
                        )

                        if dataset != "toy_gaussian":
                            bp_train_acc_means, bp_train_acc_stds = compute_metric_stats(
                                metric=bp_train_accs_all_seeds
                            )
                            bp_test_acc_means, bp_test_acc_stds = compute_metric_stats(
                                metric=bp_test_accs_all_seeds
                            )
                            pc_train_acc_means, pc_train_acc_stds = compute_metric_stats(
                                metric=pc_train_accs_all_seeds
                            )
                            pc_test_acc_means, pc_test_acc_stds = compute_metric_stats(
                                metric=pc_test_accs_all_seeds
                            )
                            plot_bp_and_pc_metric_stats(
                                means=[bp_train_acc_means, pc_train_acc_means],
                                stds=[bp_train_acc_stds, pc_train_acc_stds],
                                dataset=dataset,
                                optimiser=optimiser,
                                metric_title="Train accuracy (%)",
                                save_path=f"{experiment_dir}/train_acc_stats.pdf"
                            )
                            plot_bp_and_pc_metric_stats(
                                means=[bp_test_acc_means, pc_test_acc_means],
                                stds=[bp_test_acc_stds, pc_test_acc_stds],
                                dataset=dataset,
                                optimiser=optimiser,
                                metric_title="Test accuracy (%)",
                                save_path=f"{experiment_dir}/test_acc_stats.pdf"
                            )


## Run analysis

In [None]:
main()

In [15]:
import shutil
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%%capture
!zip -r /content/results.zip /content/results

In [None]:
colab_link = "/content/results.zip"
gdrive_link = "/content/drive/MyDrive/"
shutil.copy(colab_link, gdrive_link)

In [None]:
from google.colab import runtime
runtime.unassign()