# Linear Chains Analysis

This notebook compares the learning dynamics of backpropagation (BP) and predictive coding (PC) on linear multi-layer perceptrons with one and two hidden units (1- and 2-MLPs).

## Setup

In [69]:
#@title Installations


%%capture
#!sudo apt install nvidia-utils-515
!pip install plotly==5.11.0
!pip install -U kaleido
!pip install gif==3.0.0


In [70]:
#@title Imports


import os
import random
import numpy as np
from typing import Tuple, Dict, Optional, List, Callable, Union
from numpy.polynomial.polynomial import Polynomial

from jax import jacfwd, jacrev
from jax.numpy.linalg import eigh, norm

import gif
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

import plotly.express as px
from plotly.express.colors import sample_colorscale
import plotly.graph_objs as go
import plotly.figure_factory as ff

In [71]:
#@title Utils


def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)


def make_gaussian_dataset(mean, std, size):
    x = np.random.normal(loc=mean, scale=std, size=size)
    y = -x
    return (x, y)


def hessian(f):
    return jacfwd(jacrev(f))


def compute_theoretical_energy_eigenvals(weights, X, Y):
    n_weights = len(weights)
    theory_hessian_at_origin = np.zeros((n_weights, n_weights))

    n_hidden = len(weights)-1
    theory_hessian_at_origin[-1, -1] = -(Y**2).mean()
    if n_hidden == 1:
        theory_hessian_at_origin[0, 1] = -(X*Y).mean()
        theory_hessian_at_origin[1, 0] = -(X*Y).mean()

    theory_eigenvals, _ = eigh(theory_hessian_at_origin)
    return theory_eigenvals


def get_min_iter(lists):
    min_iter = 100000
    for i in lists:
        if len(i) < min_iter:
            min_iter = len(i)
    return min_iter


def get_min_iter_metric(metric):
    n_seeds = len(metric)
    min_iter = get_min_iter(lists=metric)

    min_iter_metric = np.zeros((n_seeds, min_iter))
    for seed in range(n_seeds):
        min_iter_metric[seed, :] = metric[seed][:min_iter]

    return min_iter_metric


def compute_metric_stats(metric):
    min_iter_metric = get_min_iter_metric(metric=metric)
    metric_means = min_iter_metric.mean(axis=0)
    metric_stds = min_iter_metric.std(axis=0)
    return metric_means, metric_stds


In [72]:
#@title Objective functions


def mse_loss_fun(ws, x, y):
    """Calculates the mean squared error (MSE) loss for a neural chain, i.e. a
    network with one neuron in every layer."""
    weight_prod = np.prod(ws)
    return ( 0.5 * (y - weight_prod*x)**2 ).mean()


def energy_fun(ws, xs, n_iters, dt):
    """Runs iterative inference for a neural network with one or two hidden unit
    and returns the total energy."""
    n_hidden = len(ws)-1

    if n_hidden == 1:
        # initialisation
        Z = xs[1]
        e2 = xs[2] - ws[1]*Z
        e1 = Z - ws[0]*xs[0]

        # iterative inference
        for t in range(n_iters):
            dZ = e1 - ws[1]*e2
            Z += - dZ * dt

            e2 = xs[2] - ws[1]*Z
            e1 = Z - ws[0]*xs[0]

    elif n_hidden == 2:
        # initialisation
        Z1 = xs[1]
        Z2 = xs[2]

        e3 = xs[3] - ws[2]*Z2
        e2 = Z2 - ws[1]*Z1
        e1 = Z1 - ws[0]*xs[0]

        # iterative inference
        for t in range(n_iters):
            dZ1 = e1 - ws[1]*e2
            dZ2 = e2 - ws[2]*e3
            Z1 += - dZ1 * DT
            Z2 += - dZ2 * DT

            e3 = xs[3] - ws[2]*Z2
            e2 = Z2 - ws[1]*Z1
            e1 = Z1 - ws[0]*xs[0]

    elif n_hidden == 5:
        # initialisation
        Z1 = xs[1]
        Z2 = xs[2]
        Z3 = xs[3]
        Z4 = xs[4]
        Z5 = xs[5]

        e6 = xs[6] - ws[5]*Z5
        e5 = Z5 - ws[4]*Z4
        e4 = Z4 - ws[3]*Z3
        e3 = Z3 - ws[2]*Z2
        e2 = Z2 - ws[1]*Z1
        e1 = Z1 - ws[0]*xs[0]

        # iterative inference
        for t in range(n_iters):
            dZ1 = e1 - ws[1]*e2
            dZ2 = e2 - ws[2]*e3
            dZ3 = e3 - ws[3]*e4
            dZ4 = e4 - ws[4]*e5
            dZ5 = e5 - ws[5]*e6
            Z1 += - dZ1 * DT
            Z2 += - dZ2 * DT
            Z3 += - dZ3 * DT
            Z4 += - dZ4 * DT
            Z5 += - dZ5 * DT

            e6 = xs[6] - ws[5]*Z5
            e5 = Z5 - ws[4]*Z4
            e4 = Z4 - ws[3]*Z3
            e3 = Z3 - ws[2]*Z2
            e2 = Z2 - ws[1]*Z1
            e1 = Z1 - ws[0]*xs[0]

    elif n_hidden == 10:
        # initialisation
        Z1 = xs[1]
        Z2 = xs[2]
        Z3 = xs[3]
        Z4 = xs[4]
        Z5 = xs[5]
        Z6 = xs[6]
        Z7 = xs[7]
        Z8 = xs[8]
        Z9 = xs[9]
        Z10 = xs[10]

        e11 = xs[11] - ws[10]*Z10
        e10 = Z10 - ws[9]*Z9
        e9 = Z9 - ws[8]*Z8
        e8 = Z8 - ws[7]*Z7
        e7 = Z7 - ws[6]*Z6
        e6 = Z6 - ws[5]*Z5
        e5 = Z5 - ws[4]*Z4
        e4 = Z4 - ws[3]*Z3
        e3 = Z3 - ws[2]*Z2
        e2 = Z2 - ws[1]*Z1
        e1 = Z1 - ws[0]*xs[0]

        # iterative inference
        for t in range(n_iters):
            dZ1 = e1 - ws[1]*e2
            dZ2 = e2 - ws[2]*e3
            dZ3 = e3 - ws[3]*e4
            dZ4 = e4 - ws[4]*e5
            dZ5 = e5 - ws[5]*e6
            dZ6 = e6 - ws[6]*e7
            dZ7 = e7 - ws[7]*e8
            dZ8 = e8 - ws[8]*e9
            dZ9 = e9 - ws[9]*e10
            dZ10 = e10 - ws[10]*e11
            Z1 += - dZ1 * DT
            Z2 += - dZ2 * DT
            Z3 += - dZ3 * DT
            Z4 += - dZ4 * DT
            Z5 += - dZ5 * DT
            Z6 += - dZ6 * DT
            Z7 += - dZ7 * DT
            Z8 += - dZ8 * DT
            Z9 += - dZ9 * DT
            Z10 += - dZ10 * DT

            e11 = xs[11] - ws[10]*Z10
            e10 = Z10 - ws[9]*Z9
            e9 = Z9 - ws[8]*Z8
            e8 = Z8 - ws[7]*Z7
            e7 = Z7 - ws[6]*Z6
            e6 = Z6 - ws[5]*Z5
            e5 = Z5 - ws[4]*Z4
            e4 = Z4 - ws[3]*Z3
            e3 = Z3 - ws[2]*Z2
            e2 = Z2 - ws[1]*Z1
            e1 = Z1 - ws[0]*xs[0]

    if n_hidden == 1:
        energy = (0.5*e1**2 + 0.5*e2**2).mean()
    elif n_hidden == 2:
        energy = (0.5*e1**2 + 0.5*e2**2 + 0.5*e3**2).mean()
    elif n_hidden == 5:
        energy = (0.5*e1**2 + 0.5*e2**2 + 0.5*e3**2 + 0.5*e4**2 + 0.5*e5**2 + 0.5*e6**2).mean()
    elif n_hidden == 10:
        energy = (0.5*e1**2 + 0.5*e2**2 + 0.5*e3**2 + 0.5*e4**2 + 0.5*e5**2 + 0.5*e6**2 + 0.5*e7**2 + 0.5*e8**2 + 0.5*e9**2 + 0.5*e10**2 + 0.5*e11**2).mean()

    return energy


