In [None]:
# Enable interactive plots with backend 'notebook'
%matplotlib notebook

# print date and time of script execution
import datetime
print("\nNotebook executed at at {} in following directory:".format(datetime.datetime.now()))
%cd /home/luye/workspace/bgcellmodels/GilliesWillshaw/

In [None]:
import neo

import elephant.spike_train_generation as stg
import elephant.spike_train_dissimilarity as stds
import elephant.spike_train_surrogates as surr

import quantities as pq
from quantities import ms, s, Hz

import numpy as np
from matplotlib import pyplot as plt

In [None]:
def plot_rastergram(spiketrain_list):
    """
    Plot simple rastergram
    """
    plt.figure()
    for i, spiketrain in enumerate(spiketrain_list):
        t = spiketrain.rescale(ms)
        plt.plot(t, i * np.ones_like(t), 'k.', markersize=2)
    
    plt.axis('tight')
    
    plt.xlabel('Time (ms)', fontsize=16)
    plt.ylabel('Spike Train Index', fontsize=16)
    plt.gca().tick_params(axis='both', which='major', labelsize=14)
    
    plt.show()
    
allcolors = [
    '#7742f4', # Dark purple
    [0.90,0.76,0.00], # Ochre
    [0.42,0.83,0.59], # soft pastel green
    [0.90,0.32,0.00], # pastel red brick
    [0.90,0.59,0.00], # OrangeBrown
    '#f442c5', # Pink
    '#c2f442', # Lime
    [1.00,0.85,0.00], # hard yellow
    [0.33,0.67,0.47], # dark pastel green
    [1.00,0.38,0.60], [0.57,0.67,0.33], [0.5,0.2,0.0],
    [0.71,0.82,0.41], [0.0,0.2,0.5],
]

In [None]:
# Make test data (spike trains)
st21 = stg.homogeneous_poisson_process(50*Hz, t_start=0*ms, t_stop=2000*ms)
st22 = stg.homogeneous_poisson_process(40*Hz, 0*ms, 2000*ms)
st23 = stg.homogeneous_poisson_process(30*Hz, 0*ms, 2000*ms)

all_test_data = [st21, st22, st23]

plot_rastergram(all_test_data)

# Sensitivity to jitter

In [None]:
st_target = st22

all_shift_cost_ms = np.array([5.0, 10.0, 15.0, 20.0, 50.0, 100.0, 150.0, 200.0, 500.0, 1000.0])
all_shift_ms = np.arange(0.5, 50, 0.5)

all_sts = []
lines = []

plt.figure()

for j, shift_cost_ms in enumerate(all_shift_cost_ms):
    
    shift_cost_hz = 1.0/(shift_cost_ms*1e-3) * pq.Hz
    
    all_vp_dist = np.zeros_like(all_shift_ms)

    for i, max_shift_ms in enumerate(all_shift_ms):

        # Each spike is shifted randomly within interval
        sts_shifted = surr.dither_spikes(
                        st_target, 
                        dither = max_shift_ms*pq.ms)
        
        st_shifted = sts_shifted[0]
        if j==0: all_sts.append(st_shifted)
        
        # Calculate spike train distance with current shift cost
        dist_mat = stds.victor_purpura_dist(
                        [st_target, st_shifted], 
                        q = shift_cost_hz, 
                        algorithm = 'fast')

        all_vp_dist[i] = dist_mat[0, 1]
    
    line_color = allcolors[j % len(allcolors)]
    line, = plt.plot(all_shift_ms, all_vp_dist, color=line_color, label=str(shift_cost_ms))
    lines.append(line)
    
plt.xlabel('Max spike time shift (ms)')
plt.ylabel('Victor-Purpura distance')
plt.legend(lines, [l.get_label() for l in lines])

plot_rastergram(all_sts)

# Sensitivity to deletions

In [None]:
st_target = st22
st_data = np.array(st_target)

all_shift_cost_ms = np.array([5.0, 10.0, 15.0, 20.0, 50.0, 100.0, 150.0, 200.0, 500.0, 1000.0])
all_num_deletions = range(1,50)
all_del_indices = np.random.choice(len(st_target), max(all_num_deletions), replace=False)
print(all_del_indices)

all_sts = []
lines = []

plt.figure()

for j, shift_cost_ms in enumerate(all_shift_cost_ms):
    
    shift_cost_hz = 1.0/(shift_cost_ms*1e-3) * pq.Hz
    
    all_vp_dist = np.zeros_like(all_num_deletions)

    for i, num_deletions in enumerate(all_num_deletions):
        
        st_shifted = neo.SpikeTrain(
                        np.delete(st_data, all_del_indices[0:num_deletions]),
                        t_start=st_target.t_start,
                        t_stop=st_target.t_stop,
                        units=st_target.units)
        
        if j==0: all_sts.append(st_shifted)
        
        # Calculate spike train distance with current shift cost
        dist_mat = stds.victor_purpura_dist(
                        [st_target, st_shifted], 
                        q = shift_cost_hz, 
                        algorithm = 'fast')

        all_vp_dist[i] = dist_mat[0, 1]
    
    line_color = allcolors[j % len(allcolors)]
    line, = plt.plot(all_num_deletions, all_vp_dist, color=line_color, label=str(shift_cost_ms))
    lines.append(line)
    
plt.xlabel('Number of spike deletions')
plt.ylabel('Victor-Purpura distance')
plt.suptitle('Spike distance for different values of shift cost parameter')
plt.legend(lines, [l.get_label() for l in lines])

plot_rastergram(all_sts)

# Sensitivity to insertions

In [None]:
# TODO

# Sensitivity to jitter (Kreuz ISI)

In [None]:
st22.as_array()
float(st22.t_start)

In [None]:
import pyspike

st_target = st22
spk_target = pyspike.SpikeTrain(
                        st_target.as_array(),
                        [st_target.t_start, st_target.t_stop],
                        is_sorted=True)


all_shift_ms = np.arange(0.5, 50, 0.5)
all_dist = np.zeros_like(all_shift_ms)

all_sts = []
lines = []

plt.figure()


for i, max_shift_ms in enumerate(all_shift_ms):

    # Each spike is shifted randomly within interval
    sts_shifted = surr.dither_spikes(
                    st_target, 
                    dither = max_shift_ms*pq.ms)

    st_shifted = sts_shifted[0]
    all_sts.append(st_shifted)
    spk_shifted = pyspike.SpikeTrain(
                        st_shifted.as_array(),
                        [st_shifted.t_start, st_shifted.t_stop],
                        is_sorted=True)

    # Calculate Kreuz ISI distance
    all_dist[i] = pyspike.isi_distance(spk_target, spk_shifted)

line_color = allcolors[0 % len(allcolors)]
line, = plt.plot(all_shift_ms, all_dist, color=line_color, label='Kreuz ISI')
lines.append(line)
    
plt.xlabel('Max spike time shift (ms)')
plt.ylabel('Victor-Purpura distance')
plt.legend(lines, [l.get_label() for l in lines])

plot_rastergram(all_sts)