In [2]:
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib import cm
from tqdm import tqdm_notebook as tqdm
%matplotlib inline

In [3]:
plt.rcParams['figure.figsize'] = (12,9)
plt.rcParams.update({
    k: 'xx-large' for k in plt.rcParams 
    if (
        k.endswith('labelsize') or
        k.endswith('fontsize')
    )
})

In [4]:
def brute_force(x, y, stim_times, winsize):
    Z, X, Y = [], [], []
    for s in stim_times:
        xx, yy = x - s, y - s
        
        pre = xx[xx <= 0]
        if len(pre) == 0:
            continue
            
        post = xx[xx > 0]
        if len(post) == 0:
            continue
            
        Z.append(pre.max())
        X.append(post.min())
        print(yy)
        print((yy >= 0) & (yy < winsize))
        Y.append(np.sum((yy >= 0) & (yy < winsize)))
    return np.array(Z), np.array(X), np.array(Y)

In [5]:
def calculate_regressors(x, y, stim_times, winsize):
    sub_x_p = np.subtract.outer(x, stim_times)
    sub_x_n = sub_x_p.copy()
    mask_x = sub_x_p > 0

    sub_x_p[~mask_x] = np.inf
    sub_x_n[mask_x] = -np.inf

    Z = np.nanmax(sub_x_n, axis=0)

    X = np.nanmin(sub_x_p, axis=0)

    sub_y = np.subtract.outer(y, stim_times)
    mask_y = (sub_y >= 0) & (sub_y < winsize)
    sub_y[~mask_y] = np.nan
    Y = np.sum(~np.isnan(sub_y), axis=0)
    mask = np.isfinite(Z) & np.isfinite(X)
    return Z[mask], X[mask], Y[mask]

In [6]:
def calculate_regressors_binary(x, y, stim_times, winsize):
    stim_times = stim_times.astype(float)
    
    src_x = np.searchsorted(x, stim_times, side='right')
    
    remove_idxs, = np.where((src_x==len(x)) | (src_x==0))
    src_x = np.delete(src_x, remove_idxs)
    stim_times = np.delete(stim_times, remove_idxs)
    Z = x[src_x-1] - stim_times
    X = x[src_x] - stim_times
    
    stim_win = np.insert(stim_times, np.arange(len(stim_times)) + 1, stim_times + winsize)
    src_y = np.searchsorted(y, stim_win, side='left')
    cnt_y = np.diff(src_y.reshape((int(len(src_y) / 2), 2)))
    Y = cnt_y.flatten()
    return Z, X, Y

In [46]:
s = np.array([1, 2, 3, 4, 5,  6]).astype(float)

x = np.array([0.1, 0.2, 1.2, 1.3, 2.1, 2.4, 4.4, 4.5, 5, 5.05, 6.01])

y = np.array([0.1, 0.2, 1.2, 1.22, 1.23, 1.3, 2.1, 2.2, 2.4, 3, 3.3, 5, 5.1, 6.3])

In [8]:
def prune(a, ref):
    b = np.concatenate(([False], np.diff(a) < ref))
    c = np.concatenate(([False], np.diff(b.astype(int)) > 0))
    d = a[~c]
    if any(np.diff(a) < ref):
        d = prune(d, ref)
    return d


def generate_stim_times(stim_rate, stim_isi_min, stop_time):
    stim_times = np.sort(np.random.uniform(
        0, stop_time, stim_rate * stop_time))
    return prune(stim_times, stim_isi_min)


def generate_neurons(stim_times, make_post=False, **p):
    n_stim = len(stim_times)
    idxs = np.random.permutation(np.arange(n_stim).astype(int))
    n_stim_spikes = int(n_stim * p['stim_hit_chance'])
    idxs_stim_spikes = idxs[:n_stim_spikes]

    spikes = np.sort(np.concatenate([
        stim_times[idxs_stim_spikes] + p['stim_latency'],
        np.random.uniform(0, p['stop_time'], p['pre_rate'] * p['stop_time'])
    ]))
    pre_spikes = prune(spikes, p['refractory'])
    n_pre_spikes = len(pre_spikes)
    if make_post:
        n_post_spikes = int(n_pre_spikes * p['pre_hit_chance'])
        idxs_post_spikes = np.random.permutation(np.arange(n_pre_spikes).astype(int))[:n_post_spikes]

        post_spikes = np.sort(np.concatenate([
            pre_spikes[idxs_post_spikes] + p['latency'],
            np.random.uniform(0, p['stop_time'], int(p['post_rate'] * p['stop_time']))
        ]))
        post_spikes = prune(post_spikes, p['refractory'])
    
        return pre_spikes, post_spikes
    else:
        return pre_spikes

