In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from scipy import stats

plt.style.use("./src/presentation.mplstyle")
from vectorhash_convered import *
import torch

In [None]:
shapes = [(3, 3), (4, 4), (5, 5)]  # module period (linear dimension)
# Np_lst=np.arange(25,350,250)
Np_lst = [25, 350]
pflip = 0.25  # param controling injected noise
Niter = 2  # number of iterations for scaffold dynamics
nruns = 1  # number of runs you want to average the results over
Npos = torch.prod(torch.prod(torch.tensor(shapes))).item()
Npatts_lst = np.arange(1, Npos + 1, 100)

In [None]:
err_gcpc, num_correct = capacity_gcpc_vectorized(
    shapes=shapes,
    Np_lst=Np_lst,
    pflip=pflip,
    Niter=Niter,
    Npos=Npos,
    nruns=nruns,
    Npatts_lst=Npatts_lst,
    test_generalization="no",
)

In [None]:
print(err_gcpc)
print(num_correct)

In [None]:
##Compute capacity across values of Np. This particular code designed assuming typical capacity estimation
## Does not work if generalization is being tested. Use the cell below if generalization is being tested

errthresh = 0.001  # Some tiny nonzero value above possible floating point error


def add_labels(ax, title, xlabel, ylabel):
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.legend(loc="best")
    return ax


capacity = -1 * np.ones((len(Np_lst), nruns))
valid = err_gcpc <= errthresh  # bool


for Np in range(len(Np_lst)):
    # Original conservative
    for r in range(nruns):
        lst = torch.argwhere(valid[Np, :, r] == False)
        # lst = np.argwhere(valid[Np,:] == False)
        if len(lst) == 0:
            # print("full capacity")
            capacity[Np, r] = Npos
        else:
            bef_err = lst[0] - 1
            bef_err = bef_err * (bef_err > 0)  # Don't want to return -1 if lst[0]=0
            capacity[Np, r] = Npatts_lst[bef_err[0]]

avg_cap = np.mean(capacity, axis=1)  # mean error over runs
# std_cap = stats.sem(capacity, axis=1)    # std dev over runs
std_cap = np.std(capacity, axis=1)  # std dev over runs


fig, ax = plt.subplots()
ax.errorbar(Np_lst, avg_cap, yerr=std_cap, fmt="ko--", label="2D grid code network")
add_labels(
    ax,
    f"Grid cells={50}; Grid periods={[3,4,5]}; errthresh={errthresh};",
    "number of place cells",
    "number of patterns",
)
# savefig(fig, ax, f"{results_dir}/{filename}")
plt.show()

In [None]:
# When computing capacity for generalization plot, can simply plot mean and std of num_correct as below.
# Make sure to pass test_generalization='yes' if doing this!

[
    plt.errorbar(
        x=Npatts_lst,
        y=num_correct[i].mean(axis=1),
        yerr=num_correct[0].std(axis=1),
        label="Np=" + str(Np_lst[i]),
        marker="o",
        mew=0,
    )
    for i in range(len(Np_lst))
]
plt.xlabel("num of trained patterns")
plt.ylabel("num of generated fixed points")
plt.show()