In [73]:
#@title Config


HIDDEN_UNITS = [1, 2, 5, 10]
N_SEEDS = 5

RESULTS_DIR = "results"

# dataset
DATA_MEAN, DATA_STD = 1., 0.1
BATCH_SIZE = 64

# PC hyperparameters
N_ITERS = 20
DT = 0.1

# optimization hyperparameters
WEIGHT_SCALE = 5e-2
LR = 0.4

# landscape plotting
SAMPLING_RESOLUTION = 30
COLORSCALE = "RdBu_r"
GRADIENT_FIELD_SCALE = 0.1
PLOT_INFER_ITERATIONS = [0, 1, 5, 10, 20, 30, 40, 50, 100]


In [74]:
#@title Landscape plotting


@gif.frame
def plot_objective_contour(
        objective_mesh: np.ndarray,
        weights: Tuple[np.ndarray, np.ndarray],
        vector_field: np.ndarray,
        vector_field_scale: float,
        objective_name: str,
        title: str,
        save_path: str,
        weight_updates: Optional[Tuple[np.ndarray, np.ndarray]] = None,
        smooth_contours: bool = True
    ):
    """Plots the contours or level sets of a given 2D objective or cost function
    and its gradient as a superimposed vector field.

    The gradient field is standardised and rescaled for easier visualisation."""

    contours_coloring = "heatmap" if smooth_contours else "fill"

    # contour plot
    contour = go.Contour(
        z=objective_mesh,
        x=weights[0],
        y=weights[1],
        colorscale=COLORSCALE,
        showscale=False,
        contours_coloring=contours_coloring
    )

    # gradient vector field
    vector_field = (vector_field - vector_field.mean()) / vector_field.std()
    w1_mesh, w2_mesh = np.meshgrid(weights[0], weights[1])
    quiver = ff.create_quiver(
        x=w1_mesh,
        y=w2_mesh,
        u=vector_field[:, :, 0],
        v=vector_field[:, :, 1],
        marker_color="rgb(255, 255, 51)",
        opacity=0.9,
        scale=vector_field_scale,
        line_width=0.8,
        showlegend=False
    )
    fig = go.FigureWidget(data=[contour, quiver.data[0]])

    # colorbar
    objective = "\mathcal{L}" if objective_name == "loss" else "\mathcal{F}"
    max_objective, min_objective = objective_mesh.max(), objective_mesh.min()
    colorbar_trace = go.Scatter(
        x=[None],
        y=[None],
        mode="markers",
        showlegend=False,
        marker=dict(
            colorscale=COLORSCALE,
            showscale=True,
            cmin=min_objective,
            cmax=max_objective,
            colorbar=dict(
                title=f"$\LARGE{{{objective}}}$",
                len=0.5,
                title_side="right",
                tickfont=dict(size=16),
                tickvals=[min_objective, max_objective],
                ticktext=["Low", "High"]
            )
        ),
        hoverinfo="none"
    )
    fig.add_trace(colorbar_trace)

    # example training trajectory
    if weight_updates is not None:
        marker_color = "rgb(255, 255, 51)"

        n_updates = len(weight_updates[0])
        for t in range(n_updates):
            if t == 0:
                text = ["$\huge{{w^0}}$"]
                symbol = "diamond"
                size = 10
            elif t == (n_updates-1):
                text = ["$\huge{{w^*}}$"]
                symbol = "cross"
                size = 13
            else:
                text = []
                symbol = "circle"
                size = 4

            fig.add_traces(
                go.Scatter(
                    x=[weight_updates[0][t]],
                    y=[weight_updates[1][t]],
                    mode="markers+text",
                    marker=dict(size=size, color=marker_color, symbol=symbol),
                    showlegend=False,
                    text=text,
                    textposition="top right",
                    textfont=dict(color=marker_color)
                )
            )
            fig.update_layout(
                title=dict(
                    text=title,
                    y=0.85,
                    x=0.1,
                    xanchor="left",
                    yanchor="top"
                ),
                xaxis=dict(title="$\LARGE{w_1}$", showticklabels=False),
                yaxis=dict(title="$\LARGE{w_2}$", showticklabels=False),
                font=dict(size=16),
                plot_bgcolor="white",
                width=500,
                height=400,
                margin=dict(r=100, b=50, l=50, t=80)
            )
            fig.write_image(f"{save_path}_train_iter_{t}.pdf")

    else:
        fig.update_layout(
            title=dict(
                text=title,
                y=0.85,
                x=0.1,
                xanchor="left",
                yanchor="top"
            ),
            xaxis=dict(title="$\LARGE{w_1}$", showticklabels=False),
            yaxis=dict(title="$\LARGE{w_2}$", showticklabels=False),
            font=dict(size=16),
            plot_bgcolor="white",
            width=500,
            height=400,
            margin=dict(r=100, b=50, l=50, t=80)
        )
        fig.write_image(save_path)
        return fig


@gif.frame
def plot_objective_surface(
        objective_mesh: np.ndarray,
        weights: Tuple[np.ndarray, np.ndarray],
        save_path: str,
        weight_updates: Optional[Tuple[np.ndarray, np.ndarray]] = None,
        train_losses: Optional[List[float]] = None
    ) -> go.Figure():
    """Plots and saves a given 2D objective or cost function as a surface."""

    fig = go.Figure(
        data=go.Surface(
            z=objective_mesh,
            x=weights[0],
            y=weights[1],
            colorscale=COLORSCALE,
        )
    )
    fig.update_traces(
        contours_z=dict(
            show=True,
            usecolormap=True,
            highlightcolor="limegreen",
            project_z=True
        ),
        showscale=False
    )

    # example weight trajectory
    if weight_updates is not None:
        fig.add_scatter3d(
            x=weight_updates[0],
            y=weight_updates[1],
            z=np.array(train_losses)+4e-2,
            mode="markers",
            marker=dict(
                line=dict(width=2, color="black"),
                size=5,
                color="rgb(255, 255, 51)"
            )
        )

    fig.update_layout(
        scene=dict(
            xaxis=dict(
                title="",
                nticks=3,
                autorange="reversed"
            ),
            yaxis=dict(
                title="",
                nticks=3,
                autorange="reversed"
            ),
            zaxis=dict(
                title="",
                showticklabels=False
            )
        ),
        scene_camera=dict(
            center=dict(x=0.05, y=0.1, z=0),
            eye=dict(x=0.75, y=1.8, z=1.25)
        ),
        font=dict(size=16),
        height=600,
        width=700,
        scene_aspectmode="cube"
    )
    fig.write_image(save_path)
    return fig


