In [None]:
import sys

sys.path.append("../..")

from fourier_scaffold import FourierScaffold
from data_utils import load_mnist_dataset, prepare_data, determine_input_size
from graph_utils import plot_imgs_side_by_side
from matplotlib import pyplot as plt
from hippocampal_sensory_layers import (
    ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayerComplexScalars,
)
import torch

device = "cuda"
shapes = [(5, 5), (7, 7)]
whitened = True
D = 400
scaffold = FourierScaffold(
    shapes=torch.tensor(shapes), D=D, device=device, _skip_K_calc=True
)
N = 500
dataset = load_mnist_dataset()
input_size = determine_input_size(dataset)
obs, _ = prepare_data(dataset, N * 2, device=device, preprocess_sensory=whitened)
noise = torch.zeros_like(obs).uniform_(-0.5, 0.5)
noisy_obs = obs + noise
layer = ComplexIterativeBidirectionalPseudoInverseHippocampalSensoryLayerComplexScalars(
    input_size, D, 0, epsilon_hs=0.01, epsilon_sh=0.01, device=device
)
noise_only = torch.zeros_like(obs[0]).uniform_(-0.5, 0.5)

In [None]:
titles = ["learned", "noisy learned", "unseen", "noisy unseen", "pure noise"]
fig, axs = plt.subplots(ncols=5, figsize=(16, 3.2))
plot_imgs_side_by_side(
    imgs=[
        img.reshape(28, 28).cpu()
        for img in [obs[0], noisy_obs[0], obs[N], noisy_obs[N], noise_only]
    ],
    axs=axs,
    titles=titles,
    fig=fig,
    use_first_img_scale=False
)
fig.suptitle(
    "images"
)

In [None]:
from tqdm import tqdm
gbook = scaffold.gbook().T
for i in tqdm(range(N)):
    layer.learn(gbook[i], obs[i])

In [None]:
fig, axs = plt.subplots(ncols=5, figsize=(16, 3.2))

entropies = [
    scaffold.entropy(layer.hippocampal_from_sensory(s))
    for s in [
        obs[0],
        noisy_obs[0],
        obs[N],
        noisy_obs[N],
        noise_only.flatten(),
    ]
]

plot_imgs_side_by_side(
    imgs=[
        scaffold.get_all_probabilities(layer.hippocampal_from_sensory(img)).abs().cpu()
        for img in [
            obs[0],
            noisy_obs[0],
            obs[N],
            noisy_obs[N],
            noise_only.flatten(),
        ]
    ],
    axs=axs,
    titles=[f"{name},H={entropy:.3f}" for name, entropy in zip(titles, entropies)],
    fig=fig,
    use_first_img_scale=False
)
fig.suptitle(
    f"distributions from first image, noisy first image, first unlearned image, noisy first unlearned image, and pure noise\nwhitened={whitened},shapes={shapes}, D={D}, num imgs={N}, noise=uniform(-0.5,0.5)"
)

In [None]:
def entropy_batched(P_batch: torch.Tensor):
    return P_batch.norm(dim=(1)) ** 2


def entropy(P: torch.Tensor):
    return P.norm() ** 2


fig, ax = plt.subplots(figsize=(10, 6))

titles = [
    "obs seen",
    "obs unseen",
    "noisy obs seen",
    "noist obs unseen",
]
subsets = [
    obs[:N],
    obs[N : N * 2],
    noisy_obs[:N],
    noisy_obs[N : N * 2],
]


for title, subset in zip(titles, subsets):
    ax.scatter(
        torch.arange(N),
        entropy_batched(layer.hippocampal_from_sensory(subset)).cpu(),  # (Npatts)
        label=title,
        s=5,
    )

ax.axhline(
    y=entropy(layer.hippocampal_from_sensory(noise_only.flatten())[0]).cpu(),
    label="noise",
    linestyle=":",
)
# ax.axhline(
#     y=expected_number_active,
#     label="expected number active",
# )


ax.set_xlabel("img")
ax.set_ylabel(f"H calculated by H=||P||^2 for s->P")
ax.legend()
fig.suptitle(
    f"entropy for learned, noisy learned, unlearned, noisy unlearned images with pure noise baseline\nwhitened={whitened},shapes={shapes}, D={D}, num imgs={N}"
)

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))

titles = [
    "obs seen",
    "obs unseen",
    "noisy obs seen",
    "noist obs unseen",
]
subsets = [
    obs[:N],
    obs[N : N * 2],
    noisy_obs[:N],
    noisy_obs[N : N * 2],
]


for title, subset in zip(titles, subsets):
    ax.scatter(
        torch.arange(N),
        entropy_batched(layer.hippocampal_from_sensory(subset)).cpu(),  # (Npatts)
        label=title,
        s=5,
    )

ax.axhline(
    y=entropy(layer.hippocampal_from_sensory(noise_only.flatten())[0]).cpu(),
    label="noise",
    linestyle=":",
)
# ax.axhline(
#     y=expected_number_active,
#     label="expected number active",
# )


ax.set_xlabel("img")
ax.set_ylabel(f"H calculated by H=||P||^2 for s->P")
ax.set_ylim(0,10)
ax.legend()
fig.suptitle(
    f"entropy for learned, noisy learned, unlearned, noisy unlearned images with pure noise baseline\nwhitened={whitened},shapes={shapes}, D={D}, num imgs={N}, noise=uniform(-0.50, 0.50)"
)