Test for logic in acf_sttc_trial_avg 

In [13]:
import numpy as np
from scipy.optimize import curve_fit, OptimizeWarning

# import from scripts
import os
os.chdir(os.path.expanduser("D:\\intr_timescales\\isttc\\scripts"))
#os.chdir(os.path.expanduser("C:\\Users\\ipoch\\Documents\\repos\\isttc\\scripts"))
from calculate_acf import sttc_fixed_2t, sttc_calculate_t

In [5]:
def func_single_exp_monkey_like(x, a, b, c):
    #return a * np.exp(-b * x) + c
    return a * (np.exp(-b * x) + c) # as in the paper

### Get data

In [6]:
spikes_trials_30 = np.load('Q:\\Personal\\Irina\\projects\\isttc\\results\\synthetic_data\\test_full_split\\trials_30_runs_1\\spikes_trials_30.npy', 
                           allow_pickle=True)

### Run 

In [33]:
num_lags = 20
bin_size = 50
sttc_dt = 49
signal_len = 100000

n_trials = 30
trial_len = num_lags * bin_size

In [16]:
def get_lag_arrays(spike_train_l_: list, lag_1_idx_: int, lag_2_idx_: int, lag_shift_: int, zero_padding_len_: int):
    """

    :param spike_train_l_: list of spike trains, every element of the list contains spikes from 1 trial, length of the
    list is equal to the number of trials. Spike times are realigned (each trial starts at time 0).
    :param lag_1_idx_: int, index for the first lag.
    :param lag_2_idx_: int, index for the first lag.
    :param lag_shift_: int, shift for a time lag (in time points)
    :param zero_padding_len_: int, len of zero padding (in time points).
    :return: Two 1D arrays containing spike times for lag 1 and lag 2.
    """

    def extract_lag(spike_train: np.ndarray, lag_idx: int, lag_shift: int) -> np.ndarray:
        """Extract spikes corresponding to a specific lag."""
        start = lag_idx * lag_shift
        end = start + lag_shift
        return spike_train[(spike_train > start) & (spike_train <= end)]

    def add_spacing(lag_list: list, spacing: int) -> list:
        """Add zero-padding spacing to lag arrays."""
        return [lag + i * spacing for i, lag in enumerate(lag_list)]

    # Extract spikes for both lags
    first_lag_l = [extract_lag(trial, lag_1_idx_, lag_shift_) for trial in spike_train_l_]
    second_lag_l = [extract_lag(trial, lag_2_idx_, lag_shift_) for trial in spike_train_l_]

    # Add padding zeros
    first_lag_spaced = add_spacing(first_lag_l, zero_padding_len_)
    second_lag_spaced = add_spacing(second_lag_l, zero_padding_len_)

    # Flatten arrays to 1D
    lag1_l = np.hstack(first_lag_spaced).tolist() if first_lag_spaced else []
    lag2_l = np.hstack(second_lag_spaced).tolist() if second_lag_spaced else []
    return lag1_l, lag2_l

In [34]:
def acf_sttc_trial_avg_v3(spike_train_l_: list, n_lags_: int, lag_shift_: int, sttc_dt_: int, zero_padding_len_: int,
                          verbose_: bool = True):
    """
    Trial average autocorrelation using STTC. T term is calculated as in sttc_trail_concat but now trials are chunks of
    lag_shift size.

    :param sttc_dt_:
    :param spike_train_l_: list of spike trains, every element of the list contains spikes from 1 trial, length of the
    list is equal to the number of trials. Spike times are realigned (each trial starts at time 0).
    :param n_lags_: int, number of lags
    :param lag_shift_:
    :param zero_padding_len_:
    :param verbose_:
    :return:
    """
    def extract_lag(spike_train: np.ndarray, lag_idx: int, lag_shift: int) -> np.ndarray:
        """Extract spikes corresponding to a specific lag."""
        start = lag_idx * lag_shift
        end = start + lag_shift
        return spike_train[(spike_train > start) & (spike_train <= end)]

    def calculate_t_term(spike_train_l, trial_len, dt, verbose):
        abs_time_sum = sum(sttc_calculate_t(spike_train, len(spike_train), dt, 0, trial_len, verbose)[0]
                           for spike_train in spike_train_l)
        return abs_time_sum / (len(spike_train_l) * trial_len)

    if verbose_:
        print('Processing {} trials: n lags {}, lag shift {}, sttc dt {}, zero padding len {}'.
              format(len(spike_train_l_), n_lags_, lag_shift_, sttc_dt_, zero_padding_len_))
    acf_matrix = np.zeros((n_lags_, n_lags_))

    # t_start = 0
    # t_stop = (len(spike_train_l_) - 1) * zero_padding_len_ + lag_shift_
    # if verbose_:
    #     print(t_start, t_stop, len(spike_train_l_))

    for i in np.arange(n_lags_ - 1):
        for j in np.arange(i + 1, n_lags_):  # filling i-th row
            # print('i = {}, j = {}'.format(i, j))
            # get arrays for T term calculation - without zero padding
            # Extract spikes for both lags
            first_lag_l = [extract_lag(trial, i, lag_shift_) for trial in spike_train_l_]
            second_lag_l = [extract_lag(trial, j, lag_shift_) for trial in spike_train_l_]
            # print('first_lag_l: {}'.format(first_lag_l))
            # print('second_lag_l: {}'.format(second_lag_l))
            first_lag_l_aligned = [trial - lag_shift_ * i for trial in first_lag_l]
            first_lag_2_aligned = [trial - lag_shift_ * j for trial in second_lag_l]
            # print('first_lag_l_aligned: {}'.format(first_lag_l_aligned))
            # print('first_lag_2_aligned: {}'.format(first_lag_2_aligned))
            l1_t = calculate_t_term(first_lag_l_aligned, lag_shift_, sttc_dt_, verbose_)
            l2_t = calculate_t_term(first_lag_2_aligned, lag_shift_, sttc_dt_, verbose_)

            # get arrays for sttc - with zero padding
            lag_1_spikes_l, lag_2_spikes_l = get_lag_arrays(spike_train_l_, i, j,
                                                            lag_shift_=lag_shift_, zero_padding_len_=zero_padding_len_)
            #print('lag_1_spikes_l: {}'.format(lag_1_spikes_l))
            #print('lag_2_spikes_l: {}'.format(lag_2_spikes_l))
            l1_aligned = [spike - lag_shift_ * i for spike in lag_1_spikes_l]
            l2_aligned = [spike - lag_shift_ * j for spike in lag_2_spikes_l]
            #print('l1_aligned: {}'.format(l1_aligned))
            #print('l2_aligned: {}'.format(l2_aligned))
            sttc_lag = sttc_fixed_2t(l1_aligned, l2_aligned, sttc_dt_, t_a_=l1_t, t_b_=l2_t, verbose_=verbose_)
            acf_matrix[i, j] = sttc_lag

    np.fill_diagonal(acf_matrix, 1)

    acf_average = np.zeros((n_lags_,))
    for i in range(n_lags_):
        acf_average[i] = np.nanmean(np.diag(acf_matrix, k=i))

    return acf_matrix, acf_average

