In [None]:
import sys
sys.path.append('../')

In [None]:
device='cuda'

In [None]:
import torch
from tqdm import tqdm
from clean_scaffold import GridHippocampalScaffold
from hippocampal_sensory_layers import HippocampalSensoryLayer
from vectorhash import build_initializer
from clean_scaffold import ArgmaxSmoothing, SoftmaxSmoothing, PolynomialSmoothing
from hippocampal_sensory_layers import IterativeBidirectionalPseudoInverseHippocampalSensoryLayer, ExactPseudoInverseHippocampalSensoryLayer
from matplotlib import pyplot as plt

In [None]:
def corrupt_p_1(codebook, p=0.1):
    if p == 0.0:
        return codebook
    rand_indices = torch.sign(
        torch.rand(size=codebook.shape, device=codebook.device) - p
    )
    return torch.multiply(codebook, rand_indices)

def dynamics_patts_signed(
    scaffold: GridHippocampalScaffold,
    sensory_hippocampal_layer: HippocampalSensoryLayer,
    sbook_noisy,  # (Npatts, input_size)
    sbook,
    hbook,
    N_iter=1,
):
    h_in_original = sensory_hippocampal_layer.hippocampal_from_sensory(sbook_noisy)
    h = torch.clone(h_in_original)
    for i in range(N_iter):
        g = scaffold.denoise(scaffold.grid_from_hippocampal(h))
        h = scaffold.hippocampal_from_grid(g)

    s_out = torch.sign(sensory_hippocampal_layer.sensory_from_hippocampal(h))
    h_l2_err = torch.linalg.vector_norm(h - hbook) / scaffold.N_h
    s_l2_err = (
        torch.linalg.vector_norm(s_out - sbook) / sensory_hippocampal_layer.input_size
    )
    s_l1_err = torch.mean(torch.abs(s_out - sbook)) / 2

    return h_l2_err, s_l2_err, s_l1_err


def dynamics_patts(
    scaffold: GridHippocampalScaffold,
    sensory_hippocampal_layer: HippocampalSensoryLayer,
    sbook_noisy,  # (Npatts, input_size)
    sbook,
    hbook,
    N_iter=1,
):
    h_in_original = sensory_hippocampal_layer.hippocampal_from_sensory(sbook_noisy)
    h = torch.clone(h_in_original)
    for i in range(N_iter):
        g = scaffold.denoise(scaffold.grid_from_hippocampal(h))
        h = scaffold.hippocampal_from_grid(g)

    s_out = sensory_hippocampal_layer.sensory_from_hippocampal(h)

    h_l2_err = torch.linalg.vector_norm(h - hbook) / scaffold.N_h
    s_l2_err = (
        torch.linalg.vector_norm(s_out - sbook) / sensory_hippocampal_layer.input_size
    )
    s_l1_err = torch.mean(torch.abs(s_out - sbook))

    return h_l2_err, s_l2_err, s_l1_err

In [None]:
def capacity_test_signed(
    scaffold: GridHippocampalScaffold,
    hippocampal_sensory_layer: HippocampalSensoryLayer,
    sbook: torch.Tensor,
    Npatts_list,
    nruns=1,
    device=None,
    p=0.1,
):
    err_h_l2 = -1 * torch.ones((len(Npatts_list), nruns), device=device)
    err_s_l1 = -1 * torch.ones((len(Npatts_list), nruns), device=device)
    err_s_l2 = -1 * torch.ones((len(Npatts_list), nruns), device=device)

    for k in tqdm(range(len(Npatts_list))):
        Npatts = Npatts_list[k]
        if hasattr(hippocampal_sensory_layer, "learn_batch"):
            hippocampal_sensory_layer.learn_batch(sbook[:Npatts], scaffold.H[:Npatts])
        else:
            for j in tqdm(range(Npatts)):
                hippocampal_sensory_layer.learn(scaffold.H[j], sbook[j])

        hbook_subset = scaffold.H[:Npatts]
        sbook_subset = sbook[:Npatts]

        for r in range(nruns):
            sbook_noisy_subset = corrupt_p_1(sbook_subset, p)
            err_h_l2[k, r], err_s_l2[k, r], err_s_l1[k, r] = dynamics_patts_signed(
                scaffold,
                hippocampal_sensory_layer,
                sbook_noisy_subset,
                sbook_subset,
                hbook_subset,
                N_iter=1,
            )

    return err_h_l2, err_s_l2, err_s_l1


