In [23]:
from matplotlib.pyplot import axes
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'


def prepare_data(
    dataset,
    num_imgs=10,
    preprocess_sensory=True,
    noise_level="medium",
):
    import torch
    import random

    data = dataset.data
    # print(num_imgs)
    # print(data.shape)
    data = data.reshape(data.shape[0], -1)
    # print(data.shape)
    data = torch.tensor(data[:num_imgs]).float().to(device)
    # print(data.shape)

    # data = random.sample(dataset.data.flatten(1).float().to("cpu"), num_imgs)
    if preprocess_sensory:
        data = (data - data.mean()) / data.std()
        # print(mnist_data[0])

    # noissing the data
    if noise_level == "none":
        return data, data
    elif noise_level == "low":
        random_noise = torch.zeros_like(data).uniform_(-1, 1)
    elif noise_level == "medium":
        random_noise = torch.zeros_like(data).uniform_(-1.25, 1.25)
    elif noise_level == "high":
        random_noise = torch.zeros_like(data).uniform_(-1.5, 1.5)
    noisy_data = data + random_noise
    # TODO: DO WE PREPROCESS NOISY IMAGES?
    # if preprocess_sensory:
    #     noisy_mnist = (noisy_mnist - noisy_mnist.mean()) / noisy_mnist.std()

    return data, noisy_data

In [24]:
import numpy as np
import torch
from nd_scaffold import GridScaffold, SparseMatrixBySparsityInitializer
from graph_utils import graph_scaffold, print_imgs_side_by_side
from matrix_initializers import SparseMatrixByScalingInitializer
from vectorhash_functions import solve_mean, spacefillingcurve
import math
from scipy.stats import norm


def test_memory_capacity(
    data,
    noisy_data,
    shapes=[(3, 3, 5), (4, 4, 7)],
    N_h=1000,
    initalization_method="by_scaling",
    percent_nonzero_relu=0.01,
    W_gh_var=1.0,
    sparse_initialization=0.1,
    T=0.01,
    **vectorhash_kwargs,
):
    assert initalization_method in ["by_scaling", "by_sparsity"]
 
    if initalization_method == "by_scaling":
        W_hg_mean = -W_hg_std * norm.ppf(1-percent_nonzero_relu) / math.sqrt(len(shapes))
        W_hg_std = math.sqrt(W_gh_var)
        h_normal_mean=len(shapes)*W_hg_mean
        h_normal_std=math.sqrt(len(shapes))*W_hg_std
        relu_theta = 0
    elif initalization_method == "by_sparsity":
        gamma = 1- sparse_initialization
        relu_theta = math.sqrt(gamma * len(shapes)) * norm.ppf(1-percent_nonzero_relu)
        W_hg_mean = 0
        W_hg_std = math.sqrt(gamma * len(shapes))
        h_normal_mean = -relu_theta
        h_normal_std = (1-sparse_initialization) * len(shapes)


    GS = GridScaffold(
        shapes=shapes,
        N_h=N_h,
        input_size=data.shape[1],
        device=device,
        h_normal_mean=h_normal_mean,
        h_normal_std=h_normal_std,
        sparse_matrix_initializer=(
            SparseMatrixByScalingInitializer(
                mean=W_hg_mean,
                scale=W_hg_std,
                device=device
            )
            if initalization_method == "by_scaling"
            else SparseMatrixBySparsityInitializer(sparsity=sparse_initialization, device=device)
        ),
        relu_theta=relu_theta,
        T=T,
        **vectorhash_kwargs,
    )

    # learn over all images
    v = spacefillingcurve(shapes)

    GS.learn_path(observations=data, velocities=v[: len(data)])
    print(len(v[: len(data)]))
    recalled_imgs = GS.recall(noisy_data)

    # for i in range(min(5, len(data))):
    #     original_img = data[i].reshape(28, 28).cpu().numpy()
    #     noisy_img = noisy_data[i].reshape(28, 28).cpu().numpy()
    #     recalled_img = recalled_imgs[i].reshape(28, 28).cpu().numpy()
    #     print_imgs_side_by_side(
    #         original_img,
    #         noisy_img,
    #         recalled_img, 
    #         captions=["Original", "Noisy", "Recalled"],
    #         out=f"g1/{len(data)}_{percent_nonzero_relu}_imgs_{i}_{vectorhash_kwargs}.png")

    print(recalled_imgs)

    similarity = torch.nn.functional.cosine_similarity(data, recalled_imgs)
    return similarity

Recreating Capacity Results

In [25]:
# Memory Capacity Tests
import torchvision
from torchvision import transforms
import torch

shapes = [(3, 3), (4, 4), (5, 5)]
# shapes = [(5, 5), (9, 9), (11, 11)]

N_h = 1000

N_g = 0
for shape in shapes:
    l = torch.prod(torch.tensor(shape)).item()
    N_g += l
# print("N_g", N_g)


transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Lambda(lambda x: x.flatten())]
)
dataset = torchvision.datasets.MNIST(
    root="data", train=True, download=True, transform=transform
)

