In [None]:
import datajoint as dj
import numpy as np
import matplotlib.pyplot as plt
import plotly
import plotly.graph_objs as go
from ibl_pipeline.plotting import plotting_utils_ephys as putils
from scipy.signal import convolve, gaussian
from ibl_pipeline.plotting import ephys as ephys_plotting

In [None]:
ephys = dj.create_virtual_module('ephys', 'test_ibl_ephys')
behavior = dj.create_virtual_module('behavior', 'test_ibl_behavior')

In [None]:
# old, only compute psth
def compute_psth(trials, trial_type, align_event, bin_size=0.025,
                 smoothing=0.025, x_lim=[-1, 1], as_dict=True):

    if trial_type == 'left':
        color = 'green'
    elif trial_type == 'right':
        color = 'blue'
    elif trial_type == 'all':
        color = 'black'
    elif trial_type == 'incorrect':
        color = 'red'
    else:
        raise NameError('Invalid type name')

    n_offset = 5 * int(np.ceil(smoothing / bin_size))  # get rid of boundary effects for smoothing
    n_bins_pre = int(np.ceil(np.negative(x_lim[0]) / bin_size)) + n_offset
    n_bins_post = int(np.ceil(x_lim[1] / bin_size)) + n_offset
    n_bins = n_bins_pre + n_bins_post

    spk_times = trials.fetch('trial_spike_times')
    hist = np.histogram(np.hstack(spk_times),
                        range=x_lim,
                        bins=n_bins)

    mean_fr = np.divide(hist[0], len(spk_times)*bin_size)
    time = hist[1]
    time_bins = (time[:-1] + time[1:])/2
    # build gaussian kernel
    if smoothing > 0:
        w = n_bins - 1 if n_bins % 2 == 0 else n_bins
        window = signal.gaussian(w, std=smoothing / bin_size)
        window /= np.sum(window)

    psth = signal.convolve(mean_fr, window, mode='same', method='auto')

    data = go.Scatter(
        x=list(time_bins),
        y=list(psth),
        mode='lines',
        marker=dict(
            size=6,
            color=color),
        name='{} trials'.format(trial_type)
    )
    if as_dict:
        return data
    else:
        return list(time_bins), list(psth)

