Compares results from elephant sttc with my implementation of sttc

In [1]:
import numpy as np
import pandas as pd
import csv
import quantities as pq
from elephant.spike_train_correlation import spike_time_tiling_coefficient
import neo
import matplotlib.pyplot as plt
import seaborn as sns

### Functions

In [26]:
def acf_sttc(signal_, n_lags_, lag_shift_=50, sttc_dt_=25, signal_length_=1000, verbose_=True):
    #def calculate_acf_sttc_t(spike_train, n_lags_, acf_lag_ms_, sttc_lag_ms_, rec_length_, verbose=True):
    
    shift_ms_l = np.linspace(lag_shift_, lag_shift_ * (n_lags_-1), n_lags_-1).astype(int)
    if verbose_:
        print('shift_ms_l {}'.format(shift_ms_l))

    acf_l = []

    sttc_no_shift = calc_sttc(signal_, signal_, t_start=0, t_stop=signal_length_, dt=sttc_dt_)
    acf_l.append(sttc_no_shift)
    # print(acf_l)

    # correlated shifted signal
    for shift_ms in shift_ms_l:
        spike_1 = signal_[signal_ >= shift_ms]
        spike_2 = signal_[signal_ < n_lags_*lag_shift_ - shift_ms]
        # align, only 1st
        spike_1_aligned = [spike - shift_ms for spike in spike_1]
        if verbose_:
            print(shift_ms)
            print(spike_1)
            print(spike_2)
            print(spike_1_aligned)
            print('spike_1 {}, spike_2 {}'.format(spike_1.shape, spike_2.shape))
        
        isttc = calc_sttc(spike_1_aligned, spike_2, t_start=0, t_stop=signal_length_-shift_ms, dt=sttc_dt_)
       # print(isttc)
        acf_l.append(isttc)

    return acf_l

In [3]:
def calculate_acf_sttc_t(spike_train, n_lags_, acf_lag_ms_, sttc_lag_ms_, rec_length_, verbose=True):
    shift_ms_l = np.linspace(acf_lag_ms_, acf_lag_ms_ * (n_lags_-1), n_lags_-1).astype(int)
    if verbose:
        print('shift_ms_l {}'.format(shift_ms_l))

    spike_train_bin = np.zeros(rec_length_)
    spike_train_bin[spike_train] = 1
    if verbose:
        print(spike_train_bin.shape)

    sttc_self_l = []
    # correlate with itself
    spike_train_neo = neo.SpikeTrain(spike_train, units='ms', t_start=0, t_stop=len(spike_train_bin))
    sttc_no_shift = spike_time_tiling_coefficient(spike_train_neo, spike_train_neo, dt=sttc_lag_ms_ * pq.ms)
    sttc_self_l.append(sttc_no_shift)

    # correlated shifted signal
    for shift_ms in shift_ms_l:

        spike_train_bin1 = spike_train_bin[shift_ms:]
        spike_train_bin2 = spike_train_bin[:- shift_ms]
        if verbose:
            print('spike_train_bin1 {}, spike_train_bin2 {}'.format(spike_train_bin1.shape, spike_train_bin2.shape))
        
        spike_train_bin1_idx = np.nonzero(spike_train_bin1)[0]
        spike_train_bin2_idx = np.nonzero(spike_train_bin2)[0]
        if verbose:
            print('spike_train_bin1_idx {}'.format(spike_train_bin1_idx))
            print('spike_train_bin2_idx {}'.format(spike_train_bin2_idx))
        
        spike_train_neo_1 = neo.SpikeTrain(spike_train_bin1_idx, units='ms', t_start=0, t_stop=len(spike_train_bin1))
        spike_train_neo_2 = neo.SpikeTrain(spike_train_bin2_idx, units='ms', t_start=0, t_stop=len(spike_train_bin2))
        if verbose:
            print(spike_train_neo_1)
            print(spike_train_neo_2)
        
        sttc_self = spike_time_tiling_coefficient(spike_train_neo_1, spike_train_neo_2, dt=sttc_lag_ms_ * pq.ms)
        sttc_self_l.append(sttc_self)

    return sttc_self_l

