In [None]:
import sys
sys.path.append('../')

In [None]:
device='cuda'

In [None]:
import torch
import matplotlib.pyplot as plt
from data_utils import load_mnist_dataset, prepare_data
from vectorhash import build_scaffold
from hippocampal_sensory_layers import (
    ExactPseudoInverseHippocampalSensoryLayer,
    IterativeBidirectionalPseudoInverseHippocampalSensoryLayer,
    HippocampalSensoryLayer,
    HebbianHippocampalSensoryLayer,
    HSPseudoInverseSHHebbieanHippocampalSensoryLayer,
)
from tqdm import tqdm
from data_utils import load_mnist_dataset
from matplotlib.axes import Axes

In [None]:
def test_layer(
    layer: HippocampalSensoryLayer, hbook: torch.Tensor, sbook: torch.Tensor
):
    err_l1_first_img_s_h_s = -torch.ones(len(sbook))
    err_l1_last_img_s_h_s = -torch.ones(len(sbook))
    avg_accumulated_err_l2 = -torch.ones(len(sbook))
    first_img = sbook[0]

    for i in tqdm(range(len(sbook))):
        h = hbook[i]
        s = sbook[i]
        layer.learn(h, s)

        err_l1_first_img_s_h_s[i] = torch.mean(
            torch.abs(
                layer.sensory_from_hippocampal(
                    layer.hippocampal_from_sensory(first_img)
                )[0]
                - first_img
            )
        )

        err_l1_last_img_s_h_s[i] = torch.mean(
            torch.abs(
                layer.sensory_from_hippocampal(
                    layer.hippocampal_from_sensory(sbook[i])
                )[0]
                - sbook[i]
            )
        )


        avg_accumulated_err_l2[i] = torch.mean(
            (
                layer.sensory_from_hippocampal(
                    layer.hippocampal_from_sensory(sbook[:i+1])
                )
                - sbook[:i+1]
            )
            ** 2
        )
        if (
            err_l1_first_img_s_h_s[i] > 10e5
            or avg_accumulated_err_l2[i] > 10e5
            or torch.any(torch.isnan(err_l1_first_img_s_h_s[i]))
            or torch.any(torch.isnan(avg_accumulated_err_l2[i]))
        ):
            break

    return err_l1_first_img_s_h_s, err_l1_last_img_s_h_s, avg_accumulated_err_l2


def plot_avg_acc_l2_err_on_ax(ax: Axes, avg_accumulated_err_l2: torch.Tensor, label):
    x = torch.arange(0, len(avg_accumulated_err_l2[0]))
    mean = avg_accumulated_err_l2.mean(dim=0)
    std = avg_accumulated_err_l2.std(dim=0)
    ax.plot(x, mean, label=label)
    ax.fill_between(x, mean-std, mean+std, alpha=0.2)

    return ax


def plot_first_img_l1_err_on_ax(ax: Axes, err_l1_first_img_s_h_s: torch.Tensor, label):
    x = torch.arange(0, len(err_l1_first_img_s_h_s[0]))
    mean = err_l1_first_img_s_h_s.mean(dim=0)
    std = err_l1_first_img_s_h_s.std(dim=0)
    ax.plot(x, mean, label=label)
    ax.fill_between(x, mean-std, mean+std, alpha=0.2)

    return ax

def plot_last_img_l1_err_on_ax(ax: Axes, err_l1_last_img_s_h_s: torch.Tensor, label):
    x = torch.arange(0, len(err_l1_last_img_s_h_s[0]))
    mean = err_l1_last_img_s_h_s.mean(dim=0)
    std = err_l1_last_img_s_h_s.std(dim=0)
    ax.plot(x, mean, label=label)
    ax.fill_between(x, mean-std, mean+std, alpha=0.2)

    return ax



def set_ax_titles(ax: Axes, title, xtitle, ytitle):
    ax.set_title(title)
    ax.set_xlabel(xtitle)
    ax.set_ylabel(ytitle)
    ax.legend()


def add_vertical_bar_on_ax(ax: Axes, x):
    ax.axvline(x=x, color="b", linestyle="--")

def add_horizontal_bar_on_ax(ax: Axes, y, label):
    ax.axhline(y=y, color="k", linestyle="--", label=label)

Analytic vs. Iterative pseudoinverse