@gif.frame
def plot_objective_volume(
      objective_mesh: np.ndarray,
      input_domain: Union[int, float],
      sampling_resolution: int,
      save_path: str,
      weight_updates: Optional[List[np.ndarray]] = None
    ) -> go.Figure():
    """Plots and saves a given 3D objective or cost function as a volume plot."""

    w1s, w2s, w3s = np.mgrid[
        -input_domain:input_domain:complex(sampling_resolution),
        -input_domain:input_domain:complex(sampling_resolution),
        -input_domain:input_domain:complex(sampling_resolution)
    ]

    fig = go.Figure()
    fig.add_traces(go.Volume(
        x=w1s.flatten(),
        y=w2s.flatten(),
        z=w3s.flatten(),
        value=objective_mesh.flatten(),
        opacity=0.5,
        surface_count=20,
        colorscale=COLORSCALE
    ))
    fig.update_traces(showscale=False)

    if weight_updates is not None:
        marker_color = "rgb(255, 255, 51)"
        marker_opacity = 0.2
        n_updates = len(weight_updates[0])
        for t in range(n_updates):

            if t == 0:
                symbol = "diamond"
                size = 8
            elif t == (n_updates-1):
                symbol = "cross"
                size = 12
            else:
                text = []
                symbol = "circle"
                size = 6

            #if t % 5 == 0:
            fig.add_traces(go.Scatter3d(
                x=[weight_updates[0][t]],
                y=[weight_updates[1][t]],
                z=[weight_updates[2][t]],
                mode="markers",
                marker=dict(
                  size=size,
                  color=marker_color,
                  opacity=marker_opacity,
                  symbol=symbol
                ),
                showlegend=False
            ))

            fig.update_layout(
                scene=dict(
                    xaxis=dict(
                        title="",
                        nticks=3,
                        autorange="reversed"
                    ),
                    yaxis=dict(
                        title="",
                        nticks=3,
                        autorange="reversed"
                    ),
                    zaxis=dict(
                        title="",
                        showticklabels=False
                    )
                ),
                scene_camera=dict(eye=dict(x=2.5, y=1.25, z=1.25)),
                font=dict(size=16),
                height=800,
                width=800,
                scene_aspectmode="data"
            )
            fig.write_image(f"{save_path}_train_iter_{t}.pdf")

    else:
        fig.update_layout(
            scene=dict(
                xaxis=dict(
                    title="",
                    nticks=3,
                    autorange="reversed",
                    showticklabels=False,
                    showbackground=False
                ),
                yaxis=dict(
                    title="",
                    nticks=3,
                    autorange="reversed",
                    showticklabels=False,
                    showbackground=False
                ),
                zaxis=dict(
                    title="",
                    showticklabels=False,
                    showbackground=False
                )
            ),
            scene_camera=dict(eye=dict(x=2.5, y=1.25, z=1.25)),
            font=dict(size=16),
            height=800,
            width=800,
            scene_aspectmode="cube"
        )
        fig.write_image(save_path)
        return fig



In [75]:
#@title Plotting


def plot_losses(losses: Dict, save_path: str) -> None:
    """Plots and saves train and test losses."""
    n_train_iters = len(losses["train"])
    train_iters = [b+1 for b in range(n_train_iters)]

    fig = go.Figure()
    for loss_type, loss in losses.items():
        fig.add_traces(
            go.Scatter(
                x=train_iters,
                y=loss,
                name=loss_type,
                mode="lines",
                line=dict(width=3)
            )
        )

    fig.update_layout(
        height=300,
        width=300,
        xaxis=dict(
            title="Training iteration",
            tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
            ticktext=[1, int(train_iters[-1]/2), train_iters[-1]],
        ),
        yaxis=dict(
            title="$\LARGE{\mathcal{L}}$",
            nticks=3
        ),
        font=dict(size=16),
    )
    fig.write_image(save_path)


def plot_energies(
        energies: List[np.ndarray],
        n_infer_iters: int,
        save_path: str,
        show_tot: bool = False
    ) -> None:
    """Plots and saves train and test energies."""
    n_hidden = len(energies)-1

    train_time = len(energies[0])
    time_points = [b+1 for b in range(train_time)]
    n_train_iters = int(train_time / n_infer_iters+1)
    ts = [((n_infer_iters+1)*train_iter - 1) for train_iter in range(1, n_train_iters-1)]

    fig = go.Figure()
    colors = ["#636EFA", "#EF553B", "#00CC96"]
    for i, (energy, color) in enumerate(zip(energies, colors)):
        fig.add_traces(
            go.Scatter(
                x=time_points,
                y=energy,
                name=f"$\mathcal{{F}}_{{{i+1}}}$",
                mode="lines",
                line=dict(width=2, color=color)
            )
        )

    if show_tot:
        color = "#FFA15A"
        fig.add_traces(
            go.Scatter(
                x=time_points,
                y=energies[0]+energies[1],
                name="$\mathcal{F}_{tot}$",
                line=dict(width=2, color=color)
            )
        )

    fig.update_layout(
        height=300,
        width=400,
        margin=dict(
            l=50,
            r=50,
            b=100,
            t=100,
            pad=4
        ),
        xaxis=dict(
            title="Iteration",
            tickvals=[1, int(time_points[-1]/2), time_points[-1]],
            ticktext=[1, int(time_points[-1]/2), time_points[-1]],
        ),
        yaxis=dict(
            title="$\LARGE{\mathcal{F}}$",
            nticks=3
        ),
        font=dict(size=16),
    )
    fig.write_image(save_path)


def plot_updates(updates: List[np.ndarray], update_type: str, save_path: str) -> None:
    """Plots and saves updates."""
    n_weights = len(updates)
    n_train_iters = len(updates[0])
    train_iters = [b+1 for b in range(n_train_iters)]

    fig = go.Figure()
    names = [f"$w_{i+1}$" if update_type == "weights" else f"$\partial w_{i+1}$" for i in range(n_weights)]
    colors = ["#AB36FA", "#FFA15A", "#19D3F3"]
    for update, name, color in zip(updates, names, colors):
        fig.add_traces(
            go.Scatter(
                x=train_iters,
                y=update,
                name=name,
                mode="lines",
                line=dict(width=3, color=color)
            )
        )

    fig.update_layout(
        height=300,
        width=300,
        xaxis=dict(
            title="Training iteration",
            tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
            ticktext=[1, int(train_iters[-1]/2), train_iters[-1]],
        ),
        yaxis=dict(
            title="Value",
            nticks=3
        ),
        font=dict(size=16),
    )
    fig.write_image(save_path)