# dataset = torchvision.datasets.FashionMNIST(
#     root="data", train=True, download=True, transform=transform
# )

# dataset = torchvision.datasets.CIFAR100(
#     root="data", train=True, download=True, transform=transform
# )

input_size = 1
for shape in dataset.data[0].shape:
    input_size *= shape

theoretical_capacity = N_g * N_h / input_size  # 784 is the size of the input for MNIST

percents = [
    # 0.01,  # 1
    # 0.03,  # 2
    # 0.1,  # 3
    # 0.2,  # 4
    0.33,  # 5
    0.5,  # 6
    0.75,  # 7
    0.9,  # 8
    1.0,  # 9
    # 1.1,  # 10
    # 1.5,  # 11
    # 2.0,  # 12
    # 3.0,  # 13
    # 10.0,  # 14
]

num_images = [theoretical_capacity * p for p in percents]
num_images = [math.ceil(n) for n in num_images]
print(num_images)

preprocess_sensory = True
noise_level = "medium"

# TODO: Need a hyperparam N_p for dimension of the projection from grid cells to place cells
similarities = []
for num_imgs in num_images:
    print("==========================================")
    data, noisy_data = prepare_data(
        dataset,
        num_imgs=num_imgs,
        preprocess_sensory=preprocess_sensory,
        noise_level=noise_level,
    )
    similarity = test_memory_capacity(
        data,
        noisy_data,
        shapes=shapes,
        N_h=N_h,
        initalization_method="by_sparsity",
        percent_nonzero_relu=10 / N_h,
        W_gh_var=1.0,
        sparse_initialization=0.6,
        T=0.01,
        continualupdate=False,
        ratshift=False,
        initialize_W_gh_with_zeroes=False,
        pseudo_inverse=False,
        learned_pseudo=False,
    )
    print("Cosine Similarity", torch.mean(similarity).item())
    similarities.append(similarity)

[22, 32, 48, 58, 64]
module shapes:  [(3, 3), (4, 4), (5, 5)]
N_g     :  50
N_patts :  3600
N_h     :  1000


  data = torch.tensor(data[:num_imgs]).float().to(device)


Unique Gs seen while learning: 22
Unique Hs seen while learning: 22
22
Unique Hs seen while recalling: 22
Unique Gs seen while recalling (before denoising): 22
Unique Gs seen while recalling (after denoising): 8
Unique Hs seen while recalling (after denoising): 8
avg nonzero H: 179.9545440673828
avg nonzero H_denoised: 20.5
tensor([[-2.7006, -2.7006, -2.7006,  ..., -2.7006, -2.7006, -2.7006],
        [-2.7006, -2.7006, -2.7006,  ..., -2.7006, -2.7006, -2.7006],
        [-1.7388, -1.7388, -1.7388,  ..., -1.7388, -1.7388, -1.7388],
        ...,
        [-2.5576, -2.5576, -2.5576,  ..., -2.5576, -2.5576, -2.5576],
        [-2.7006, -2.7006, -2.7006,  ..., -2.7006, -2.7006, -2.7006],
        [-2.7006, -2.7006, -2.7006,  ..., -2.7006, -2.7006, -2.7006]],
       device='cuda:0')
Cosine Similarity 0.5466322898864746
module shapes:  [(3, 3), (4, 4), (5, 5)]
N_g     :  50
N_patts :  3600
N_h     :  1000
Unique Gs seen while learning: 32
Unique Hs seen while learning: 32
32
Unique Hs seen while 

Comparing initialization techniques (at 10%, 50%, 100%, 150% capacity)

Compare continual learning vs learning once at start (at 10%, 50%, 100%, 150% capacity)

Hyperparam tuning for number active hippocampal cells after denoising

Why are we losing grid states? 

Test effect of using h fix or not

In [None]:
# Memory Capacity Tests
import torchvision
from torchvision import transforms
import torch
import math
import numpy as np
from matrix_initializers import ConstantInitializer, SparseMatrixByScalingInitializer
from vectorhash_functions import solve_mean, spacefillingcurve
from nd_scaffold import GridScaffold
from graph_utils import print_imgs_side_by_side
import matplotlib.pyplot as plt
import os


shapes = [(3, 3), (4, 4), (5, 5)]
device = "cuda"

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Lambda(lambda x: x.flatten())]
)
dataset = torchvision.datasets.MNIST(
    root="data", train=True, download=True, transform=transform
)
input_size = 1
for shape in dataset.data[0].shape:
    input_size *= shape

theoretical_capacity = 3600 // 10

N_h = 1000
target_N_hs = [10, 100, 300, 500]
percents = [0.01, 0.03, 0.1, 0.33, 0.5, 0.75, 0.9, 1.0]
num_images = [theoretical_capacity * p for p in percents]
num_images = [math.ceil(n) for n in num_images]
print(num_images)

preprocess_sensory = True
noise_level = "none"

# TODO: Need a hyperparam N_p for dimension of the projection from grid cells to place cells