def capacity_test(
    scaffold: GridHippocampalScaffold,
    hippocampal_sensory_layer: HippocampalSensoryLayer,
    sbook: torch.Tensor,
    Npatts_list,
    nruns,
    device,
    p=0
):
    err_h_l2 = -1 * torch.ones((len(Npatts_list), nruns), device=device)
    err_s_l1 = -1 * torch.ones((len(Npatts_list), nruns), device=device)
    err_s_l2 = -1 * torch.ones((len(Npatts_list), nruns), device=device)

    for k in tqdm(range(len(Npatts_list))):
        Npatts = Npatts_list[k]

        if hasattr(hippocampal_sensory_layer, "learn_batch"):
            hippocampal_sensory_layer.learn_batch(sbook[:Npatts], scaffold.H[:Npatts])
        else:
            for j in tqdm(range(Npatts)):
                hippocampal_sensory_layer.learn(scaffold.H[j], sbook[j])

        hbook_subset = scaffold.H[:Npatts]
        sbook_subset = sbook[:Npatts]

        for r in range(nruns):
            sbook_noisy_subset = corrupt_p_1(sbook_subset, p)
            err_h_l2[k, r], err_s_l2[k, r], err_s_l1[k, r] = dynamics_patts(
                scaffold,
                hippocampal_sensory_layer,
                sbook_noisy_subset,
                sbook_subset,
                hbook_subset,
                N_iter=1,
            )

    return err_h_l2, err_s_l2, err_s_l1

In [None]:
shapes = [(3, 3), (4, 4), (5, 5)]
N_h_list = [400]
input_size = 784
nruns = 1
Npatts_list = torch.arange(1, torch.tensor(shapes).prod(), 200)

In [None]:
smoothing_methods = [
    # SoftmaxSmoothing(T=1),
    # SoftmaxSmoothing(T=0.1),
    ArgmaxSmoothing(),
    # PolynomialSmoothing(k=2),
    # PolynomialSmoothing(k=5),
    # PolynomialSmoothing(k=8)
]
pseudoinverse_methods = ["exact_pseudoinverse"]  # ",iterative_pseudoinverse"]
initialization_method = "by_scaling"
relu_options = [False, True]
p = 0
data = torch.sign(
    torch.randn(torch.tensor(shapes).prod().item(), input_size, device=device)
)

In [None]:
err_h_l2_results = torch.zeros(
    (
        len(smoothing_methods),
        len(relu_options),
        len(pseudoinverse_methods),
        len(N_h_list),
        len(Npatts_list),
        nruns,
    ),
)
err_s_l2_results = torch.zeros(
    (
        len(smoothing_methods),
        len(relu_options),
        len(pseudoinverse_methods),
        len(N_h_list),
        len(Npatts_list),
        nruns,
    ),
)
err_s_l1_results = torch.zeros(
    (
        len(smoothing_methods),
        len(relu_options),
        len(pseudoinverse_methods),
        len(N_h_list),
        len(Npatts_list),
        nruns,
    ),
)