In [None]:
# new, psth with error bars
def compute_psth_with_errorbar(trials, trial_type, align_event, bin_size=0.025,
                 smoothing=0.025, x_lim=[-1, 1], as_dict=True):

    if trial_type == 'left':
        color = 'green'
        err_color = 'rgba(0, 255, 0, 0.2)'
    elif trial_type == 'right':
        color = 'blue'
        err_color = 'rgba(0, 0, 255, 0.2)'
    elif trial_type == 'all':
        color = 'black'
        err_color = 'rgba(0, 0, 0, 0.2)'
    elif trial_type == 'incorrect':
        color = 'red'
        err_color = 'rgba(255, 0, 0, 0.2)'
    else:
        raise NameError('Invalid type name')

    # set up bins
    n_offset = 5 * int(np.ceil(smoothing / bin_size))  # get rid of boundary effects for smoothing
    n_bins_pre = int(np.ceil(np.negative(x_lim[0]) / bin_size)) + n_offset
    n_bins_post = int(np.ceil(x_lim[1] / bin_size)) + n_offset
    n_bins = n_bins_pre + n_bins_post
    
    # this is bin edges
    bins = np.arange(-n_bins_pre, n_bins_post + 1) * bin_size

    # spikes times for all trials
    spk_times = trials.fetch('trial_spike_times')
    
    # trial_id for each spike, flattened 
    trial_ids_flat = np.hstack([[i_trial] * len(spk_time) for i_trial, spk_time in enumerate(spk_times)])
    # flatten spk times
    spk_times_flat = np.hstack(spk_times)
    
    # filter out spike times that are not in this range
    rel_idxs = np.bitwise_and(spk_times_flat >= bins[0], spk_times_flat <= bins[-1])
    filtered_spike_times_flat = spk_times_flat[rel_idxs]
    filtered_trial_ids_flat = trial_ids_flat[rel_idxs]
    
    # ----- assign each spike into 2D bins, each trial and each time slot --------

    # bin id of each spike
    bin_id = (np.floor((filtered_spike_times_flat - np.min(bins)) / bin_size)).astype(np.int64)

    # trial id of each spike
    trial_scale, trial_id = np.unique(filtered_trial_ids_flat, return_inverse=True)

    # assign each spike a 1d index representing a combination of trial and time bin
    bin_num, trial_num = [bins.size, trial_scale.size]
    ind2d = np.ravel_multi_index(np.c_[trial_id, bin_id].T, dims=[trial_num, bin_num])

    # spike counts of each trial and each bin
    spike_counts = np.bincount(ind2d, minlength=bin_num * trial_num, weights=None).reshape(trial_num, bin_num)

    # get binned spikes as a 2D array n_trials x n_bins
    binned_spikes = spike_counts[:, :-1]
    
    # smooth with convolution
    if smoothing > 0:
        w = n_bins - 1 if n_bins % 2 == 0 else n_bins
        window = gaussian(w, std=smoothing / bin_size)
        window /= np.sum(window)
        binned_spikes_conv = np.zeros([trial_num, bin_num-1])
        for j in range(binned_spikes.shape[0]):
            binned_spikes_conv[j, :] = convolve(
                binned_spikes[j, :], window, mode='same', method='auto')
        binned_spikes = binned_spikes_conv
    
    mean_psth = np.mean(binned_spikes, axis=0)
    sem_psth = np.std(binned_spikes, axis=0)/np.sqrt(trial_num)
    
    mean_psth = mean_psth[n_offset:-n_offset]
    sem_psth = sem_psth[n_offset:-n_offset]
    
    upper_psth = mean_psth + sem_psth
    lower_psth = mean_psth - sem_psth
    
    # return the middle of each bin as the time
    time_bins = (bins[:-1] + bins[1:]) / 2
    time_bins = time_bins[n_offset:-n_offset]
    
    upper_bound = psth = go.Scatter(
        x=list(time_bins),
        y=list(upper_psth),
        mode='lines',
        marker=dict(color="#444"),
        fillcolor=err_color,
        line=dict(width=0),
        fill='tonexty',
        showlegend=False,
    )
    psth = go.Scatter(
        x=list(time_bins),
        y=list(mean_psth),
        mode='lines',
        marker=dict(
            size=6,
            color=color),
        fill='tonexty',
        fillcolor=err_color,
        name='{} trials, mean +/- s.e.m'.format(trial_type)
    )
    lower_bound = go.Scatter(
        x=list(time_bins),
        y=list(lower_psth),
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        showlegend=False,
    )
    
    if as_dict:
        return [lower_bound, psth, upper_bound]
    else:
        return list(time_bins), list(mean_psth), list(mean_psth+sem_psth), list(mean_psth-sem_psth)

In [None]:
from ibl_pipeline import subject
keys = (ephys.DefaultCluster & (subject.Subject & 'subject_nickname="CSHL047"')).fetch('KEY')
key = keys[25]

In [None]:
x_lim = [-1, 1]
cluster = ephys.DefaultCluster & key
event = (ephys.Event & 'event="stim on"').fetch1('KEY')
trials_all = (behavior.TrialSet.Trial * ephys.AlignedTrialSpikes & cluster).proj(
    'trial_start_time', 'trial_stim_on_time', 'trial_response_time', 'trial_feedback_time',
    'trial_response_choice', 'trial_spike_times',
    trial_duration='trial_end_time-trial_start_time',
    trial_signed_contrast='trial_stim_contrast_right - trial_stim_contrast_left'
) & 'trial_duration < 5' & 'trial_response_choice!="No Go"' & event

trials_left = trials_all & 'trial_response_choice="CW"' & 'trial_signed_contrast < 0'
trials_right = trials_all & 'trial_response_choice="CCW"' & 'trial_signed_contrast > 0'
trials_incorrect = trials_all - trials_right.proj() - trials_left.proj()

align_event = event['event']


data = []
if len(trials_left):
    data += putils.compute_psth_with_errorbar(trials_left, 'left', align_event)

