In [None]:
import numpy as np
from numpy.random import randn

lambdas = [3, 4, 5]  # module period
shapes = [(l, l) for l in lambdas]
M = len(lambdas)  # num modules
Ng = np.sum(np.square(lambdas))  # num grid cells
Npos = np.prod(lambdas)
Npos = Npos * Npos
Ns = 2000  # 84*84*3                           # num of sensory cells set at Npos, can be larger
Np_lst = [400]  # np.arange(25, 425, 25)     # num place cells
pflip = 0.0  # measure of noise injected in s (prob of flipping if binary, gaussian noise if cts)
Niter = 1  # number of iterations for scaffold dynamics
nruns = 1
sparsity = 0  # Dummy param for older code, not used currently
Npatts_lst = np.arange(1, 3 * 3 * 4 * 4 * 5 * 5, 200)  # number of patterns to train on
# Npatts is 1, 201, 401 etc... 3401

smoothing_methods = ["argmax", "softmax", "polynomial"]
pseudoinverse_methods = ["exact", "iterative"]

In [None]:
from test_utils import capacity1
from data_utils import load_mnist_dataset, prepare_data

device = "cuda"

# dataset = load_mnist_dataset()
# data, _ = prepare_data(dataset, num_imgs=3600, preprocess_sensory="false")
# data = data.numpy().T
# sign_output=False

data = np.sign(randn(Ns, Npos))
sign_output=True

In [None]:
err_h_l2_results = np.zeros(
    (
        len(smoothing_methods),
        len(pseudoinverse_methods),
        len(Np_lst),
        len(Npatts_lst),
        nruns,
    ),
)
err_s_l2_results = np.zeros(
    (
        len(smoothing_methods),
        len(pseudoinverse_methods),
        len(Np_lst),
        len(Npatts_lst),
        nruns,
    ),
)
err_s_l1_results = np.zeros(
    (
        len(smoothing_methods),
        len(pseudoinverse_methods),
        len(Np_lst),
        len(Npatts_lst),
        nruns,
    ),
)

for i, smoothing_method in enumerate(smoothing_methods):
    for j, pseudoinverse_method in enumerate(pseudoinverse_methods):
        err_h_l2, err_s_l2, err_s_l1 = capacity1(
            shapes,
            Np_lst,
            Npatts_lst,
            nruns,
            data,
            device,
            pseudoinverse_method=pseudoinverse_method,
            smoothing_method=smoothing_method,
            sign_output=sign_output,
        )
        err_h_l2_results[i, j] = err_h_l2.cpu().numpy()
        err_s_l2_results[i, j] = err_s_l2.cpu().numpy()
        err_s_l1_results[i, j] = err_s_l1.cpu().numpy()


# Place states chosen to be random vectors with same sparsity as base case (teal curves in Fig. 3)
# err_pc, err_gc, err_sens, err_senscup, err_sensl1 = capacity(senstrans_gs_random_sparse_p, lambdas, Ng, Np_lst, pflip, Niter, Npos,
# gbook, Npatts_lst, nruns, Ns, sbook, sparsity)
# Assuming linear hippocampal activations
# err_pc, err_gc, err_sens, err_senscup, err_sensl1 = capacity(senstrans_gs_linear_p, lambdas, Ng, Np_lst, pflip, Niter, Npos,
# gbook, Npatts_lst, nruns, Ns, sbook, sparsity)

# Use gbook as a spiraling outward + linear activation (for SI Fig. S13)
# err_pc, err_gc, err_sens, err_senscup, err_sensl1 = capacity(senstrans_gs_linear_p_spiral, lambdas, Ng, Np_lst, pflip, Niter, Npos,
# gbook, Npatts_lst, nruns, Ns, sbook, sparsity)

In [None]:
import matplotlib.pyplot as plt
import torch

Npatts = np.array(nruns * [Npatts_lst])  # Npatts_lst repeated nruns times
Npatts = Npatts.T

for i, smoothing_method in enumerate(smoothing_methods):
    for j, pseudoinverse_method in enumerate(pseudoinverse_methods):
        normlizd_l1 = err_s_l1_results[i, j]
        m = 1 - (2 * normlizd_l1)
        a = (1 + m) / 2
        b = (1 - m) / 2
        a = torch.abs(torch.tensor(a))
        b = torch.abs(torch.tensor(b)).cpu()
        S = -a * np.log2(a) - b * np.log2(b)
        S = np.where(m == 1, np.zeros_like(S), S)
        MI = 1 - S

        if pseudoinverse_method == "iterative":
            label = f"iterative pseudoinverse (ε_hs = 0.1, ε_sh=0.1, hidden_layer_factor=1, smoothing={smoothing_method}"
        elif pseudoinverse_method == "exact":
            label = f"analytic pseudoinverse, smoothing={smoothing_method}"
        plt.errorbar(
            Npatts_lst, MI[0].mean(axis=1), yerr=MI[0].std(axis=1), lw=2, label=label
        )