In [None]:
dataset = load_mnist_dataset()
N_patts = 600
data, noisy_data = prepare_data(dataset, N_patts, noise_level='none', device=device)
runs=1

N_h = 400
shapes = [(3,3,3),(4,4,4),]
scaffold, mean_h = build_scaffold(shapes, N_h, device=device, sanity_check=True)

In [None]:
names = ["analytic", "iterative"]
err_l1_first_img_s_h_s = -torch.ones(2, runs, N_patts)
err_l1_last_img_s_h_s = -torch.ones(2, runs, N_patts)
avg_accumulated_err_l2 = -torch.ones(2, runs, N_patts)

scaffold, mean_h = build_scaffold(shapes, N_h, device=device, sanity_check=True)
for i, name in enumerate(names):
    for run in range(runs):
        if name == "analytic":
            layer = ExactPseudoInverseHippocampalSensoryLayer(
                784, N_h, N_patts, scaffold.H[:N_patts], device=device
            )
        else:
            layer = IterativeBidirectionalPseudoInverseHippocampalSensoryLayer(
                784, N_h, 1, True, 0.1, 0.1, device=device
            )
        (
            err_l1_first_img_s_h_s[i, run],
            err_l1_last_img_s_h_s[i, run],
            avg_accumulated_err_l2[i, run],
        ) = test_layer(layer, scaffold.H, data[torch.randperm(len(data))])

In [None]:
mean_of_dataset = torch.mean(data, dim=0)
rand = torch.rand_like(data)
err_mean_l2 = torch.mean((mean_of_dataset - data) ** 2).cpu()
err_mean_l1 = torch.mean(torch.abs(mean_of_dataset - data[0])).cpu()

fig, ax = plt.subplots(figsize=(15, 9))

for i, name in enumerate(names):
    plot_avg_acc_l2_err_on_ax(ax, avg_accumulated_err_l2[i], label=name)
  
add_vertical_bar_on_ax(ax, N_h)
add_horizontal_bar_on_ax(ax, err_mean_l1, label='err using "mean of dataset" image')
set_ax_titles(
    ax,
    f"shapes={shapes}, N_h={N_h}",
    "Number of images learned",
    "Average L2 error over all patterns",
)
ax.set_ylim(0, 1)
fig.savefig("hipp_sens_result_analytic_vs_iterative_dataset_err")

fig, ax = plt.subplots(figsize=(15, 9))
for i, name in enumerate(names):
    plot_first_img_l1_err_on_ax(ax, err_l1_first_img_s_h_s[i], label=name)
    # label=f"iterative hidden_layer_factor={1}, stationary={True}, epsilon_W_sh={0.1}, epsilon_W_hs={0.1}",
    
add_vertical_bar_on_ax(ax, N_h)
add_horizontal_bar_on_ax(ax, err_mean_l1, label='err using "mean of dataset" image')
set_ax_titles(
    ax,
    f"shapes={shapes}, N_h={N_h}",
    "Number of images learned",
    "Error when recovering first pattern",
)
ax.set_ylim(0, 1)
fig.savefig("hipp_sens_result_analytic_vs_iterative_first_img_err")

fig, ax = plt.subplots(figsize=(15, 9))
for i, name in enumerate(names):
    plot_last_img_l1_err_on_ax(ax, err_l1_last_img_s_h_s[i], label=name)
    # label=f"iterative hidden_layer_factor={1}, stationary={True}, epsilon_W_sh={0.1}, epsilon_W_hs={0.1}",
    
add_vertical_bar_on_ax(ax, N_h)
add_horizontal_bar_on_ax(ax, err_mean_l1, label='err using "mean of dataset" image')
set_ax_titles(
    ax,
    f"shapes={shapes}, N_h={N_h}",
    "Number of images learned",
    "Error when recovering first pattern",
)
ax.set_ylim(0, 1)
fig.savefig("hipp_sens_result_analytic_vs_iterative_last_img_err")

Different Hebbian variations

In [None]:
dataset = load_mnist_dataset()
N_patts = 600
data, noisy_data = prepare_data(dataset, N_patts, noise_level="none", device=device)

N_h = 400
shapes = [(3, 3, 3), (4, 4, 4)]
means = [None, 0]
scaling_updates = [True, False]
runs = 10