In [9]:
stim_params = {
    'stop_time': 1000, # seconds
    'stim_rate': 30, # rate of stimulation (gets reduced by pruning for minimum inter stimulus interval)
    'stim_isi_min': 50e-3, # minimum inter stimulus interval
}
neuron_params = {
    'refractory': 4e-3, # 4 ms
    'latency': 6e-3, # post response delay
    'pre_hit_chance': .5, # fraction of spikes that are driven by the presynaptic neuron
    'post_rate': 5, # Hz
    'pre_rate': 5, # base rate
    'stim_hit_chance': .8, # fraction of spikes that are driven by the stimulation
    'stim_latency': 5e-4, # latency from stim to pre response
    'stop_time': stim_params['stop_time'],
}

s = generate_stim_times(**stim_params) 

x, y = generate_neurons(s, make_post=True, **neuron_params)

In [47]:
# %%timeit 
Zb, Xb, Yb = brute_force(x,y,s,0.3)
# Zb, Xb, Yb

[-0.9  -0.8   0.2   0.22  0.23  0.3   1.1   1.2   1.4   2.    2.3   4.
  4.1   5.3 ]
[False False  True  True  True False False False False False False False
 False False]
[-1.9  -1.8  -0.8  -0.78 -0.77 -0.7   0.1   0.2   0.4   1.    1.3   3.
  3.1   4.3 ]
[False False False False False False  True  True False False False False
 False False]
[-2.9  -2.8  -1.8  -1.78 -1.77 -1.7  -0.9  -0.8  -0.6   0.    0.3   2.
  2.1   3.3 ]
[False False False False False False False False False  True  True False
 False False]
[-3.9  -3.8  -2.8  -2.78 -2.77 -2.7  -1.9  -1.8  -1.6  -1.   -0.7   1.
  1.1   2.3 ]
[False False False False False False False False False False False False
 False False]
[-4.9  -4.8  -3.8  -3.78 -3.77 -3.7  -2.9  -2.8  -2.6  -2.   -1.7   0.
  0.1   1.3 ]
[False False False False False False False False False False False  True
  True False]
[-5.9  -5.8  -4.8  -4.78 -4.77 -4.7  -3.9  -3.8  -3.6  -3.   -2.7  -1.
 -0.9   0.3 ]
[False False False False False False False False False 

In [48]:
# %%timeit 
Zv, Xv, Yv = calculate_regressors(x,y,s,0.3)
# Zv, Xv, Yv

In [49]:
# %%timeit 
Zv2, Xv2, Yv2 = calculate_regressors_binary(x,y,s,0.3)

In [50]:
np.array_equal(Zb, Zv), np.array_equal(Xb, Xv), np.array_equal(Yb, Yv)

(True, True, True)

In [51]:
np.array_equal(Zb, Zv2), np.array_equal(Xb, Xv2), np.array_equal(Yb, Yv2)

(True, True, False)

In [52]:
Yb

array([3, 2, 2, 0, 2, 1])

In [53]:
Yv

array([3, 2, 2, 0, 2, 1])

In [54]:
Xb

array([0.2 , 0.1 , 1.4 , 0.4 , 0.05, 0.01])

In [55]:
Yv2

array([3, 2, 1, 0, 2, 0])

## 

In [56]:
Zv2

array([-0.8 , -0.7 , -0.6 , -1.6 ,  0.  , -0.95])

In [57]:
Xv2

array([0.2 , 0.1 , 1.4 , 0.4 , 0.05, 0.01])