vhash_y = [
    1.000000000000000000e00,
    1.000000000000000000e00,
    1.000000000000000000e00,
    5.988623183160277641e-01,
    3.667958255856974548e-01,
    2.624110436154711845e-01,
    2.042300801824028511e-01,
    1.672434617281599589e-01,
    1.414727808416358368e-01,
    1.225660944022268772e-01,
    1.082352629751366369e-01,
    9.674044810282866891e-02,
    8.747471863732059205e-02,
    7.977915334088647725e-02,
    7.342708729082536578e-02,
    6.793351052792084843e-02,
    6.324575644685004328e-02,
    5.912155577074185153e-02,
]


plt.errorbar(Npatts_lst, vhash_y, lw=2, label="vectorhash")
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.title(
    f"MI per inp bit vs num patts (N_h={Np_lst[0]}, sparsity=0.6, relu_theta=0.5)"
)

plt.xlim(xmin=100)
# plt.ylim(ymin=0, ymax=1)
plt.ylabel("MI per inp bit")
plt.xlabel("num patts")
plt.grid(which="both")
plt.show()

In [None]:
###Baselines
import numpy as np
import scipy.sparse as sparse
import matplotlib.pyplot as plt
from tqdm import tqdm as tqdm


def cap(W,bound):
    W1=torch.where(W>bound,bound*torch.ones(W.shape),W)
    W2=torch.where(W1<-bound,-bound*torch.ones(W.shape),W1)
    return W2

def corrupt_p(codebook,p=0.1,booktype='-11'):
    rand_indices = torch.sign(torch.random.uniform(size=codebook.shape)- p )
    if booktype=='-11':
        return torch.multiply(codebook,rand_indices)
    elif booktype=='01':
        return abs(codebook - 0.5*(-rand_indices+1))
    elif booktype=='cts':
        return codebook + torch.random.normal(0,1,size=codebook.shape)*p
    else:
        print("codebook should be -11; 01; or cts")
        return 0


def get_weights(patterns,connectivity):
    if connectivity is 'standard':
        if learning == 'hebbian':
            W = patts @ patts.T
        elif learning == 'sparsehebbian':
            prob = sparsity #np.sum(patts)/patts.shape[0]/patts.shape[1]
            W =(1/patts.shape[0])* (patts - prob) @ (patts.T - prob)
        elif learning == 'pinv':
            W= patts @ np.linalg.pinv(patts)
        elif learning == 'bounded_hebbian':
            num_patts = patts.shape[1]
            num_nodes = patts.shape[0]
            W = np.zeros((num_nodes,num_nodes))
            for i in range(num_patts):
                Wtmp = np.outer(patts[:,i] , patts[:,i])/np.sqrt(num_nodes)
                # ~ print(np.amax(Wtmp))
                W = cap(Wtmp + W,bound)
        W = W - torch.diag(torch.diag(W))
    else:
        N = connectivity.shape[0]
        W = sparse.lil_matrix(connectivity.shape)
        for i in range(N):
            for j in connectivity.rows[i]:
                W[i,j] = np.dot(patterns[i],patterns[j])
        W.setdiag(0)
    return W


def entropy(inlist):
    ent = np.zeros(len(inlist))
    for idx,x in enumerate(inlist):
        if x == 0 or x == 1:
            ent[idx] = 0
        else:
            ent[idx] = -1 * ( x*np.log2(x) + (1-x)*np.log2(1-x) )
    return ent

In [None]:
nruns=1
iterations=100
N = 708
corrupt_fraction = 0.0
Npatts_list = np.arange(1,800,10)
connectivity='standard' # Standard fully connected Hopfield network. For sparse connectivity use the next cell
# learning can be 'hebbian', 'bounded_hebbian', 'pinv', or 'sparsehebbian' for sparse hopfield network
learning='bounded_hebbian'
bound=0.3  #Use bound param if learning='bounded_hebbian'

init_overlap = torch.zeros((nruns,*Npatts_list.shape))
final_overlap = torch.zeros((nruns,*Npatts_list.shape))
MI_hc = torch.zeros((nruns,*Npatts_list.shape))