err_l1_first_img_s_h_s = -torch.ones(len(means), len(scaling_updates), runs, N_patts)
err_l1_last_img_s_h_s = -torch.ones(len(means), len(scaling_updates), runs, N_patts)
avg_accumulated_err_l2 = -torch.ones(len(means), len(scaling_updates), runs, N_patts)

for j, mean in enumerate(means):
    for k, scaling_update in enumerate(scaling_updates):
        for run in range(runs):
            scaffold, mean_h = build_scaffold(
                shapes, N_h, device=device, sanity_check=True
            )
            layer = HebbianHippocampalSensoryLayer(
                784, N_h, "norm", mean != None, mean_h, scaling_update, device
            )
            (
                err_l1_first_img_s_h_s[j, k, run],
                err_l1_last_img_s_h_s [j, k, run],
                avg_accumulated_err_l2[j, k, run],
            ) = test_layer(layer, scaffold.H, data[torch.randperm(len(data))])

In [None]:
fig, ax = plt.subplots(figsize=(15, 9))

mean_of_dataset = torch.mean(data, dim=0)
rand = torch.rand_like(data)
err_mean_l2 = torch.mean((mean_of_dataset - data) ** 2).cpu()
err_mean_l1 = torch.mean(torch.abs(mean_of_dataset - data[0])).cpu()

for k, scaling_update in enumerate(scaling_updates):
    for j, mean in enumerate(means):
        plot_avg_acc_l2_err_on_ax(
            ax,
            avg_accumulated_err_l2[j, k],
            label=f"mean_fix={mean != None}, scaling_update={scaling_update}",
        )
add_vertical_bar_on_ax(ax, N_h)
add_horizontal_bar_on_ax(ax, err_mean_l2, label='err using "mean of dataset" image')
set_ax_titles(
    ax,
    f"shapes={shapes}, N_h={N_h}",
    "Number of images learned",
    "Average L2 error over all patterns",
)
ax.set_ylim(0, 2)
fig.savefig("hipp_sens_result_hebb_dataset_err_zoom_y")
ax.set_ylim(0, 70**2)
fig.savefig("hipp_sens_result_hebb_dataset_err")
ax.set_ylim(0, 2)
ax.set_xlim(0, 50)
fig.savefig("hipp_sens_result_hebb_dataset_err_zoom_xy")

fig, ax = plt.subplots(figsize=(15, 9))
for k, scaling_update in enumerate(scaling_updates):
    for j, mean in enumerate(means):
        plot_first_img_l1_err_on_ax(
            ax,
            err_l1_first_img_s_h_s[j, k],
            label=f"mean_fix={mean != None}, scaling_update={scaling_update}",
        )
add_vertical_bar_on_ax(ax, N_h)
add_horizontal_bar_on_ax(ax, err_mean_l1, label='err using "mean of dataset" image')
set_ax_titles(
    ax,
    f"shapes={shapes}, N_h={N_h}",
    "Number of images learned",
    "Error when recovering first pattern",
)
ax.set_ylim(0, 2)
fig.savefig("hipp_sens_result_hebb_first_img_err_zoom_y")
ax.set_ylim(0, 70)
fig.savefig("hipp_sens_result_hebb_first_img_err")
ax.set_ylim(0, 2)
ax.set_xlim(0, 50)
fig.savefig("hipp_sens_result_hebb_first_img_err_zoom_xy")

fig, ax = plt.subplots(figsize=(15, 9))
for k, scaling_update in enumerate(scaling_updates):
    for j, mean in enumerate(means):
        plot_last_img_l1_err_on_ax(
            ax,
            err_l1_last_img_s_h_s[j, k],
            label=f"mean_fix={mean != None}, scaling_update={scaling_update}",
        )
add_vertical_bar_on_ax(ax, N_h)
add_horizontal_bar_on_ax(ax, err_mean_l1, label='err using "mean of dataset" image')
set_ax_titles(
    ax,
    f"shapes={shapes}, N_h={N_h}",
    "Number of images learned",
    "Error when recovering last pattern",
)
ax.set_ylim(0, 2)
fig.savefig("hipp_sens_result_hebb_last_img_err_zoom_y")
ax.set_ylim(0, 70)
fig.savefig("hipp_sens_result_hebb_last_img_err")
ax.set_ylim(0, 2)
ax.set_xlim(0, 50)
fig.savefig("hipp_sens_result_hebb_last_img_err_zoom_xy")