In [None]:
from vectorhash_imported import *
from vectorhash_convered import *
from nd_scaffold import GridScaffold
import math
from scipy.stats import norm


lambdas = [3, 4, 5]
shapes = [(i, i) for i in lambdas]
percent_nonzero_relu = 0.5  #
W_gh_var = 1
sparse_initialization = 0.1
T = 0.01
W_hg_std = math.sqrt(W_gh_var)
W_hg_mean = -W_hg_std * norm.ppf(1 - percent_nonzero_relu) / math.sqrt(len(lambdas))
h_normal_mean = len(lambdas) * W_hg_mean
h_normal_std = math.sqrt(len(lambdas)) * W_hg_std
relu_theta = math.sqrt((1 - sparse_initialization) * len(lambdas)) * norm.ppf(
    1 - percent_nonzero_relu
)
num_imgs = 10

print(
    percent_nonzero_relu, W_hg_mean, W_hg_std, h_normal_mean, h_normal_std, relu_theta
)

scaffold = GridScaffold(
    shapes=shapes,
    N_h=1000,
    input_size=784,
    device=None,
    learned_pseudo="bidirectional",
    hidden_layer_factor=0,
    sparse_matrix_initializer=SparseMatrixBySparsityInitializer(
        sparsity=sparse_initialization, device="cpu"
    ),
    relu_theta=relu_theta,
    # sparse_matrix_initializer=SparseMatrixByScalingInitializer(
    #     scale=W_hg_std, mean=W_hg_mean, device="cpu"
    # ),
    T=T,
    # h fix
    calculate_update_scaling_method="n_h",
    use_h_fix=False,  # true if norm scaling, false 25/32, true 21/32 11/32
    h_normal_mean=h_normal_mean,
    h_normal_std=h_normal_std,
    # epsilon=0.01,
    scaling_updates=False,  # only relevant when using hebbian false 21/32, true 15/32 7/32
    sanity_check=True,
    # dream_fix=1,
    # zero_tol=1,
)

In [None]:
from vectorhash_functions import spacefillingcurve
from data_utils import prepare_data, load_mnist_dataset

dataset = load_mnist_dataset()
data, noisy_data = prepare_data(
    dataset,
    num_imgs=num_imgs,
    preprocess_sensory=True,
    noise_level="none",
    across_dataset=True,
)

v = spacefillingcurve(shapes)

g_positions, g_positions2, g_points, g_points_2 = scaffold.learn_path(
    observations=data, velocities=v[: len(data)]
)

In [None]:
import matplotlib
import matplotlib.axes
import torch
import numpy as np
matplotlib.use("ipympl")
%matplotlib widget
from test_utils import get_action

ls = [module.l for module in scaffold.modules]
g_ticks = np.cumsum(ls)
g_ticks = np.insert(g_ticks, 0, 0)
g_ticks = g_ticks[:-1]


pos = torch.tensor((0,0))
g = scaffold.cartesian_coordinates_to_grid_state(
    torch.tensor(pos, device=scaffold.device)
)
scaffold.g = g
im_data = scaffold.recall_from_position(g).reshape(28, 28).cpu().numpy()
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
im_artist = ax[0].imshow(im_data)
g_artist = ax[1].imshow(scaffold.g.unsqueeze(0).T.cpu().numpy())
g_text_artist = ax[1].set_title(f"g_idx: {g.nonzero().flatten().cpu().numpy()}", fontsize=12)
text_artist = ax[0].set_title(f"pos: {pos}", fontsize=12)
# disabley ticks
ax[1].xaxis.set_visible(False)
ax[1].set_yticks(g_ticks)

plt.show()

In [None]:
while True:
    action = get_action()
    if action is None:
        break
    pos = pos + action
    g = scaffold.cartesian_coordinates_to_grid_state(
        torch.tensor(pos, device=scaffold.device)
    )
    scaffold.g = g
    im_data = scaffold.recall_from_position(g).reshape(28, 28).cpu().numpy()
    im_artist.set_data(im_data)
    text_artist.set_text(f"pos: {pos}")
    g_artist.set_data(scaffold.g.unsqueeze(0).T.cpu().numpy())
    g_text_artist.set_text(f"g_idx: {g.nonzero().flatten().cpu().numpy()}")
    fig.canvas.draw()