for k, N_h in enumerate(N_h_list):
    for i, smoothing_method in enumerate(smoothing_methods):
        for l, relu in enumerate(relu_options):
            initializer, relu_theta, mean_h = build_initializer(
                shapes,
                initalization_method=initialization_method,
                percent_nonzero_relu=0.8,
                sparse_initialization=0.1,
                device=device,
            )
            scaffold = GridHippocampalScaffold(
                shapes,
                N_h,
                sparse_matrix_initializer=initializer,
                smoothing=smoothing_method,
                device=device,
                relu_theta=0.5,
                sanity_check=True,
                relu=relu,
            )
            for j, pseudoinverse_method in enumerate(pseudoinverse_methods):
                if pseudoinverse_method == "exact_pseudoinverse":
                    layer = ExactPseudoInverseHippocampalSensoryLayer(
                        input_size=input_size,
                        N_h=N_h,
                        N_patts=scaffold.N_patts,
                        hbook=scaffold.H,
                        device=device,
                    )
                elif pseudoinverse_method == "iterative_pseudoinverse":
                    layer = IterativeBidirectionalPseudoInverseHippocampalSensoryLayer(
                        input_size=input_size,
                        N_h=N_h,
                        epsilon_hs=0.1,
                        epsilon_sh=0.1,
                        hidden_layer_factor=1,
                        device=device,
                    )

                err_h_l2, err_s_l2, err_s_l1 = capacity_test_signed(
                    scaffold=scaffold,
                    hippocampal_sensory_layer=layer,
                    sbook=data,
                    Npatts_list=Npatts_list,
                    nruns=nruns,
                    device=device,
                    p=p,
                )
                err_h_l2_results[i, l, j, k] = err_h_l2
                err_s_l2_results[i, l, j, k] = err_s_l2
                err_s_l1_results[i, l, j, k] = err_s_l1

In [None]:
Npatts = torch.tensor(Npatts_list)  # Npatts_lst repeated nruns times
Npatts = Npatts.T

fig, ax = plt.subplots(figsize=(10, 6))
ax.set_title("MI per inp bit vs num patts")
ax.set_xlabel("num patts")
ax.set_ylabel("MI per inp bit")
ax.set_xscale("log")
ax.set_yscale("log")
ax.grid(which="both")


for k, N_h in enumerate(N_h_list):
    for i, smoothing_method in enumerate(smoothing_methods):
        for l, relu in enumerate(relu_options):
            for j, pseudoinverse_method in enumerate(pseudoinverse_methods):
                normlizd_l1 = err_s_l1_results[i, l, j, k]
                m = 1 - (2 * normlizd_l1)
                a = (1 + m) / 2
                b = (1 - m) / 2
                a = torch.abs(torch.tensor(a))
                b = torch.abs(torch.tensor(b)).cpu()
                S = -a * torch.log2(a) - b * torch.log2(b)
                S = torch.where(m == 1, torch.zeros_like(S), S)
                MI = 1 - S

                if pseudoinverse_method == "iterative_pseudoinverse":
                    label = f"layer={layer}, hidden_layer_factor=1, smoothing={smoothing_method} relu={relu}"
                elif pseudoinverse_method == "exact_pseudoinverse":
                    label = f"analytic pseudoinverse, smoothing={smoothing_method}, relu={relu}"
                ax.errorbar(
                    Npatts_list,
                    MI.mean(axis=-1),
                    yerr=MI.std(axis=-1),
                    lw=2,
                    label=label,
                )

vhash_y = [
    1.000000000000000000e00,
    1.000000000000000000e00,
    1.000000000000000000e00,
    5.988623183160277641e-01,
    3.667958255856974548e-01,
    2.624110436154711845e-01,
    2.042300801824028511e-01,
    1.672434617281599589e-01,
    1.414727808416358368e-01,
    1.225660944022268772e-01,
    1.082352629751366369e-01,
    9.674044810282866891e-02,
    8.747471863732059205e-02,
    7.977915334088647725e-02,
    7.342708729082536578e-02,
    6.793351052792084843e-02,
    6.324575644685004328e-02,
    5.912155577074185153e-02,
]


ax.errorbar(Npatts_list, vhash_y, lw=2, label="vectorhash")
ax.legend()

# plt.ylim(ymin=0, ymax=1)
plt.show()

In [None]:
fig.savefig(
    f"capacity_test_signed_{initialization_method}_N_h_{N_h_list[0]}_p_{p}.png",
)