if len(trials_right):
    data += putils.compute_psth_with_errorbar(trials_right, 'right', align_event)
    
if len(trials_incorrect):
    data += putils.compute_psth_with_errorbar(trials_incorrect, 'incorrect', align_event)

data += putils.compute_psth_with_errorbar(trials_all, 'all', align_event)


layout = go.Layout(
    width=700,
    height=370,
    margin=go.layout.Margin(
        l=50,
        r=30,
        b=40,
        t=80,
        pad=0
    ),
    title=dict(
        text='PSTH, aligned to {} time'.format(align_event),
        x=0.17,
        y=0.87
    ),
    xaxis=dict(
        title='Time (sec)',
        range=x_lim,
        showgrid=False
    ), 
    yaxis=dict(
        title='Firing rate (spks/sec)',
        showgrid=False
    ),
)

fig = go.Figure(data=data, layout=layout)
plotly.offline.iplot(fig)

In [None]:
import json
f = open("psth_combined_with_errorbars.json","w")
s = json.dumps(fig.to_plotly_json())
f.write(s)
f.close()

# Break the code of psth for debugging

In [None]:
# parameters
bin_size=0.025
smoothing=0.025 
x_lim=[-1, 1]

In [None]:
spk_times = trials_all.fetch('trial_spike_times')
# trial_id for each spike, flattened 
trial_ids_flat = np.hstack([[i_trial] * len(spk_time) for i_trial, spk_time in enumerate(spk_times)])
spk_times_flat = np.hstack(spk_times)

In [None]:
# set up bins
n_offset = 5 * int(np.ceil(smoothing / bin_size))  # get rid of boundary effects for smoothing
n_bins_pre = int(np.ceil(np.negative(x_lim[0]) / bin_size)) + n_offset
n_bins_post = int(np.ceil(x_lim[1] / bin_size)) + n_offset
n_bins = n_bins_pre + n_bins_post

bins = np.arange(-n_bins_pre, n_bins_post + 1) * bin_size

In [None]:
# filter out spike times that are not in this range
rel_idxs = np.bitwise_and(spk_times_flat >= bins[0], spk_times_flat <= bins[-1])

filtered_spike_times_flat = spk_times_flat[rel_idxs]
filtered_trial_ids_flat = trial_ids_flat[rel_idxs]

In [None]:
# bin id of each spike
bin_id = (np.floor((filtered_spike_times_flat - np.min(bins)) / bin_size)).astype(np.int64)

# trial id of each spike
trial_scale, trial_id = np.unique(filtered_trial_ids_flat, return_inverse=True)

bin_num, trial_num = [bins.size, trial_scale.size]
ind2d = np.ravel_multi_index(np.c_[trial_id, bin_id].T, dims=[trial_num, bin_num])

# spike counts of each trial and each bin
spike_counts = np.bincount(ind2d, minlength=bin_num * trial_num, weights=None).reshape(trial_num, bin_num)

binned_spikes = spike_counts[:, :-1]

In [None]:
# convolution
w = n_bins - 1 if n_bins % 2 == 0 else n_bins
window = gaussian(w, std=smoothing / bin_size)
window /= np.sum(window)
binned_spikes_conv = np.zeros([trial_num, bin_num-1])
for j in range(binned_spikes.shape[0]):
    binned_spikes_conv[j, :] = convolve(
        binned_spikes[j, :], window, mode='same', method='auto')

In [None]:
mean_spikes = np.mean(binned_spikes_conv, axis=0)
std_spikes = np.std(binned_spikes_conv, axis=0)/np.sqrt(mean_spikes.shape[0])
time = (bins[:-1] + bins[1:]) / 2

In [None]:
plt.plot(time, mean_spikes)
plt.fill_between(time, mean_spikes+std_spikes, mean_spikes-std_spikes, alpha=0.5)

In [None]:
bins

In [None]:
# compare to np.histogram
hist, time = np.histogram(
    spk_times_flat, range=x_lim, bins=n_bins)

time = (time[:-1] + time[1:]) / 2

In [None]:
plt.plot(time, hist)