@gif.frame
def plot_predictions(
        targets: np.ndarray,
        predictions: np.ndarray,
        title: Optional[str] = None,
        save_path: Optional[str] = None
    ) -> go.Figure():
    """Plots and saves predictions against targets."""

    n_example = len(targets)
    examples = [i+1 for i in range(n_example)]

    fig = go.Figure()
    fig.add_traces(
        go.Scatter(
            x=examples,
            y=targets,
            name="true",
            mode="lines",
            line=dict(width=3)
        )
    )
    fig.add_traces(
        go.Scatter(
            x=examples,
            y=predictions,
            mode="lines",
            name="prediction",
            line=dict(width=3)
        )
    )

    fig.update_layout(
        height=300,
        width=600,
        xaxis=dict(
            title="Data",
            tickvals=[1, int(examples[-1]/2), examples[-1]],
            ticktext=[1, int(examples[-1]/2), examples[-1]],
        ),
        yaxis=dict(
            title="$\LARGE{y}$",
            nticks=3
        ),
        font=dict(size=16),
    )
    if title is not None:
        fig.update_layout(
            title=dict(
                text=title,
                y=0.82,
                x=0.45,
                xanchor="center",
                yanchor="top"
            )
        )
    if save_path is not None:
        fig.write_image(save_path)

    return fig


@gif.frame
def plot_hessian_matrix(hessian_matrix, save_path, title=None):
    fig, ax = plt.subplots()
    heatmap = ax.imshow(
        X=hessian_matrix,
        cmap="bwr",
        vmin=-1,
        vmax=1
    )
    cbar = fig.colorbar(heatmap, ax=ax, location="right", ticks=[-1, 0, 1])
    cbar.ax.tick_params(labelsize=25)
    ticks = np.arange(len(hessian_matrix), dtype=int)
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.set_xticklabels(ticks+1)
    ax.set_yticklabels(ticks+1)

    if title is not None:
        plt.title(title, fontsize=20)
    fig.savefig(save_path)
    return fig


def plot_loss_and_energy_hessian_eigenvals(hessian_eigenvals: List, save_path: str) -> None:
    fig = go.Figure()
    names = ["loss", "energy (numeric)", "energy (theory)"]
    colors = ["#EF553B", "#636EFA", "rgba(0,0,0,0)"]
    for eigenval, name, color in zip(hessian_eigenvals, names, colors):
        fig.add_trace(
            go.Histogram(
                x=eigenval,
                histnorm="probability",
                nbinsx=10,
                name=name,
                marker=dict(
                    color=color,
                    line=dict(
                        color="black",
                        width=2 if "theory" in name else 0
                    )
                ),
            )
        )

    fig.update_layout(
        barmode="overlay",
        height=360,
        width=525,
        title=dict(
            y=0.75,
            x=0.5,
            xanchor="center",
            yanchor="top"
        ),
        xaxis=dict(title="Hessian eigenvalue"),
        yaxis=dict(
            title=f"Density (log)",
            type="log",
            exponentformat="power",
            dtick=1
        ),
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01
        ),
        font=dict(size=18)
    )
    fig.update_layout(
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01,
            font=dict(size=16)
        )
    )
    fig.update_traces(opacity=0.75)
    fig.write_image(save_path)


def plot_bp_and_pc_loss_stats(
        means: Tuple[np.ndarray],
        stds: Tuple[np.ndarray],
        loss_title: str,
        save_path: str
    ) -> None:
    max_train_iter = len(means[0]) if len(means[0]) >= len(means[1]) else len(means[1])

    fig = go.Figure()
    for i in range(2):
        n_train_iters = len(means[i])
        train_iters = [b+1 for b in range(n_train_iters)]

        color = "#EF553B" if i == 0 else "#636EFA"
        y_upper, y_lower = means[i] + stds[i], means[i] - stds[i]

        fig.add_traces(
            go.Scatter(
                x=list(train_iters)+list(train_iters[::-1]),
                y=list(y_upper)+list(y_lower[::-1]),
                fill="toself",
                fillcolor=color,
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                opacity=0.3
            )
        )
        fig.add_traces(
            go.Scatter(
                x=train_iters,
                y=means[i],
                name="BP" if i == 0 else "PC",
                mode="lines+markers",
                line=dict(width=3, color=color)
            )
        )

    fig.update_layout(
        height=300,
        width=350,
        xaxis=dict(
            title="Training iteration",
            tickvals=[1, int(max_train_iter/2), max_train_iter],
            ticktext=[1, int(max_train_iter/2), max_train_iter],
        ),
        yaxis=dict(
            title=loss_title,
            nticks=3
        ),
        font=dict(size=16),
    )
    fig.write_image(save_path)


def plot_bp_vs_pc_grad_norm_stats(
        means: Tuple[np.ndarray],
        stds: Tuple[np.ndarray],
        save_path: str
    ) -> None:
    max_train_iter = len(means[0]) if len(means[0]) >= len(means[1]) else len(means[1])

    fig = go.Figure()
    for i in range(2):
        n_train_iters = len(means[i])
        train_iters = [b+1 for b in range(n_train_iters)]

        color = "#EF553B" if i == 0 else "#636EFA"
        y_upper, y_lower = means[i] + stds[i], means[i] - stds[i]

        fig.add_traces(
            go.Scatter(
                x=list(train_iters)+list(train_iters[::-1]),
                y=list(y_upper)+list(y_lower[::-1]),
                fill="toself",
                fillcolor=color,
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                opacity=0.3
            )
        )
        fig.add_traces(
            go.Scatter(
                x=train_iters,
                y=means[i],
                name="BP" if i == 0 else "PC",
                mode="lines+markers",
                line=dict(width=3, color=color)
            )
        )

    fig.update_layout(
        height=300,
        width=350,
        xaxis=dict(
            title="Training iteration",
            tickvals=[1, int(max_train_iter/2), max_train_iter],
            ticktext=[1, int(max_train_iter/2), max_train_iter],
        ),
        yaxis=dict(
            title="$\Large{||\partial \\theta||_2}$",
            nticks=3
        ),
        font=dict(size=16),
    )
    fig.write_image(save_path)


## Scripts

In [76]:
#@title Loss landscape visualisation script


