In [None]:
from brian2 import *
%matplotlib inline
from sklearn.model_selection import KFold
from brian2tools import *

In [None]:
def get_cluster_connection_probs(REE, k, pee):
    p_out = pee * k/(REE + k -1)
    p_in = REE * p_out
    return p_in, p_out

In [None]:
simulation_time = 3 * second
alpha = 1.

# seed the random number generator
np.random.seed(10)

# create simulation network
net = Network()

# cluster parameters 
# cluster size 
k = 50 
# cluster coef 
ree = 1.0
# average ee sparseness
pee = 0.2 
cluster_weight_factor = 1.9 
p_in, p_out = get_cluster_connection_probs(ree, k, pee)

# neuron parameters
vt = 1
vr = 0
tau_e = 15*ms
tau_i = 10*ms
tau1 = 1 * ms
tau2_e = 3 * ms
tau2_i = 2 * ms
refrac = 5 * ms
tau_scale = 1 * ms

# network parameters 
NE = 4000
NI = 1000
N = NE + NI

# sparseness
pie = 0.5 
pei = 0.5 
pii = 0.5 

# weights 
wee = 0.024
wei = -0.045
wie = 0.014
wii = -0.057

In [None]:
# define neuron equation 
eqs = '''
dv/dt = (mu-v)/tau + (I_e + I_i) / tau_scale : 1
dI_e/dt = -(I_e - x_e)/tau2_e : 1
dI_i/dt = -(I_i - x_i)/tau2_i : 1
dx_e/dt = -x_e / tau1 : 1
dx_i/dt = -x_i / tau1 : 1
mu : 1
tau : second
'''

In [None]:
# set up the network 
G = NeuronGroup(N, eqs, threshold='v>vt', reset='v=vr', method='euler', refractory=refrac)
net.add(G)
Pe = G[:NE]
Pi = G[NE:]
# set E I specific membrane time constants 
Pe.tau = tau_e
Pi.tau = tau_i

# create clusters
Nc = int(NE/k)
PeCluster = [Pe[i*Nc:(i+1)*Nc] for i in range(k)]

In [None]:
# set up connectivity, except for ee 
Sii = Synapses(Pi, Pi, 'w : 1', on_pre='''x_i += w''')
Sii.connect(p=pii)
Sii.w = wii
net.add(Sii)

Sei = Synapses(Pi, Pe, 'w : 1', on_pre='''x_i += w''')
Sei.connect(p=pei)
Sei.w = wei
net.add(Sei)

Sie = Synapses(Pe, Pi, 'w : 1', on_pre='''x_e += w''')
Sie.connect(p=pie)
Sie.w = wie
net.add(Sie)

if ree == 1: # uniform case 
    print('uniform case')
    See = Synapses(Pe, Pe, 'w : 1', on_pre='''x_e += w''')
    See.connect(p=pee)
    See.w = wee
    net.add(See)
    
elif ree < 0.:  
    # list of synapse objects 
    # do the cluster connection like cross validation: cluster neuron := test idx; other neurons := train idx
    kf = KFold(n_splits=k)
    for idx_out, idx_in in kf.split(range(NE)):  # idx_out holds all other neurons; idx_in holds all cluster neurons
        
        current_cluster = G[idx_in[0]:idx_in[-1]]
        other_exc_neurons = G[idx_out[0]:idx_out[-1]]
        
        # connect current cluster to itself        
        Syn_in = Synapses(current_cluster, current_cluster, 'w : 1', on_pre='''x_e += w''')
        Syn_in.connect(p=p_in)
        Syn_in.w = wee * cluster_weight_factor
        net.add(Syn_in)
        
        # connect current cluster to other exc neurons
        Syn_out = Synapses(current_cluster, other_exc_neurons, 'w : 1', on_pre='''x_e += w''')
        Syn_out.connect(p=p_out)
        Syn_out.w = wee
        net.add(Syn_out)        
else: 
    for i in range(k):
        SeeIn = Synapses(PeCluster[i], PeCluster[i], 'w : 1', on_pre='''x_e += w''')
        SeeIn.connect(p=p_in)
        SeeIn.w = wee * cluster_weight_factor
        net.add(SeeIn)

    # cluster-external excitatory connections (cluster only)
    for i in range(k):
        for j in range(k):
            if (i == j): continue
            SeeOut = Synapses(PeCluster[i], PeCluster[j], 'w : 1', on_pre='''x_e += w''')
            SeeOut.connect(p=p_out)
            SeeOut.w = wee
            net.add(SeeOut)

