Compares results from elephant sttc with my implementation of sttc

In [2]:
import numpy as np
import pandas as pd
import csv
import quantities as pq
from elephant.spike_train_correlation import spike_time_tiling_coefficient
import neo
import matplotlib.pyplot as plt
import seaborn as sns

### Functions

In [54]:
def calc_sttc_elephant(spike_train_1, spike_train_2, t_start_, t_stop_, dt_):
    spike_train_neo_1 = neo.SpikeTrain(spike_train_1, units='ms', t_start=t_start_, t_stop=t_stop_)
    spike_train_neo_2 = neo.SpikeTrain(spike_train_2, units='ms', t_start=t_start_, t_stop=t_stop_)
    sttc_no_shift = spike_time_tiling_coefficient(spike_train_neo_1, spike_train_neo_2, dt=dt_ * pq.ms)
    return sttc_no_shift

In [14]:
def run_T(spiketrain, N, dt, t_start, t_stop):
    """
    Calculate the proportion of the total recording time 'tiled' by spikes.
    N: number of spikes
    """
    time_A = 2 * N * dt  # maxium possible time

    if N == 1:  # for just one spike in train
        if spiketrain[0] - t_start < dt:
            time_A = time_A - dt + spiketrain[0] - t_start
        elif spiketrain[0] + dt > t_stop:
            time_A = time_A - dt - spiketrain[0] + t_stop

    else:  # if more than one spike in train
        '''
            This part of code speeds up calculation with respect to the original version
        '''
        diff = np.diff(spiketrain)
        idx = np.where(diff<(2*dt))[0]
        Lidx = len(idx)
        time_A = time_A - 2 * Lidx * dt + diff[idx].sum()

        if (spiketrain[0] - t_start) < dt:
            time_A = time_A + spiketrain[0] - dt - t_start

        if (t_stop - spiketrain[N - 1]) < dt:
            time_A = time_A - spiketrain[-1] - dt + t_stop

    T = (time_A / (t_stop - t_start)) #.item()
    
    return time_A, T

def run_P(spiketrain_1, spiketrain_2, N1, N2, dt):
    """
    Check every spike in train 1 to see if there's a spike in train 2
    within dt
    """
    Nab = 0
    j = 0
    for i in range(N1):
        L=0
        while j < N2:  # don't need to search all j each iteration
            if np.abs(spiketrain_1[i] - spiketrain_2[j]) <= dt:
                Nab = Nab + 1
                L+=1
                break
            elif spiketrain_2[j] > spiketrain_1[i]:
                break
            else:
                j = j + 1
    return Nab

def calc_sttc(lag1_l, lag2_l, t_start, t_stop, dt):
    n_a = len(lag1_l)
    n_b = len(lag2_l)

    if n_a == 0 or n_b == 0:
        index = 0
    else:
        time_a, t_a = run_T(lag1_l, n_a, dt, t_start, t_stop)
        print('time_a {}, t_a {}'.format(time_a, t_a))
        
        time_b, t_b = run_T(lag2_l, n_b, dt, t_start, t_stop)
        print('time_b {}, t_b {}'.format(time_b, t_b))
        
        p_a_count = run_P(lag1_l, lag2_l, n_a, n_b, dt)
        p_a = p_a_count / float(n_a)
        print('p_a_count {}, p_a {}'.format(p_a_count, p_a))
        
        p_b_count = run_P(lag2_l, lag1_l, n_b, n_a, dt)
        p_b = p_b_count / float(n_b)
        print('p_b_count {}, p_b {}'.format(p_b_count, p_b))
        
        index = 0.5 * (p_a - t_b) / (1 - p_a * t_b) + 0.5 * (p_b - t_a) / (1 - p_b * t_a)
    # print(index)
    return index

### Test on some random 0/1 arrays

In [39]:
spike_trains_l = []
n_trains = 10
train_len = 1000

for i in range(n_trains):
    poisson = np.random.poisson(.1, train_len)
    bounded_poisson = np.clip(poisson, a_min=0, a_max=1)
    spike_trains_l.append(bounded_poisson)

In [57]:
for i in range(n_trains-1):
    print('#############')
    print('i {}'.format(i))
    spike_times_1 = np.where(spike_trains_l[i] == 1)[0]
    spike_times_2 = np.where(spike_trains_l[i+1] == 1)[0]
    
    my_sttc = calc_sttc(spike_times_1, spike_times_2, t_start=0, t_stop=train_len, dt=10)
    print(my_sttc)
    
    elephant_sttc = calc_sttc_elephant(spike_times_1, spike_times_2, t_start_=0, t_stop_=train_len, dt_=10)
    print(elephant_sttc)

    print('my_sttc - elephant_sttc: {}'.format(my_sttc - elephant_sttc))

#############
i 0
time_a 865, t_a 0.865
time_b 825, t_b 0.825
p_a_count 97, p_a 0.8981481481481481
p_b_count 77, p_b 0.8850574712643678
0.18397759092595073
0.18397759092595073
my_sttc - elephant_sttc: 0.0
#############
i 1
time_a 825, t_a 0.825
time_b 820, t_b 0.82
p_a_count 73, p_a 0.8390804597701149
p_b_count 80, p_b 0.8163265306122449
0.01730091654384685
0.01730091654384685
my_sttc - elephant_sttc: 0.0
#############
i 2
time_a 820, t_a 0.82
time_b 909, t_b 0.909
p_a_count 95, p_a 0.9693877551020408
p_b_count 84, p_b 0.875
0.3514456050490732
0.3514456050490732
my_sttc - elephant_sttc: 0.0
#############
i 3
time_a 909, t_a 0.909
time_b 822, t_b 0.822
p_a_count 77, p_a 0.8020833333333334
p_b_count 75, p_b 0.8928571428571429
-0.07207371257878956
-0.07207371257878956
my_sttc - elephant_sttc: 0.0
#############
i 4
time_a 822, t_a 0.822
time_b 848, t_b 0.848
p_a_count 67, p_a 0.7976190476190477
p_b_count 83, p_b 0.8058252427184466
-0.10179458460995568
-0.10179458460995568
my_sttc - elephan