def visualise_loss_1mlp_landscape(
        domain: int,
        x: np.ndarray,
        y: np.ndarray,
        weight_updates: List[np.ndarray],
        train_losses: List[float],
        save_dir: str
    ) -> None:
    w1s = np.linspace(-domain, domain, SAMPLING_RESOLUTION)
    w2s = np.linspace(-domain, domain, SAMPLING_RESOLUTION)

    loss_mesh = np.zeros((SAMPLING_RESOLUTION, SAMPLING_RESOLUTION))
    gradient_field = np.zeros((SAMPLING_RESOLUTION, SAMPLING_RESOLUTION, 2))
    for k, w1 in enumerate(w1s):
        for j, w2 in enumerate(w2s):
            loss = ( 0.5 * (y - w2*w1*x )**2 ).mean()
            loss_mesh[j, k] = loss

            error = y - w2*w1*x
            dw1 = ( - error * w2*x ).mean()
            dw2 = ( - error * w1*x ).mean()
            gradient_field[j, k, 0] = - dw1
            gradient_field[j, k, 1] = - dw2


    plot_objective_surface(
        objective_mesh=loss_mesh,
        weights=[w1s, w2s],
        save_path=f"{save_dir}/loss_landscape_surface_{domain}.pdf",
        weight_updates=weight_updates,
        train_losses=train_losses
    )
    plot_objective_contour(
        objective_mesh=loss_mesh,
        weights=[w1s, w2s],
        vector_field=gradient_field,
        vector_field_scale=GRADIENT_FIELD_SCALE,
        objective_name="loss",
        title="BP",
        save_path=f"{save_dir}/loss_landscape_contour_{domain}.pdf",
    )
    plot_objective_contour(
        objective_mesh=loss_mesh,
        weights=[w1s, w2s],
        vector_field=gradient_field,
        vector_field_scale=GRADIENT_FIELD_SCALE,
        objective_name="loss",
        title="BP",
        save_path=f"{save_dir}/loss_landscape_contour_{domain}",
        weight_updates=weight_updates
    )


def visualise_loss_2mlp_landscape(
        domain: int,
        x: np.ndarray,
        y: np.ndarray,
        weight_updates: List[np.ndarray],
        save_dir: str
    ) -> None:
    w1s = np.linspace(-domain, domain, SAMPLING_RESOLUTION)
    w2s = np.linspace(-domain, domain, SAMPLING_RESOLUTION)
    w3s = np.linspace(-domain, domain, SAMPLING_RESOLUTION)

    loss_mesh = np.zeros((SAMPLING_RESOLUTION, SAMPLING_RESOLUTION, SAMPLING_RESOLUTION))
    for k, w1 in enumerate(w1s):
        for j, w2 in enumerate(w2s):
            for i, w3 in enumerate(w3s):
                loss = ( 0.5 * (y - w3*w2*w1*x )**2 ).mean()
                loss_mesh[j, k, i] = loss

    plot_objective_volume(
        objective_mesh=loss_mesh,
        input_domain=domain,
        sampling_resolution=SAMPLING_RESOLUTION,
        weight_updates=weight_updates,
        save_path=f"{save_dir}/loss_landscape_volume_{domain}"
    )


def visualise_loss_landscape(
        domain: int,
        x: np.ndarray,
        y: np.ndarray,
        weight_updates: List[np.ndarray],
        train_losses: List[float],
        save_dir: str
    ):
    if len(weight_updates) == 2:
        visualise_loss_1mlp_landscape(
            domain=domain,
            x=x,
            y=y,
            weight_updates=weight_updates,
            train_losses=train_losses,
            save_dir=save_dir
        )
    elif len(weight_updates) == 3:
        visualise_loss_2mlp_landscape(
            domain=domain,
            x=x,
            y=y,
            weight_updates=weight_updates,
            save_dir=save_dir
        )


In [77]:
#@title Energy landscape visualisation script


def visualise_energy_1mlp_landscape(
        domain: int,
        x: np.ndarray,
        y: np.ndarray,
        n_iters: int,
        dt: float,
        weight_updates: List[np.ndarray],
        train_losses: List[float],
        save_dir: str
    ) -> None:
    w1s = np.linspace(-domain, domain, SAMPLING_RESOLUTION)
    w2s = np.linspace(-domain, domain, SAMPLING_RESOLUTION)

    energy_mesh = np.zeros((SAMPLING_RESOLUTION, SAMPLING_RESOLUTION, n_iters+1))
    gradient_field = np.zeros((SAMPLING_RESOLUTION, SAMPLING_RESOLUTION, n_iters+1, 2))
    for k, w1 in enumerate(w1s):
        for j, w2 in enumerate(w2s):

            z = w1*x
            e2 = y - w2*z
            e1 = z - w1*x
            energy = (0.5 * e1**2 + 0.5 * e2**2).mean()
            energy_mesh[j, k, 0] = energy

            dw1 = ( - e1*x ).mean()
            dw2 = ( - e2*z ).mean()
            gradient_field[j, k, 0, 0] = - dw1
            gradient_field[j, k, 0, 1] = - dw2

            for t in range(1, n_iters+1):
                dz = e1 - w2*e2
                z += - dz * dt

                e2 = (y - w2*z)
                e1 = (z - w1*x)
                energy = (0.5 * e1**2 + 0.5 * e2**2).mean()
                energy_mesh[j, k, t] = energy

                dw1 = ( - e1*x ).mean()
                dw2 = ( - e2*z ).mean()
                gradient_field[j, k, t, 0] = - dw1
                gradient_field[j, k, t, 1] = - dw2

    # equilibrated energy landscape, w/ updates
    plot_objective_contour(
        objective_mesh=energy_mesh[:, :, -1],
        weights=[w1s, w2s],
        vector_field=gradient_field[:, :, -1],
        vector_field_scale=GRADIENT_FIELD_SCALE,
        objective_name="energy",
        title="PC",
        save_path=f"{save_dir}/equilib_energy_landscape_contour_{domain}",
        weight_updates=weight_updates
    )

    # energy landscape inference dynamics, w/o updates
    surface_frames, contour_frames = [], []
    for t in range(n_iters+1):
        if t in PLOT_INFER_ITERATIONS:

            fig = plot_objective_surface(
                objective_mesh=energy_mesh[:, :, t],
                weights=[w1s, w2s],
                save_path=f"{save_dir}/energy_landscape_surface_{domain}_iter_{t}.pdf"
            )
            surface_frames.append(fig)

            fig = plot_objective_contour(
                objective_mesh=energy_mesh[:, :, t],
                weights=[w1s, w2s],
                vector_field=gradient_field[:, :, t],
                vector_field_scale=GRADIENT_FIELD_SCALE,
                objective_name="energy",
                title=f"PC, t = {t}",
                save_path=f"{save_dir}/energy_landscape_contour_{domain}_iter_{t}.pdf"
            )
            contour_frames.append(fig)

    plot_objective_surface(
        objective_mesh=energy_mesh[:, :, -1],
        weights=[w1s, w2s],
        save_path=f"{save_dir}/equilib_energy_landscape_surface_{domain}_updates.pdf",
        weight_updates=weight_updates,
        train_losses=train_losses
    )

    gif.save(
        frames=surface_frames,
        path=f"{save_dir}/energy_landscape_surface_{domain}_infer_dynamics.gif",
        duration=1,
        unit="s"
    )
    gif.save(
        frames=contour_frames,
        path=f"{save_dir}/energy_landscape_contour_{domain}_infer_dynamics.gif",
        duration=1,
        unit="s"
    )