In [None]:
# set initial values of the membrane voltage
Pe.v = np.random.rand(NE) * (vt - vr) + vr
Pi.v = np.random.rand(NI) * (vt - vr) + vr

# set uniform resting potential around the threshold 
Pe.mu = 2 * np.random.uniform(1.1, 1.2, NE) * (vt - vr) + vr
Pi.mu = 2 * np.random.uniform(1.0, 1.05, NI) * (vt - vr) + vr

# set up monitors 
Mn = StateMonitor(G, ['v', 'I_e', 'I_i', 'x_e', 'x_i'], record=[0, NE])
sme = SpikeMonitor(Pe)

In [None]:
# finally, run the network 
run(simulation_time)

In [None]:
plt.figure(figsize=(15, 5))
brian_plot(sme, markersize=1.)

# ANALYSIS

## Get spike counts from spikemonitor for certain time window 

In [None]:
def select_spiketimes(spiketimes, t, delta_t): 
    spiketimes = np.asarray(spiketimes)
    timemask = np.logical_and(spiketimes >= t, spiketimes <= t + delta_t)
    return spiketimes[timemask]


def get_spike_times_for_time_window(spiketimedict, t, delta_t): 
    return {k: select_spiketimes(v, t, delta_t) for k, v in spiketimedict.items()}


def get_spike_counts_for_time_window(spiketimedict, t, delta_t): 
    # get spike times for 
    spiketimedict = get_spike_times_for_time_window(spiketimedict, t, delta_t)
    spikecounts = [spiketime_array.size for spiketime_array in spiketimedict.values()]
    return np.array(spikecounts)

In [None]:
spiketimedict = sme.spike_trains()
spike_counts = get_spike_counts_for_time_window(spiketimedict, t=1., delta_t=2.)
rates = spike_counts / 2.
plt.figure(figsize=(15, 5))
plt.hist(rates, bins='auto');

## Get spike counts for sliding time window to calculate correlations

In [None]:
def calculate_spike_counts_over_windows(spiketime_dict, t, delta_t, window_length): 
    n_neurons = len(spiketime_dict.keys())
    length_of_recording = delta_t 
    n_time_windows = int(length_of_recording / window_length)
    spike_counts_windows = np.zeros((n_neurons, n_time_windows))

    for window_idx in range(n_time_windows): 
        wt = t + window_idx * window_length
        spike_counts_windows[:, window_idx] = get_spike_counts_for_time_window(spiketime_dict, 
                                                                               t=wt, 
                                                                               delta_t=window_length)
    return spike_counts_windows

def calculate_correlation_matrix(spikecount_matrix_windows): 
    n_trials, n_neurons, n_time_windows = spikecount_matrix_windows.shape
    
    # prelocate the cov matrix 
    cov = np.zeros((n_neurons, n_neurons))
    for trial in range(n_trials): 
        # just add them up over trials 
        cov += np.cov(spikecount_matrix_windows[trial, ...])
    # average across trials 
    cov /= n_trials
    
    # get the mask of spiking neurons idx
    spiking_mask = np.logical_not(np.diag(cov).copy() == 0)
    
    # remove silent neurons from the analysis 
    temp_cov = cov[spiking_mask, :]
    new_cov = temp_cov[:, spiking_mask]
    var = np.diag(new_cov).copy()
    # use the outer product over the variance vector to do it vectorized 
    rho = new_cov / np.sqrt(np.outer(var, var))
    return rho

In [None]:
spike_counts_windows = calculate_spike_counts_over_windows(spiketimedict, t=1., delta_t=2., window_length=0.1)
s = spike_counts_windows[np.newaxis, :, : ]
s.shape

In [None]:
corr = calculate_correlation_matrix(spikecount_matrix_windows=s)
assert (corr.T == corr).all()

In [None]:
temp_corr = corr.copy()
temp_corr[np.diag_indices_from(temp_corr)] = np.inf

In [None]:
rho = temp_corr[np.isfinite(temp_corr)]

In [None]:
plt.figure(figsize=(15, 5))
plt.hist(corr.flatten(), bins='auto')
plt.hist(rho, bins='auto');