In [31]:
spikes_trials_30

array([array([247, 271, 279, 291, 294, 334, 390, 395, 412, 419, 421, 428, 433,
              453, 483, 496, 499, 573, 686, 716, 722, 731, 773, 776, 778, 805,
              817, 828, 829, 846, 847, 861, 864, 867, 869, 875, 878, 879, 883,
              895], dtype=int64)                                              ,
       array([ 25,  40,  50,  54,  59,  61,  63, 166, 167, 173, 175, 452, 453,
              461, 462, 463, 474, 505, 523, 528, 532, 534, 537, 544, 545, 551,
              557, 562, 564, 569, 575, 576, 585, 587, 594, 598, 620, 661, 665,
              911], dtype=int64)                                              ,
       array([128, 144, 170, 195, 209, 210, 216, 223, 224, 232, 237, 244, 245,
              250, 251, 259, 265, 280, 288, 303, 304, 331, 357, 358, 381, 392,
              418, 422, 428, 433, 438, 442, 443, 450, 452, 453, 481, 483, 487,
              488, 489, 492, 503, 507, 511, 516, 517, 519, 521, 526, 529, 531,
              533, 537, 542, 544, 550, 553, 561, 5

In [35]:
sttc_matrix_trail_avg, sttc_average_trial_avg = acf_sttc_trial_avg_v3(spikes_trials_30, 
                                                                   n_lags_ = num_lags,
                                                                   lag_shift_=bin_size, 
                                                                   zero_padding_len_=250, 
                                                                   sttc_dt_=sttc_dt, 
                                                                   verbose_=False)
spike_train_trial_avg_sttc_popt, _ = curve_fit(func_single_exp_monkey_like, np.linspace(0,19,20), sttc_average_trial_avg, maxfev=5000)
spike_train_trial_avg_sttc_tau_ms = (1/spike_train_trial_avg_sttc_popt[1]) * bin_size
print(spike_train_trial_avg_sttc_tau_ms)

66.55308493335758


In [17]:
spikes_trials_30

array([array([247, 271, 279, 291, 294, 334, 390, 395, 412, 419, 421, 428, 433,
              453, 483, 496, 499, 573, 686, 716, 722, 731, 773, 776, 778, 805,
              817, 828, 829, 846, 847, 861, 864, 867, 869, 875, 878, 879, 883,
              895], dtype=int64)                                              ,
       array([ 25,  40,  50,  54,  59,  61,  63, 166, 167, 173, 175, 452, 453,
              461, 462, 463, 474, 505, 523, 528, 532, 534, 537, 544, 545, 551,
              557, 562, 564, 569, 575, 576, 585, 587, 594, 598, 620, 661, 665,
              911], dtype=int64)                                              ,
       array([128, 144, 170, 195, 209, 210, 216, 223, 224, 232, 237, 244, 245,
              250, 251, 259, 265, 280, 288, 303, 304, 331, 357, 358, 381, 392,
              418, 422, 428, 433, 438, 442, 443, 450, 452, 453, 481, 483, 487,
              488, 489, 492, 503, 507, 511, 516, 517, 519, 521, 526, 529, 531,
              533, 537, 542, 544, 550, 553, 561, 5