In [4]:
def calc_sttc_elephant(spike_train_1, spike_train_2, t_start_, t_stop_, dt_):
    spike_train_neo_1 = neo.SpikeTrain(spike_train_1, units='ms', t_start=t_start_, t_stop=t_stop_)
    spike_train_neo_2 = neo.SpikeTrain(spike_train_2, units='ms', t_start=t_start_, t_stop=t_stop_)
    sttc_no_shift = spike_time_tiling_coefficient(spike_train_neo_1, spike_train_neo_2, dt=dt_ * pq.ms)
    return sttc_no_shift

In [20]:
def run_T(spiketrain, N, dt, t_start, t_stop):
    """
    Calculate the proportion of the total recording time 'tiled' by spikes.
    N: number of spikes
    """
    time_A = 2 * N * dt  # maxium possible time

    if N == 1:  # for just one spike in train
        if spiketrain[0] - t_start < dt:
            time_A = time_A - dt + spiketrain[0] - t_start
        elif spiketrain[0] + dt > t_stop:
            time_A = time_A - dt - spiketrain[0] + t_stop

    else:  # if more than one spike in train
        '''
            This part of code speeds up calculation with respect to the original version
        '''
        diff = np.diff(spiketrain)
        idx = np.where(diff<(2*dt))[0]
        Lidx = len(idx)
        time_A = time_A - 2 * Lidx * dt + diff[idx].sum()

        if (spiketrain[0] - t_start) < dt:
            time_A = time_A + spiketrain[0] - dt - t_start

        if (t_stop - spiketrain[N - 1]) < dt:
            time_A = time_A - spiketrain[-1] - dt + t_stop

    T = (time_A / (t_stop - t_start)) #.item()
    
    return time_A, T

def run_P(spiketrain_1, spiketrain_2, N1, N2, dt):
    """
    Check every spike in train 1 to see if there's a spike in train 2
    within dt
    """
    Nab = 0
    j = 0
    for i in range(N1):
        L=0
        while j < N2:  # don't need to search all j each iteration
            if np.abs(spiketrain_1[i] - spiketrain_2[j]) <= dt:
                Nab = Nab + 1
                L+=1
                break
            elif spiketrain_2[j] > spiketrain_1[i]:
                break
            else:
                j = j + 1
    return Nab

def calc_sttc(lag1_l, lag2_l, t_start, t_stop, dt):
    n_a = len(lag1_l)
    n_b = len(lag2_l)

    if n_a == 0 or n_b == 0:
        index = 0
    else:
        time_a, t_a = run_T(lag1_l, n_a, dt, t_start, t_stop)
        # print('time_a {}, t_a {}'.format(time_a, t_a))
        
        time_b, t_b = run_T(lag2_l, n_b, dt, t_start, t_stop)
        # print('time_b {}, t_b {}'.format(time_b, t_b))
        
        p_a_count = run_P(lag1_l, lag2_l, n_a, n_b, dt)
        p_a = p_a_count / float(n_a)
        # print('p_a_count {}, p_a {}'.format(p_a_count, p_a))
        
        p_b_count = run_P(lag2_l, lag1_l, n_b, n_a, dt)
        p_b = p_b_count / float(n_b)
        # print('p_b_count {}, p_b {}'.format(p_b_count, p_b))

        if t_a * p_b == 1 and t_b * p_a == 1:
            index = 1
        elif t_a * p_b == 1:
            index = 0.5 * (p_a - t_b) / (1 - p_a * t_b) + 0.5 
        elif t_b * p_a == 1:
            index = 0.5 + 0.5 * (p_b - t_a) / (1 - p_b * t_a) 
        else:
            index = 0.5 * (p_a - t_b) / (1 - p_a * t_b) + 0.5 * (p_b - t_a) / (1 - p_b * t_a)
    # print(index)
    return index

### Test on some random 0/1 arrays