def visualise_energy_2mlp_landscape(
        domain: int,
        x: np.ndarray,
        y: np.ndarray,
        n_iters: int,
        dt: float,
        weight_updates: List[np.ndarray],
        save_dir: str
    ) -> None:
    w1s = np.linspace(-domain, domain, SAMPLING_RESOLUTION)
    w2s = np.linspace(-domain, domain, SAMPLING_RESOLUTION)
    w3s = np.linspace(-domain, domain, SAMPLING_RESOLUTION)

    energy_mesh = np.zeros(
        (
            SAMPLING_RESOLUTION,
            SAMPLING_RESOLUTION,
            SAMPLING_RESOLUTION,
            N_ITERS+1
        )
    )
    for k, w1 in enumerate(w1s):
        for j, w2 in enumerate(w2s):
            for i, w3 in enumerate(w3s):

                Z1 = w1*x
                Z2 = w2*Z1
                e3 = y - w3*Z2
                e2 = Z2 - w2*Z1
                e1 = Z1 - w1*x
                energy = (0.5 * e1**2 + 0.5 * e2**2 + 0.5 * e3**2).mean()
                energy_mesh[j, k, i, 0] = energy

                for t in range(1, N_ITERS+1):
                    dZ1 = - e1 + w2*e2
                    dZ2 = - e2 + w3*e3
                    Z1 += dZ1 * DT
                    Z2 += dZ2 * DT

                    e3 = y - w3*Z2
                    e2 = Z2 - w2*Z1
                    e1 = Z1 - w1*x
                    energy = (0.5 * e1**2 + 0.5 * e2**2 + 0.5 * e3**2).mean()
                    energy_mesh[j, k, i, t] = energy

    # equilibrated energy volume, w/ updates
    plot_objective_volume(
        objective_mesh=energy_mesh[:, :, :, -1],
        input_domain=domain,
        sampling_resolution=SAMPLING_RESOLUTION,
        save_path=f"{save_dir}/equilib_energy_volume_{domain}",
        weight_updates=weight_updates
    )

    # energy volume inference dynamics, w/o updates
    volume_frames = []
    for t in range(N_ITERS+1):
        if t in PLOT_INFER_ITERATIONS:

            fig = plot_objective_volume(
                objective_mesh=energy_mesh[:, :, :, t],
                input_domain=domain,
                sampling_resolution=SAMPLING_RESOLUTION,
                save_path=f"{save_dir}/energy_landscape_volume_{domain}_iter_{t}.pdf"
            )
            volume_frames.append(fig)

    gif.save(
        frames=volume_frames,
        path=f"{save_dir}/energy_landscape_volume_infer_dynamics_{domain}.gif",
        duration=1000
    )


def visualise_energy_landscape(
        domain: int,
        x: np.ndarray,
        y: np.ndarray,
        n_iters: int,
        dt: float,
        weight_updates: List[np.ndarray],
        train_losses: List[float],
        save_dir: str
    ) -> None:
    if len(weight_updates) == 2:
        visualise_energy_1mlp_landscape(
            domain=domain,
            x=x,
            y=y,
            n_iters=N_ITERS,
            dt=DT,
            weight_updates=weight_updates,
            train_losses=train_losses,
            save_dir=save_dir
        )
    elif len(weight_updates) == 3:
        visualise_energy_2mlp_landscape(
            domain=domain,
            x=x,
            y=y,
            n_iters=N_ITERS,
            dt=DT,
            weight_updates=weight_updates,
            save_dir=save_dir
        )


In [78]:
#@title BP training script


def train_bp(weights: np.ndarray, save_dir: str):
    n_hidden = len(weights)-1
    print(f"------------------------------------------------------")
    print(f"Starting training of {n_hidden}-MLP with BP...\n")
    os.makedirs(save_dir, exist_ok=True)

    if n_hidden == 1:
        w1, w2 = weights[0], weights[1]
    else:
        w1, w2, w3 = weights[0], weights[1], weights[2]

    # metrics
    ys, ys_hat, train_losses, test_losses = [], [], [], []
    w1_updates, w2_updates = [w1], [w2]
    dw1s, dw2s = [], []
    grad_norms = []
    if n_hidden == 2:
         w3_updates = [w3]
         dw3s = []

    max_train_iters = 15 if n_hidden == 1 else 100
    # training
    for batch in range(max_train_iters):
        # data
        X, Y = make_gaussian_dataset(
            mean=DATA_MEAN,
            std=DATA_STD,
            size=BATCH_SIZE
        )
        ys.append(Y)

        # loss
        y_hat = w2*w1*X if n_hidden == 1 else w3*w2*w1*X
        train_loss = ( 0.5 * (Y - y_hat)**2 ).mean()
        train_losses.append(train_loss)
        ys_hat.append(y_hat)

        # weight gradient
        error = Y - w2*w1*X if n_hidden == 1 else Y - w3*w2*w1*X
        dw1 = ( - error * w2*X ).mean() if n_hidden == 1 else ( - error * w3*w2*X ).mean()
        dw2 = ( - error * w1*X ).mean() if n_hidden == 1 else ( - error * w3*w1*X ).mean()
        dw1s.append(dw1)
        dw2s.append(dw2)
        if n_hidden == 2:
           dw3 = ( - error * w2*w1*X ).mean()
           dw3s.append(dw3)

        grad = np.array([[dw1], [dw2]]) if n_hidden == 1 else np.array([[dw1], [dw2], [dw3]])
        grad_norms.append(norm(grad))

        # weight update
        w1 += - LR * (dw1)
        w2 += - LR * (dw2)
        w1_updates.append(w1)
        w2_updates.append(w2)
        if n_hidden == 2:
            w3 += - LR * (dw3)
            w3_updates.append(w3)

        # test loss
        x = 1
        y = -x
        test_loss = ( 0.5 * (y - w2*w1*x)**2 ).mean() if n_hidden == 1 else ( 0.5 * (y - w3*w2*w1*x)**2 ).mean()
        test_losses.append(test_loss)
        print(f"test loss: {test_loss:.5f}")

    print(f"\nTraining stopped at batch {batch+1} with test loss {test_loss:.5f}\n")

    # plot losses, weight updates & gradient dynamics
    plot_losses(
        losses={"train": train_losses, "test": test_losses},
        save_path=f"{save_dir}/losses.pdf"
    )
    weight_updates = [w1_updates, w2_updates] if n_hidden == 1 else [w1_updates, w2_updates, w3_updates]
    plot_updates(
        updates=weight_updates,
        update_type="weights",
        save_path=f"{save_dir}/weights.pdf"
    )
    gradient_updates = [dw1s, dw2s] if n_hidden == 1 else [dw1s, dw2s, dw3s]
    plot_updates(
        updates=gradient_updates,
        update_type="gradient",
        save_path=f"{save_dir}/gradient.pdf"
    )

    # plot prediction dynamics
    n_batches = len(ys)
    pred_frames = []
    for batch in range(n_batches):
        fig = plot_predictions(
            targets=ys[batch],
            predictions=ys_hat[batch],
            title=f"Training iteration {batch+1}"
        )
        pred_frames.append(fig)

    gif.save(
        frames=pred_frames,
        path=f"{save_dir}/prediction_learning_dynamics.gif",
        duration=200
    )

    # visualise learning dynamics
    X, Y = make_gaussian_dataset(
        mean=DATA_MEAN,
        std=DATA_STD,
        size=BATCH_SIZE
    )
    if N_SEEDS == 1:
        visualise_loss_landscape(
            domain=1 if n_hidden == 1 else 2,
            x=X,
            y=Y,
            weight_updates=weight_updates,
            train_losses=train_losses,
            save_dir=save_dir
        )
    return train_losses, test_losses, grad_norms