### Calculate CV 

CV if defined as the std of the ISI over the mean of the ISI. 

In [None]:
spikemonitor = sme
# get the spike trains from the monitor: a dict with {'neuron_idx' : spike times in sec}
sdict = spikemonitor.spike_trains()

# for every neuron extract spike times between t and t + delta_t and count
cv = []
for idx, n_key in enumerate(sdict):
    spike_times = sdict[n_key]
    spike_times = np.logical_and(spike_times >= t, spike_times <= (t + delta_t))
    if spike_times.size > 10:         
        isi = np.diff(spike_times)
        cv.append(np.std(isi) / np.mean(isi))
        
def calculate_cv(spikemonitor, t=1.0, delta_t=1.5): 
    # get the spike trains from the monitor: a dict with {'neuron_idx' : spike times in sec}
    sdict = spikemonitor.spike_trains()
    t *= second
    delta_t *= second
    
    # for every neuron extract spike times between t and t + delta_t and count
    cv = []
    for idx, n_key in enumerate(sdict):
        spike_times = sdict[n_key]
        spike_times = spike_times[np.logical_and(spike_times >= t, spike_times <= (t + delta_t))]
        if spike_times.size > 5:         
            isi = np.diff(spike_times)
            cv.append(np.std(isi) / np.mean(isi))    
    return np.array(cv)

In [None]:
cvs = calculate_cv(sme)
plt.figure(figsize=(10, 5))
plt.hist(cvs, bins=15)

In [None]:
def calculate_firing_rate(spikemonitor, t=1.0, delta_t=2.0):
    t *= second
    delta_t *= second
    # get the spike trains from the monitor: a dict with {'neuron_idx' : spike times in sec}
    sdict = spikemonitor.spike_trains()

    # prelocate
    spike_counts = np.zeros(len(sdict.keys()))
    # for every neuron extract spike times between t and t + delta_t and count
    for idx, n_key in enumerate(sdict):
        spike_times = sdict[n_key]
        spike_counts[idx] = np.sum(np.logical_and(spike_times >= t, spike_times <= (t + delta_t)))

    # return the firing rate in spikes per second, as an array over neurons 
    return spike_counts / delta_t

In [None]:
rates = calculate_firing_rate(sme, 1.0, 2.0)
plt.hist(rates, bins=15);

In [None]:
sdict = sme.spike_trains()
t = 1.5 * second 
delta_t = 1.5 * second 
dt = 0.002
t0 = 1.5 
T = 3.0

In [None]:
corrs = []
rates = []

# sdict holds the neuron idx as key and the spiketimes as values. 
for idx, n_key in enumerate(sdict): 
    # get the spiketimes of the current neuron 
    spiketimes = sdict[n_key]
    # restrict to the times we are interested in: 1.5s - 3.0s 
    spiketimes = spiketimes[np.logical_and(spiketimes >= t, spiketimes <=(t + delta_t))]
    
    # use a histogram to convert it to an array spike counts in bins of dt=2ms width
    spikeTrain, edges = np.histogram(spiketimes, bins=int(T/dt), range=[1.5, 3.0])

    # get the rate as the total number of spikes over time
    rate = np.sum(spikeTrain) / 1.5
    rates.append(rate)
    
    # calculate autocorrelation of the spike train 
    Q = np.correlate(spikeTrain, spikeTrain, mode='same') 

    # remove central peak 
    #Q[np.argmax(Q)] = 0
    #Q = Q / np.max(Q)
    # save result
    if Q.max(): 
        corrs.append(Q )

In [None]:
cm = np.array(corrs).mean(axis=0)
plt.hist(rates)
plt.xlabel('rate in spikes/s')
plt.ylabel('count'); 

In [None]:
corrs = np.array(corrs)
plt.figure(figsize=(15, 10))
plt.subplot(211)
for c in corrs[:100]: 
    plt.plot(np.linspace(-750, 750, 1500), c, alpha=0.2, color='grey')
plt.plot(np.linspace(-750, 750, 1500), corrs.mean(axis=0), color='orange')

In [None]:
# take the mean correlation 
plt.figure(figsize=(15, 5))
#c = corrs.mean(axis=0)
#c[np.argmax(c)] = np.mean(c)
cm = corrs.mean(axis=0)
cm[np.argmax(cm)] = cm.mean()
plt.plot(cm);