In [1]:
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import poisson, nbinom
import seaborn as sns
import pandas as pd

#from Model_xx import Model
from Model_less_weird import Model
from PlaceInputs import PlaceInputs
from utils import *
import configs

# Parameters and Initialization

In [132]:
# Parameters
N_inp = 5000
N_bar = 5000
num_states = 100
steps = 100

In [3]:
place_inputs = PlaceInputs(N_inp, num_states).get_inputs()
model = Model(N_inp, N_bar, num_states, steps=steps)

In [4]:
cache_states = [0, 20, 40, 60, 80]

# Run Task (with learning)

In [33]:
model.reset()
cache_acts = []
cache_acts_over_time = []
retrieval_acts_over_time = []

for cache_state in cache_states:
    print("Now I just cached at location:", cache_state)
    preacts, acts, _, acts_over_time = model.run_recurrent(place_inputs, n_zero_input=20)
    model.update(place_inputs[cache_state], acts[cache_state], preacts[cache_state])
    cache_acts.append(acts.copy())
    cache_acts_over_time.append(acts_over_time.copy())
    _, _, _, acts_over_time = model.run_recall(0.1, place_inputs)
    retrieval_acts_over_time.append(acts_over_time.copy())

Now I just cached at location: 0
Now I just cached at location: 20
Now I just cached at location: 40
Now I just cached at location: 60
Now I just cached at location: 80


In [34]:
_, place_acts, _, _ = model.run_nonrecurrent(place_inputs)
place_mean = np.mean(place_acts, axis=0).flatten()
place_std = np.std(place_acts, axis=0).flatten()

# Define cache activity as a mean over some time window

In [140]:
avg_cache_acts = []
for i in range(len(cache_states)):
    location = cache_states[i]
    vec = cache_acts_over_time[i][steps+model.seed_steps-1, location, :]#.mean(axis=0)
    avg_cache_acts.append(vec)
avg_cache_acts = np.array(avg_cache_acts)

In [141]:
cache_mean = avg_cache_acts.mean(axis=0)

In [142]:
avg_retrieval_acts = []
for i in range(len(cache_states)):
    location = cache_states[i]
    vec = retrieval_acts_over_time[i][steps-1, location, :]#.mean(axis=0)
    avg_retrieval_acts.append(vec)
avg_retrieval_acts = np.array(avg_retrieval_acts)

In [143]:
def nb(mu, std_scaling=1.0, mu_scaling=0.75, shift=0.0):
    """ mu is a vector of firing rates. std_scaling is a scalar. """


    mu = mu*mu_scaling + 1E-8 + shift
    std = std_scaling * np.sqrt(mu)
    std += 1E-8
    n = (mu**2)/(std**2 - mu)
    p = mu/(std**2)
    nb_mu = nbinom.rvs(n, p)
    return nb_mu.astype(float)


def nb_corr(a, b):
    nb_a = nb(a)
    nb_b = nb(b)
    return np.corrcoef(nb_a, nb_b)[0, 1]

# Fig 3ab, Fig 4: cache v. visit correlation

In [144]:
cache_cache_corr = {'Distance': [], 'Correlation': []}
cache_visit_corr = {'Distance': [], 'Correlation': []}
visit_visit_corr = {'Distance': [], 'Correlation': []}
std_scaling = 1

for i in range(len(cache_states)):
    for j in range(i, len(cache_states)):
        _distance = distance(cache_states[i], cache_states[j], num_states)
        
        cache_act_i = avg_cache_acts[i].copy()
        cache_act_j = avg_cache_acts[j].copy()
        place_act_i = place_acts[cache_states[i]].copy()
        place_act_j = place_acts[cache_states[j]].copy()
        
        for _ in range(3): # More samples
            c_c_corr = nb_corr(cache_act_i, cache_act_j)
            c_v_corr = nb_corr(cache_act_i, place_act_j)
            v_v_corr = nb_corr(place_act_i, place_act_j)

            cache_cache_corr['Distance'].append(_distance)
            cache_cache_corr['Correlation'].append(c_c_corr)

            cache_visit_corr['Distance'].append(_distance)
            cache_visit_corr['Correlation'].append(c_v_corr)

            visit_visit_corr['Distance'].append(_distance)
            visit_visit_corr['Correlation'].append(v_v_corr)
            
cache_retrieval_corr = []
for i in range(len(cache_states)):
    cache_act_i = avg_cache_acts[i].copy()
    retrieval_act_i = avg_retrieval_acts[i].copy()
    c_r_corr = nb_corr(cache_act_i, retrieval_act_i)
    cache_retrieval_corr.append(c_r_corr)

In [145]:
# cc

In [146]:
cache_cache_corr = pd.DataFrame(cache_cache_corr)
cache_cache_corr[cache_cache_corr['Distance']==0]

Unnamed: 0,Distance,Correlation
0,0,0.440366
1,0,0.534453
2,0,0.468233
15,0,0.476849
16,0,0.524313
17,0,0.473271
27,0,0.591039
28,0,0.541148
29,0,0.542982
36,0,0.467484


In [147]:
np.mean(cache_cache_corr[cache_cache_corr['Distance']==0])

Distance       0.000000
Correlation    0.507731
dtype: float64

In [148]:
# cr

In [149]:
cache_retrieval_corr

[0.5170703521332835,
 0.4365240387214454,
 0.5382730294700442,
 0.43743560525738023,
 0.49196558082200437]

In [150]:
np.mean(cache_retrieval_corr)

0.4842537212808315

In [151]:
# vv

In [152]:
visit_visit_corr = pd.DataFrame(visit_visit_corr)
visit_visit_corr[visit_visit_corr['Distance']==0]

Unnamed: 0,Distance,Correlation
0,0,0.185426
1,0,0.174673
2,0,0.193182
15,0,0.216849
16,0,0.18325
17,0,0.1779
27,0,0.212617
28,0,0.164877
29,0,0.167711
36,0,0.213199


In [153]:
np.mean(visit_visit_corr[visit_visit_corr['Distance']==0])

Distance       0.000000
Correlation    0.188578
dtype: float64