for runidx in range(nruns):
    print("runidx = "+str(runidx))
    
    if learning == 'sparsehebbian':
        # sparse hopfiled 0/1 code
        sparsity = 0.2
        patterns = 1*(torch.random.rand(N,Npatts_list.max()) > (1-sparsity))
    else:
        patterns = torch.sign(torch.random.normal(0,1,(N,Npatts_list.max())))

    
    for idx,Npatts in enumerate(tqdm(Npatts_list)):
        #print(Npatts)
        patts = patterns[:,:Npatts]
        cor_patts = patterns[:,:Npatts]
        W = get_weights(patts,connectivity)
        
        if learning == 'sparsehebbian':
            # sparse hopfield
            theta = torch.sum(W-torch.diag(W), axis=1)
            theta=0.05 #0.04 #0
            rep = (torch.sign(W@cor_patts - theta)+1)/2            
        else:
            rep = torch.sign(W@cor_patts)

        init_overlap[runidx,idx] = np.average(np.einsum('ij,ij->j',rep,patts)/N) 

        rep1 = np.copy(rep)
        for ite in range(iterations-1):
            if learning == 'sparsehebbian':
                rep = (np.sign(W@rep - theta)+1)/2
            else:
                rep = np.sign(W@rep)
            
            if np.sum(abs(rep - rep1))>0:
                rep1 = np.copy(rep)
            else:
                # print("converged at "+str(ite))
                break
        err = np.einsum('ij,ij->j',rep,patts)/N
        overlap = np.average(err) 
        final_overlap[runidx,idx] = overlap #err
        
        if learning=='sparsehebbian':
            q = np.sum(np.abs(rep), axis=0) / N  # sparse hopfield
            m = err
            p = np.sum(patts, axis=0)/patts.shape[0]
            P1e = 1 - (m/p)
            P0e = (q-m)/(1-p)
            MI_hc[runidx,idx] =  np.average( entropy(q) - ( p*entropy(P1e) + (1-p)*entropy(P0e) ) )


# print(init_overlap)
# print(final_overlap)

results_dir = "continuum_results"
# filename = f"sparseconnhopfield__mutualinfo_N={N}_noise={corrupt_fraction}_gamma={gamma}_iter={iterations}_nruns={nruns}"
filename = f"stdhopfield__mutualinfo_N={N}_noise={corrupt_fraction}_iter={iterations}_nruns={nruns}"
# filename = f"pinvhopfield__mutualinfo_N={N}_noise={corrupt_fraction}_iter={iterations}_nruns={nruns}"
# filename = f"sparsehopfield__mutualinfo_N={N}_noise={corrupt_fraction}_p={sparsity}_iter={iterations}_nruns={nruns}"
# filename = f"boundedhopfield__mutualinfo_N={N}_noise={corrupt_fraction}_bound={bound}_iter={iterations}_nruns={nruns}"


fig1 = plt.figure(1)
plt.plot(Npatts_list,init_overlap.mean(axis=0), label='single, corrupt='+str(corrupt_fraction));
plt.plot(Npatts_list,final_overlap.mean(axis=0), label='final, corrupt='+str(corrupt_fraction));
plt.legend()
plt.xlabel('Number of patterns')
plt.ylabel("Overlap");
plt.title(r"N = "+str(N)+", $W$");
plt.show()
# exit()
# fig1.savefig(f"{results_dir}/Overlap_{filename}.png")

if learning=='sparsehebbian':
    print("MI already calculated in loop")
else:
    m = final_overlap
    a = (1+m)/2
    b = (1-m)/2

    S = - a * np.log2(a) - b * np.log2(b)
    S = np.where(m==1, np.zeros_like(S), S)

    MI_hc = 1 - S


fig2 = plt.figure(1)
plt.errorbar(Npatts_list,MI_hc.mean(axis=0),yerr=MI_hc.std(axis=0), label='final, corrupt='+str(corrupt_fraction)); #plt.xscale('log'); plt.yscale('log');
plt.legend()
plt.xlabel('Number of patterns')
plt.ylabel("MI");
plt.title(r"N = "+str(N)+", $W$");
plt.show()
# fig2.savefig(f"{results_dir}/MI_{filename}.png")

data = {
    "N": N,
    "init_overlap": init_overlap,
    "m": final_overlap,
    "MI": MI_hc,
    "Npatts_list": Npatts_list,
    "noise": corrupt_fraction,
    # "q": q  #needed for sparse hebbian
    # "bound": bound #needed for bounded hopfield
}
# write_pkl(f"{results_dir}/{filename}", data)