results_no_fix = np.zeros((len(target_N_hs), len(percents)))
results_fix = np.zeros((len(target_N_hs), len(percents)))

from scipy.stats import norm

def test_h_fix(
    data,
    noisy_data,
    shapes=[(3, 3, 5), (4, 4, 7)],
    N_h=1000,
    percent_nonzero_relu=0.01,
    W_hg_std=1.0,
    T=0.01,
    **vectorhash_kwargs,
):
    W_hg_mean = -W_hg_std * norm.ppf(1-percent_nonzero_relu) / math.sqrt(len(shapes))
    h_normal_mean=len(shapes)*W_hg_mean
    h_normal_std=math.sqrt(len(shapes))*W_hg_std
    print("normal_mean", W_hg_mean)
    GS_no_fix = GridScaffold(
        shapes=shapes,
        N_h=N_h,
        input_size=data.shape[1],
        device=device,
        h_normal_mean=h_normal_mean,
        h_normal_std=h_normal_std,
        sparse_matrix_initializer=(
            SparseMatrixByScalingInitializer(
                mean=W_hg_mean,
                scale=W_hg_std,
                device=device,
            )
        ),
        relu_theta=0,
        T=T,
        **vectorhash_kwargs,
        use_h_fix=False,
    )
    GS_fix = GridScaffold(
        shapes=shapes,
        N_h=N_h,
        input_size=data.shape[1],
        device=device,
        h_normal_mean=h_normal_mean,
        h_normal_std=h_normal_std,
        sparse_matrix_initializer=ConstantInitializer(value=GS_no_fix.W_hg),
        relu_theta=0,
        T=T,
        **vectorhash_kwargs,
        use_h_fix=True,
    )

    # learn over all images
    v = spacefillingcurve(shapes)

    GS_no_fix.learn_path(observations=data, velocities=v[: len(data)])
    GS_fix.learn_path(observations=data, velocities=v[: len(data)])
    print(len(v[: len(data)]))
    recalled_imgs_no_fix = GS_no_fix.recall(noisy_data)
    recalled_imgs_fix = GS_fix.recall(noisy_data)

    # os.makedirs(f"g1/num_imgs_{len(data)}_target_{target_N_h}", exist_ok=True)
    # dir = f"g1/num_imgs_{len(data)}_target_{target_N_h}"
    # for i in range(min(5, len(data))):
    #     original_img = data[i].reshape(28, 28).cpu().numpy()
    #     noisy_img = noisy_data[i].reshape(28, 28).cpu().numpy()
    #     recalled_img_no_fix = recalled_imgs_no_fix[i].reshape(28, 28).cpu().numpy()
    #     recalled_img_fix = recalled_imgs_fix[i].reshape(28, 28).cpu().numpy()
    #     print_imgs_side_by_side(
    #         original_img,
    #         noisy_img,
    #         recalled_img_no_fix,
    #         recalled_img_fix,
    #         captions=["Original", "Noisy", "Recalled No Fix", "Recalled Fix"],
    #         out=f"{dir}/imgs_{i}.png",
    #     )
    #     plt.close('all')


    similarity_no_fix = torch.nn.functional.cosine_similarity(data, recalled_imgs_no_fix)
    similarity_fix = torch.nn.functional.cosine_similarity(data, recalled_imgs_fix)
    return similarity_no_fix, similarity_fix


for i, target_N_h in enumerate(target_N_hs):
    vectorhash_kwargs = dict(
        continualupdate=False,
        ratshift=False,
        initialize_W_gh_with_zeroes=False,
        pseudo_inverse=False,
        learned_pseudo=False,
        calculate_update_scaling_method="norm",
    )

    for j, num_imgs in enumerate(num_images):
        data, noisy_data = prepare_data(
            dataset,
            num_imgs=num_imgs,
            preprocess_sensory=preprocess_sensory,
            noise_level=noise_level,
        )
        similarity_no_fix, similarity_fix = test_h_fix(
            data,
            noisy_data,
            shapes=shapes,
            N_h=1000,
            percent_nonzero_relu=target_N_h / N_h,
            W_hg_std=1.0,
            T=0.01,
            **vectorhash_kwargs,
        )
        results_no_fix[i, j] = torch.mean(similarity_no_fix).item()
        results_fix[i, j] = torch.mean(similarity_fix).item()
        print("Cosine Similarity no fix", torch.mean(similarity_no_fix).item())
        print("Cosine Similarity    fix", torch.mean(similarity_fix).item())

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")
sns.set_theme("paper")

fig, ax = plt.subplots(1, 1, figsize=(10, 10))

for i, target_N_h in enumerate(target_N_hs):

    ax.plot([p* 0.1 for p in percents], results_no_fix[i], label=f"no fix {target_N_h}")
    ax.plot([p* 0.1 for p in percents], results_fix[i], label=f"fix {target_N_h}")

ax.set_xlabel("Percent of Theoretical Capacity")
ax.set_ylabel("Cosine Similarity")
ax.legend()

plt.show()