In [82]:
#@title PC training script


def train_pc(weights: np.ndarray, save_dir: str):
    n_hidden = len(weights)-1
    print(f"------------------------------------------------------")
    print(f"Starting training of {n_hidden}-MLP with PC...\n")
    os.makedirs(save_dir, exist_ok=True)

    if n_hidden == 1:
        w1, w2 = weights[0], weights[1]
    else:
        w1, w2, w3 = weights[0], weights[1], weights[2]

    # learning metrics
    ys, ys_hat, energies1, energies2 = [], [], [], []
    train_losses, test_losses = [], []
    w1_updates, w2_updates = [w1], [w2]
    dw1s, dw2s = [], []
    gradient_norms = []
    if n_hidden == 2:
        energies3 = []
        w3_updates = [w3]
        dw3s = []

    # training
    for batch in range(50):
        # inference metric
        ys_hat_iters = np.zeros((BATCH_SIZE, N_ITERS+1))

        # data
        X, Y = make_gaussian_dataset(
            mean=DATA_MEAN,
            std=DATA_STD,
            size=BATCH_SIZE
        )
        ys.append(Y)

        if n_hidden == 1:
            # initialisation
            Z = w1*X
            e2 = Y - w2*Z
            e1 = Z - w1*X
            energy = (0.5 * e1**2 + 0.5 * e2**2).mean()

            ys_hat_iters[:, 0] = w2*Z
            energies1.append((0.5 * e1**2).mean())
            energies2.append((0.5 * e2**2).mean())

            # iterative inference
            for t in range(1, N_ITERS+1):
                dZ = e1 - w2*e2
                Z += - dZ * DT

                e2 = Y - w2*Z
                e1 = Z - w1*X
                energy = (0.5 * e1**2 + 0.5 * e2**2).mean()

                ys_hat_iters[:, t] = w2*Z
                energies1.append((0.5 * e1**2).mean())
                energies2.append((0.5 * e2**2).mean())

            print(f"z^*: {Z.mean()}")
            print(f"theory z^*: {( (w1*X + w2*Y)/(1+w2**2) ).mean()}\n")

        elif n_hidden == 2:
            Z1 = w1*X
            Z2 = w2*Z1
            e3 = Y - w3*Z2
            e2 = Z2 - w2*Z1
            e1 = Z1 - w1*X
            energy = (0.5 * e1**2 + 0.5 * e2**2 + 0.5 * e3**2).mean()

            ys_hat_iters[:, 0] = w3*Z2
            energies1.append((0.5 * e1**2).mean())
            energies2.append((0.5 * e2**2).mean())
            energies3.append((0.5 * e3**2).mean())

            for t in range(1, N_ITERS+1):
                dZ1 = e1 - w2*e2
                dZ2 = e2 - w3*e3
                Z1 += - dZ1 * DT
                Z2 += - dZ2 * DT

                e3 = Y - w3*Z2
                e2 = Z2 - w2*Z1
                e1 = Z1 - w1*X
                energy = (0.5 * e1**2 + 0.5 * e2**2 + 0.5 * e3**2).mean()

                ys_hat_iters[:, t] = w3*Z2
                energies1.append((0.5 * e1**2).mean())
                energies2.append((0.5 * e2**2).mean())
                energies3.append((0.5 * e3**2).mean())

        ys_hat.append(ys_hat_iters)

        # weight gradient
        dw1 = ( - e1*X ).mean()
        dw2 = ( - e2*Z ).mean() if n_hidden == 1 else ( - e2*Z1 ).mean()
        dw1s.append(dw1)
        dw2s.append(dw2)
        if n_hidden == 2:
           dw3 = ( - e3*Z2 ).mean()
           dw3s.append(dw3)

        grad = np.array([[dw1], [dw2]]) if n_hidden == 1 else np.array([[dw1], [dw2], [dw3]])
        gradient_norms.append(norm(grad))

        # train loss
        train_loss = ( 0.5 * (Y - w2*w1*X )**2 ).mean() if n_hidden == 1 else ( 0.5 * (Y - w3*w2*w1*X )**2 ).mean()
        train_losses.append(train_loss)

        # weight update
        w1 += LR * (- dw1)
        w2 += LR * (- dw2)
        w1_updates.append(w1)
        w2_updates.append(w2)
        if n_hidden == 2:
            w3 += - LR * (dw3)
            w3_updates.append(w3)

        # test loss
        x = 1
        y = -x
        test_loss = ( 0.5 * (y - w2*w1*x )**2 ) if n_hidden == 1 else ( 0.5 * (y - w3*w2*w1*x )**2 )
        test_losses.append(test_loss)

        print(f"test loss: {test_loss:.5f}")
        if n_hidden == 1 and test_loss < 0.001:
            break

    print(f"\nTraining stopped at batch {batch+1} with total energy {energy:.5f} and test loss {test_loss:.5f}\n")

    # plot losses, energies & updates
    plot_losses(
        losses={"train": train_losses, "test": test_losses},
        save_path=f"{save_dir}/losses.pdf"
    )
    energies = [energies1, energies2] if n_hidden == 1 else [energies1, energies2, energies3]
    plot_energies(
        energies=energies,
        n_infer_iters=N_ITERS,
        save_path=f"{save_dir}/energies.pdf"
    )
    weight_updates = [w1_updates, w2_updates] if n_hidden == 1 else [w1_updates, w2_updates, w3_updates]
    plot_updates(
        updates=weight_updates,
        update_type="weights",
        save_path=f"{save_dir}/weights.pdf"
    )
    gradient_updates = [dw1s, dw2s] if n_hidden == 1 else [dw1s, dw2s, dw3s]
    plot_updates(
        updates=gradient_updates,
        update_type="gradient",
        save_path=f"{save_dir}/gradient.pdf"
    )

    # plot predictions dynamics
    n_batches = len(ys)
    pred_frames = []
    for batch in range(n_batches):
        for t in range(N_ITERS+1):
            if t in [0, int(N_ITERS/2), N_ITERS]:
                fig = plot_predictions(
                    targets=ys[batch],
                    predictions=np.array(ys_hat)[batch, :, t],
                    title=f"Training iteration {batch+1}, t = {t}"
                )
                pred_frames.append(fig)

    gif.save(
        frames=pred_frames,
        path=f"{save_dir}/predictions_infer_learn_dynamics.gif",
        duration=200
    )

    # visualise learning dynamics
    X, Y = make_gaussian_dataset(
        mean=DATA_MEAN,
        std=DATA_STD,
        size=BATCH_SIZE
    )
    if N_SEEDS == 1:
        visualise_energy_landscape(
            domain=1 if n_hidden == 1 else 2,
            x=X,
            y=Y,
            n_iters=N_ITERS,
            dt=DT,
            weight_updates=weight_updates,
            train_losses=train_losses,
            save_dir=save_dir
        )
    return train_losses, test_losses, gradient_norms


In [80]:
#@title Hessian script