In [6]:
spike_trains_l = []
n_trains = 10
train_len = 1000

for i in range(n_trains):
    poisson = np.random.poisson(.05, train_len)
    bounded_poisson = np.clip(poisson, a_min=0, a_max=1)
    spike_trains_l.append(bounded_poisson)

In [16]:
for i in range(n_trains-1):
    print('#############')
    print('i {}'.format(i))
    spike_times_1 = np.where(spike_trains_l[i] == 1)[0]
    spike_times_2 = np.where(spike_trains_l[i+1] == 1)[0]
    
    my_sttc = calc_sttc(spike_times_1, spike_times_2, t_start=0, t_stop=train_len, dt=10)
    print(my_sttc)
    
    elephant_sttc = calc_sttc_elephant(spike_times_1, spike_times_2, t_start_=0, t_stop_=train_len, dt_=10)
    print(elephant_sttc)

    print('my_sttc - elephant_sttc: {}'.format(my_sttc - elephant_sttc))

#############
i 0
0.15850876611183662
0.15850876611183662
my_sttc - elephant_sttc: 0.0
#############
i 1
0.1646225264420118
0.1646225264420118
my_sttc - elephant_sttc: 0.0
#############
i 2
0.1656019296531637
0.1656019296531637
my_sttc - elephant_sttc: 0.0
#############
i 3
0.14053063991381132
0.14053063991381132
my_sttc - elephant_sttc: 0.0
#############
i 4
-0.07608761197138186
-0.07608761197138186
my_sttc - elephant_sttc: 0.0
#############
i 5
0.15432781378221572
0.15432781378221572
my_sttc - elephant_sttc: 0.0
#############
i 6
0.09194986338397529
0.09194986338397529
my_sttc - elephant_sttc: 0.0
#############
i 7
0.031008749835819954
0.031008749835819954
my_sttc - elephant_sttc: 0.0
#############
i 8
0.007634156808932377
0.007634156808932377
my_sttc - elephant_sttc: 0.0


In [27]:
n_lags_ = 20 
for i in range(n_trains):
    print('#############')
    print('i {}'.format(i))
    spike_times = np.where(spike_trains_l[i] == 1)[0]
    
    my_acf_sttc = acf_sttc(spike_times, n_lags_, lag_shift_=50, sttc_dt_=50, signal_length_=1000, verbose_=False) 
    print(my_acf_sttc)
    
    elephant_acf_sttc = calculate_acf_sttc_t(spike_times, n_lags_, acf_lag_ms_=50, sttc_lag_ms_=50, rec_length_=1000, verbose=False)
    #print(elephant_acf_sttc)

    print('my_acf_sttc - elephant_acf_sttc: {}'.format(np.asarray(my_acf_sttc) - np.asarray(elephant_acf_sttc)))

#############
i 0
[1.0, 1.0, 0.3764829030006983, -0.2938758886255934, 1.0, 0.4200000000000004, 1.0, 1.0, 1.0, 1.0, 0.428819444444445, 0.4371827411167506, 1.0, 0.4368852459016391, 0.5753424657534247, 1, 1, 1, 1, 1.0]
my_acf_sttc - elephant_acf_sttc: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
#############
i 1
[1.0, 1.0, 1.0, 0.24621549421193106, 0.36772908366533846, 0.23904179408766535, 0.1422413793103448, -0.4118969843949635, 1.0, 1.0, 1.0, 0.27750809061488657, 0.3416666666666666, 1, 1.0, 1.0, 1, 1, 1, 1]
my_acf_sttc - elephant_acf_sttc: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
#############
i 2
[1.0, 1.0, 1.0, 0.1731707317073165, 0.36396761133603234, 1.0, 1.0, 0.1906249999999995, 1.0, 1.0, 0.5020345879959313, 1.0, 1.0, -0.5477376452293966, 1.0, 0.5782918149466193, 1, 1, 1, 1]
my_acf_sttc - elephant_acf_sttc: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
#############
i 3
[1, 1, 1.0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1.0, 