## Reproduce figure 2 of Litwin-Kumar Doiran 2012

Use data from the simulated balanced network to reproduce the figure 2. 

In [None]:
import sys 
sys.path.append('../code/likelihoodfree-models/')
from lfmods.balanced_network_utils import *
%matplotlib inline
import time

In [None]:
# specify data folder 
folder = '/Users/Jan/Dropbox/Master/mackelab/code/balanced_clustered_network/data/'

# give filename for uniform and clustered data 
fn_uni = '15009991276ree10_dur20_brain1.p'
fn_clus = '150099931884ree25_dur20_brain1.p'

# filename params 
time_str = time.time()
dur = '2s'
figure_filename = '{}_figure2_dur{}'.format(time_str, dur).replace('.', '')

# load data 
d_uni = load_data(fn_uni, folder)
d_clus = load_data(fn_clus, folder)

In [None]:
# assume a single long trial, use only E neurons
# extract parameters 

# set time window for ff  in sec 
window_length_ff = 0.1  
# and for rho 
window_length_rho = 0.05 

# time windows for ff ocer time windows 
time_windows = np.linspace(0.025, 0.2, 8)

# get the params 
params = d_uni['params']
n_rounds = params['n_rounds']
n_trials = params['n_trials']
simulation_time = np.asarray(params['simulation_time'])  # remove unit
NE = params['NE']
NI = params['NI']
n_clusters = 50

# get the total number of neurons pairs that are in a cluster 
Nc = NE / n_clusters
n_pairs_in_cluster = n_clusters * Nc**2
n_random_pairs = int(np.sqrt(n_pairs_in_cluster))

time_offset = 1.  # in sec
delta_t = simulation_time - time_offset  # in sec
recordings_length = delta_t

In [None]:
stat_dict=dict(uniform={}, clustered={})
keys = ['uniform', 'clustered']

for idx, data in enumerate([d_uni, d_clus]): 

    spiketimedict = data['trial0']['spikes_E']

    # rates 
    spikecounts = get_spikecounts_fixed_time_window(spiketimedict, time_offset, delta_t)
    stat_dict[keys[idx]]['rates'] = spikecounts / delta_t
    
    # fano factors 
    spikecounts_over_windows_ff = get_spike_counts_over_time_windows(spiketimedict, time_offset, delta_t, 
                                                                      window_length=window_length_ff)
    stat_dict[keys[idx]]['ff'] = calculate_fano_factor(spikecounts_over_windows_ff)
  
    # rho all 
    spikecounts_over_windows_rho = get_spike_counts_over_time_windows(spiketimedict, time_offset, delta_t, 
                                                              window_length=window_length_rho)
    stat_dict[keys[idx]]['rho'] = calculate_correlation_matrix(spikecounts_over_windows_rho)
    
    # rho pairs is dependent on the clustering 
    if idx == 0: 
        shuffled_counts = spikecounts_over_windows_rho.copy()
        np.random.shuffle(shuffled_counts)
        rho_pairs = calculate_correlation_matrix(shuffled_counts[np.newaxis, :n_random_pairs, :])
    else:
        rho_pairs = calculate_clusterpair_correlations(spikecounts_over_windows_rho, n_clusters, Nc)
    stat_dict[keys[idx]]['rho_pairs'] = rho_pairs
    
    # ff over time windows 
    ff_over_windows = np.zeros_like(time_windows)
    for w_idx, time_window in enumerate(time_windows): 
        tmp_spikecounts = get_spike_counts_over_time_windows(spiketimedict, time_offset, delta_t, 
                                                                  window_length=time_window)
        ff_over_windows[w_idx] = calculate_fano_factor(tmp_spikecounts).mean()
        
    stat_dict[keys[idx]]['ff_over_windows'] = ff_over_windows

In [None]:
plt.figure(figsize=(15, 10))
colors = ['C0', 'C1']
for idx, key in enumerate(keys): 
    
    rates = stat_dict[key]['rates']
    ff = stat_dict[key]['ff']
    rho = stat_dict[key]['rho']
    rho_pairs = stat_dict[key]['rho_pairs']
    ff_over_windows = stat_dict[key]['ff_over_windows']

    # plotting 
    plt.subplot(231)
    # plot the rates 
    plt.hist(rates, bins=40, range=[0, 15], alpha=.3, lw=3.)
    plt.axvline(np.mean(rates), linestyle='--', label='mean={}'.format(np.round(np.mean(rates), 2)), 
                color=colors[idx])
    plt.title('Rates for uniform (blue) and clustered (orange) connectivity')
    plt.legend()  
    plt.xlabel('Rate (spikes/sec)')
    
    # the ffs
    plt.subplot(233)
    plt.hist(ff, bins=40, range=[0, 3.5], alpha=.3, lw=3.)
    plt.title('Fano factors over trials and {}s windows'.format(window_length_ff))
    plt.axvline(np.mean(ff), linestyle='--', label='mean={}'.format(np.round(np.mean(ff), 2)), 
                color=colors[idx])
    plt.legend()
    plt.xlabel('Fano factor')
    
    # the rho
    plt.subplot(234)
    plt.hist(rho, bins=40, range=[-.5, .5], alpha=.3)
    plt.title('Corr over {}s windows'.format(window_length_rho))
    plt.axvline(np.mean(rho), linestyle='--', label='mean={}'.format(np.round(np.mean(rho), 2)), 
                color=colors[idx])
    plt.legend()
    plt.xlabel('Correlation (all pairs)')
    
    # the rho pairs 
    plt.subplot(235)
    plt.hist(rho_pairs, bins=40, range=[-.5, .5], alpha=.3)
    plt.title('Corr over {}s windows'.format(window_length_rho))
    plt.axvline(np.mean(rho_pairs), linestyle='--', label='mean={}'.format(np.round(np.mean(rho_pairs), 2)), 
                color=colors[idx])
    plt.legend()
    plt.xlabel('Correlation (same cluster)')
    
    # fanos over windows 
    plt.subplot(236)
    plt.title('Fano factors over different time windows')
    plt.plot(time_windows, ff_over_windows, '-o')
    plt.xlabel('Window (ms)')
    plt.ylabel('Fano factor')
    
save_figure(filename=figure_filename + '.pdf')