def compute_loss_and_energy_hessian(
        weights: np.ndarray,
        save_dir: str
    ) -> None:
    n_hidden = len(weights)-1
    print(f"Calculating loss and energy Hessian for {n_hidden}-MLP at the origin...\n")
    os.makedirs(save_dir, exist_ok=True)

    X, Y = make_gaussian_dataset(
        mean=DATA_MEAN,
        std=DATA_STD,
        size=BATCH_SIZE
    )
    if n_hidden == 1:
        w1, w2 = weights[0], weights[1]
        activities = np.array([X, w1*X, Y])
    elif n_hidden == 2:
        w1, w2, w3 = weights[0], weights[1], weights[2]
        activities = np.array([X, w1*X, w2*w1*X, Y])
    elif n_hidden == 5:
        w1, w2, w3, w4, w5, w6 = weights[0], weights[1], weights[2], weights[3], weights[4], weights[5]
        activities = np.array([X, w1*X, w2*w1*X, w3*w2*w1*X, w4*w3*w2*w1*X, w5*w4*w3*w2*w1*X, Y])
    elif n_hidden == 10:
        w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11 = weights[0], weights[1], weights[2], weights[3], weights[4], weights[5], weights[6], weights[7], weights[8], weights[9], weights[10]
        activities = np.array([X, w1*X, w2*w1*X, w3*w2*w1*X, w4*w3*w2*w1*X, w5*w4*w3*w2*w1*X, w6*w5*w4*w3*w2*w1*X, w7*w6*w5*w4*w3*w2*w1*X, w8*w7*w6*w5*w4*w3*w2*w1*X, w9*w8*w7*w6*w5*w4*w3*w2*w1*X, w10*w9*w8*w7*w6*w5*w4*w3*w2*w1*X, Y])

    loss_hessian = hessian(mse_loss_fun)(
        weights,
        X,
        Y
    )
    plot_hessian_matrix(
        hessian_matrix=loss_hessian,
        save_path=f"{save_dir}/loss_hessian.pdf"
    )
    loss_hessian_eigenvals, _ = eigh(loss_hessian)

    energy_hessian_frames = []
    for t in range(N_ITERS+1):
        energy_hessian = hessian(energy_fun)(
            weights,
            activities,
            n_iters=t,
            dt=DT
        )
        fig = plot_hessian_matrix(
            hessian_matrix=energy_hessian,
            title=f"Inference iteration = {t}",
            save_path=f"{save_dir}/energy_hessian_iter_{t}.pdf"
        )
        energy_hessian_frames.append(fig)

    gif.save(
        frames=energy_hessian_frames,
        path=f"{save_dir}/energy_hessian_infer_dynamics.gif",
        duration=1000
    )
    energy_hessian_eigenvals, _ = eigh(energy_hessian)

    theory_energy_hessian_eigenvals = compute_theoretical_energy_eigenvals(
        weights,
        X,
        Y
    )

    plot_loss_and_energy_hessian_eigenvals(
        hessian_eigenvals=[
            loss_hessian_eigenvals,
            energy_hessian_eigenvals,
            theory_energy_hessian_eigenvals
        ],
        save_path=f"{save_dir}/hessian_eigenspectrum.pdf"
    )


In [81]:
#@title Main script


def main():
    for n_hidden in HIDDEN_UNITS:
        n_hidden_dir = f"{RESULTS_DIR}/n_hidden_{n_hidden}"
        os.makedirs(n_hidden_dir, exist_ok=True)

        if n_hidden in [1, 2]:
            bp_train_losses_all_seeds = [[] for seed in range(N_SEEDS)]
            bp_test_losses_all_seeds = bp_train_losses_all_seeds.copy()
            pc_train_losses_all_seeds = bp_train_losses_all_seeds.copy()
            pc_test_losses_all_seeds = bp_train_losses_all_seeds.copy()

            bp_grad_norms_all_seeds = bp_train_losses_all_seeds.copy()
            pc_grad_norms_all_seeds = bp_train_losses_all_seeds.copy()

        for seed in range(N_SEEDS):
            print(f"\nSeed {seed+1}/{N_SEEDS}...\n")
            if N_SEEDS == 1:
                seed = 2 if n_hidden == 2 else 1

            set_seed(seed)
            weights = np.random.normal(loc=0, scale=WEIGHT_SCALE, size=n_hidden+1)

            if n_hidden in [1, 2]:
                bp_train_losses, bp_test_losses, bp_grad_norms = train_bp(
                    weights=weights,
                    save_dir=f"{n_hidden_dir}/{str(seed)}/bp"
                )
                pc_train_losses, pc_test_losses, pc_grad_norms = train_pc(
                    weights=weights,
                    save_dir=f"{n_hidden_dir}/{str(seed)}/pc"
                )
                idx = seed if N_SEEDS > 1 else 0
                bp_train_losses_all_seeds[idx] = bp_train_losses
                bp_test_losses_all_seeds[idx] = bp_test_losses
                pc_train_losses_all_seeds[idx] = pc_train_losses
                pc_test_losses_all_seeds[idx] = pc_test_losses

                bp_grad_norms_all_seeds[idx] = bp_grad_norms
                pc_grad_norms_all_seeds[idx] = pc_grad_norms

            zero_weights = np.random.normal(loc=0, scale=0, size=n_hidden+1)
            compute_loss_and_energy_hessian(
                weights=zero_weights,
                save_dir=f"{n_hidden_dir}/{str(seed)}/hessians"
            )

        if n_hidden in [1, 2]:
            bp_train_means, bp_train_stds = compute_metric_stats(metric=bp_train_losses_all_seeds)
            bp_test_means, bp_test_stds = compute_metric_stats(metric=bp_test_losses_all_seeds)
            pc_train_means, pc_train_stds = compute_metric_stats(metric=pc_train_losses_all_seeds)
            pc_test_means, pc_test_stds = compute_metric_stats(metric=pc_test_losses_all_seeds)

            bp_grad_norm_means, bp_grad_norm_stds = compute_metric_stats(metric=bp_grad_norms_all_seeds)
            pc_grad_norm_means, pc_grad_norm_stds = compute_metric_stats(metric=pc_grad_norms_all_seeds)

            plot_bp_and_pc_loss_stats(
                means=[bp_train_means, pc_train_means],
                stds=[bp_train_stds, pc_train_stds],
                loss_title="$\LARGE{\mathcal{L}_{\\text{train}}}$",
                save_path=f"{n_hidden_dir}/train_loss_stats.pdf"
            )
            plot_bp_and_pc_loss_stats(
                means=[bp_test_means, pc_test_means],
                stds=[bp_test_stds, pc_test_stds],
                loss_title="$\LARGE{\mathcal{L}_{\\text{test}}}$",
                save_path=f"{n_hidden_dir}/test_loss_stats.pdf"
            )
            plot_bp_vs_pc_grad_norm_stats(
                means=[bp_grad_norm_means, pc_grad_norm_means],
                stds=[bp_grad_norm_stds, pc_grad_norm_stds],
                save_path=f"{n_hidden_dir}/gradient_norm_stats.pdf"
            )


## Run analysis

In [None]:
main()
!zip -r linear_chains_results.zip results

## Download results to gdrive

In [None]:
import shutil
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
colab_link = "/content/linear_chains_results.zip"
gdrive_link = "/content/drive/MyDrive/"
shutil.copy(colab_link, gdrive_link)

'/content/drive/MyDrive/linear_chains_results.zip'