In [None]:
from vectorhash import GridHippocampalScaffold

lambdas = [3, 4, 5]
shapes = [(i, i) for i in lambdas]

scaffold = GridHippocampalScaffold(shapes, 400, sanity_check=False, calculate_g_method='hairpin')

In [None]:
from data_utils import prepare_data, load_mnist_dataset
from hippocampal_sensory_layers import IterativeBidirectionalPseudoInverseHippocampalSensoryLayer


dataset = load_mnist_dataset()
data, noisy_data = prepare_data(
    dataset,
    num_imgs=400,
    preprocess_sensory=True,
    noise_level="none",
    across_dataset=True,
)

sh = IterativeBidirectionalPseudoInverseHippocampalSensoryLayer(
  input_size=784,
  N_h=400,
  hidden_layer_factor=1,
  epsilon_hs=0.1,
  epsilon_sh=0.1
)

In [None]:
from tqdm import tqdm
for j in tqdm(range(400)):
  sh.learn(scaffold.H[j], data[j])

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import matplotlib
import matplotlib.axes
import torch
matplotlib.use("ipympl")
%matplotlib widget
from test_utils import get_action

ls = [module.l for module in scaffold.modules]
g_ticks = np.cumsum(ls)
g_ticks = np.insert(g_ticks, 0, 0)
g_ticks = g_ticks[:-1]


pos = torch.tensor((0,0))
g = scaffold.grid_state_from_cartesian_coordinates(
    torch.tensor(pos, device=scaffold.device)
)
scaffold.g = g
im_data = sh.sensory_from_hippocampal(scaffold.hippocampal_from_grid(g)).reshape(28,28)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
im_artist = ax[0].imshow(im_data)
g_artist = ax[1].imshow(scaffold.g.unsqueeze(0).T.cpu().numpy())
j = torch.nonzero(torch.all(scaffold.G == g, dim=1)).flatten().item()
g_text_artist = ax[1].set_title(f"g_idx: {g.nonzero().flatten().cpu().numpy()}; j: {j}", fontsize=12)
text_artist = ax[0].set_title(f"pos: {pos}", fontsize=12)
# disabley ticks
ax[1].xaxis.set_visible(False)
ax[1].set_yticks(g_ticks)

plt.show()

In [None]:
while True:
    action = get_action()
    if action is None:
        break
    pos = pos + action
    g = scaffold.grid_state_from_cartesian_coordinates(
        torch.tensor(pos, device=scaffold.device)
    )
    j = torch.nonzero(torch.all(scaffold.G == g, dim=1)).flatten().item()
    scaffold.g = g
    im_data = sh.sensory_from_hippocampal(scaffold.hippocampal_from_grid(g)).reshape(28,28)
    im_artist.set_data(im_data)
    text_artist.set_text(f"pos: {pos}")
    g_artist.set_data(scaffold.g.unsqueeze(0).T.cpu().numpy())
    g_text_artist.set_text(f"g_idx: {g.nonzero().flatten().cpu().numpy()}; j: {j}")
    fig.canvas.draw()