# Notebook for exploring claustrum data

### install and import essential packages and data

In [1]:
#!pip install ipywidgets
#!jupyter nbextension enable --py widgetsnbextension

import scipy as sp
import scipy.io
import scipy.stats
import os
import numpy as np
import pandas as pd

from collections import Iterable
import matplotlib.pylab as mpl
%load_ext autoreload

mat = sp.io.loadmat('Log_Claustrum4.mat')
mat2 = sp.io.loadmat('Opto_log_Claustrum4.mat')
log = mat['fullLog']
log_df = pd.DataFrame(log, columns=['mouse_name', 'date','block_type', 'trial_type', 'touch_stimulus',
                                    'vis_stimulus', 'response','correct', 'trial_num', 'stim_onset', 'stim_offset', 
                                    'licks_right', 'licks_left', 'spike_times', 'cluster_name' ])
opto_log = mat2['fullOptoTable']
opto_log_df = pd.DataFrame(opto_log, columns=['mouse_name', 'date', 'cluster_name', 'opto_stim_onsets',
                                              'opto_stim_offsets'])

In [2]:
log_df

Unnamed: 0,mouse_name,date,block_type,trial_type,touch_stimulus,vis_stimulus,response,correct,trial_num,stim_onset,stim_offset,licks_right,licks_left,spike_times,cluster_name
0,[Claustrum4],[06-05-17],[Whisker],[Stim_Som_NoCue],[SineAmp0p875Freq20Cyc3],[Amp0p5Dur0p150],[[0]],[[0]],[[[[3]]]],[[[[2.43213333]]]],[[[[2.5796]]]],[[[]]],[[[]]],"[[0.4044666666666667], [1.0637666666666667], [...",[TT1clst1]
1,[Claustrum4],[06-05-17],[Whisker],[Stim_Som_NoCue],[SineAmp0p875Freq20Cyc3],[Amp0p5Dur0p150],[[0]],[[0]],[[[[3]]]],[[[[2.43213333]]]],[[[[2.5796]]]],[[[]]],[[[]]],"[[0.15036666666666668], [0.26516666666666666],...",[TT2clst1]
2,[Claustrum4],[06-05-17],[Whisker],[Stim_Som_NoCue],[SineAmp0p875Freq20Cyc3],[Amp0p5Dur0p150],[[0]],[[0]],[[[[3]]]],[[[[2.43213333]]]],[[[[2.5796]]]],[[[]]],[[[]]],"[[0.10716666666666667], [0.22166666666666668],...",[TT2clst2]
3,[Claustrum4],[06-05-17],[Whisker],[Stim_Som_NoCue],[SineAmp0p875Freq20Cyc3],[Amp0p5Dur0p150],[[0]],[[0]],[[[[3]]]],[[[[2.43213333]]]],[[[[2.5796]]]],[[[]]],[[[]]],"[[0.013466666666666667], [0.01686666666666667]...",[TT2clst3]
4,[Claustrum4],[06-05-17],[Whisker],[Stim_Som_NoCue],[SineAmp0p875Freq20Cyc3],[Amp0p5Dur0p150],[[0]],[[0]],[[[[3]]]],[[[[2.43213333]]]],[[[[2.5796]]]],[[[]]],[[[]]],"[[0.15036666666666668], [0.3058666666666667], ...",[TT2clst4]
5,[Claustrum4],[06-05-17],[Whisker],[Stim_Som_NoCue],[SineAmp0p875Freq20Cyc3],[Amp0p5Dur0p150],[[0]],[[0]],[[[[3]]]],[[[[2.43213333]]]],[[[[2.5796]]]],[[[]]],[[[]]],"[[0.10106666666666667], [0.11446666666666667],...",[TT2clst5]
6,[Claustrum4],[06-05-17],[Whisker],[Stim_Som_NoCue],[SineAmp0p875Freq20Cyc3],[Amp0p5Dur0p150],[[0]],[[0]],[[[[3]]]],[[[[2.43213333]]]],[[[[2.5796]]]],[[[]]],[[[]]],"[[0.15036666666666668], [0.3058666666666667], ...",[TT2clst6]
7,[Claustrum4],[06-05-17],[Whisker],[Stim_Som_NoCue],[SineAmp0p875Freq20Cyc3],[Amp0p5Dur0p150],[[0]],[[0]],[[[[3]]]],[[[[2.43213333]]]],[[[[2.5796]]]],[[[]]],[[[]]],[],[TT3clst1]
8,[Claustrum4],[06-05-17],[Whisker],[Stim_Som_NoCue],[SineAmp0p875Freq20Cyc3],[Amp0p5Dur0p150],[[0]],[[0]],[[[[3]]]],[[[[2.43213333]]]],[[[[2.5796]]]],[[[]]],[[[]]],"[[0.08076666666666667], [0.20926666666666668],...",[TT4clst1]
9,[Claustrum4],[06-05-17],[Whisker],[Stim_Som_NoCue],[SineAmp0p875Freq20Cyc3],[Amp0p5Dur0p150],[[0]],[[0]],[[[[3]]]],[[[[2.43213333]]]],[[[[2.5796]]]],[[[]]],[[[]]],"[[0.18666666666666668], [0.5661666666666667], ...",[TT4clst2]


In [4]:
import scipy as sp
import scipy.io
import scipy.stats
import os
import numpy as np
import pandas as pd
import glob
import csv
from tqdm import tnrange, tqdm_notebook
from collections import Iterable
import matplotlib.pylab as mpl

%load_ext autoreload

task_data_files = glob.glob("Log_Claustrum*")
opto_data_files = glob.glob("Opto_log_Claustrum*")
opto_spike_files = glob.glob("optoSpikes_log_Claustrum*")
opto_wave_files = glob.glob("waveform_log_Claustrum*")

column_names1 =['mouse_name', 'date','block_type', 'trial_type', 'touch_stimulus',
                'vis_stimulus', 'response','correct', 'trial_num', 'stim_onset', 'stim_offset', 
                'licks_right', 'licks_left', 'spike_times', 'cluster_name' ]
column_names2 = ['mouse_name', 'date', 'cluster_name', 'opto_stim_onsets','opto_stim_offsets']
column_names3 = np.concatenate((['mouse_name', 'date', 'cluster_name', 'spikes'],
                                ['waveform_'+str(i) for i in range(128)]))

log_df = pd.DataFrame([], columns = column_names1)
opto_log_df = pd.DataFrame([], columns = column_names2)
opto_spikes_df = pd.DataFrame()
opto_waves_df = pd.DataFrame()

for file_num in tnrange(len(glob.glob("Log_Claustrum*"))):
    mat = sp.io.loadmat(task_data_files[file_num])
    mat2 = sp.io.loadmat(opto_data_files[file_num])
    
    log = mat['fullLog']
    log2 = mat2['fullOptoTable']

    indv_log_df = pd.DataFrame(log, columns = column_names1)

    log_df = pd.concat([log_df,indv_log_df])
    opto_log_df = pd.concat([opto_log_df,pd.DataFrame(log2, columns = column_names2)])

    opto_waves_df = pd.concat([opto_waves_df, pd.read_csv(opto_wave_files[file_num])])
    opto_spikes_df = pd.concat([opto_spikes_df, pd.read_csv(opto_spike_files[file_num])])

opto_log_df = opto_log_df.reset_index(drop = True) 
opto_spikes_df = opto_spikes_df.reset_index(drop = True)

##spikes/waves were listed in opto_waves_df by finding all spikes within 0.025s of laser pulse
##this can cause spikes to be listed more than once when frequency of laser stim was >50hz
##This will remove double listed spikes:
opto_waves_df = opto_waves_df.drop_duplicates().reset_index(drop = True)
                     

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload



### clean up data tables

In [5]:
for col in [0,1,2,3,4,5,6,6,7,7,8,8,8,8,9,9,9,9,10,10,10,10,11,11,12,12,14]:
        log_df.ix[:,col] = log_df.ix[:,col].str[0]
for col in [0,1,2,3,4]:
        opto_log_df.ix[:,col] = opto_log_df.ix[:,col].str[0]
opto_spikes_df['cluster_name'] = opto_spikes_df['cluster_name'].apply(lambda y: 'T'+ y)
opto_waves_df['cluster_name'] = opto_waves_df['cluster_name'].apply(lambda y: 'T'+ y)

In [9]:
log_df.loc[log_df['mouse_name'] == 'Claustrum4', 'date'].unique()

array(['06-05-17'], dtype=object)

In [10]:
log_df.loc[log_df['mouse_name'] == 'Claustrum5', 'date'].unique()

array(['05-31-17', '06-04-17', '06-05-17'], dtype=object)

In [14]:
opto_waves_df

Unnamed: 0,mouse_name,date,cluster_name,spikes,waveform,waveform_1,waveform_2,waveform_3,waveform_4,waveform_5,...,waveform_118,waveform_119,waveform_120,waveform_121,waveform_122,waveform_123,waveform_124,waveform_125,waveform_126,waveform_127
0,Claustrum4,06-05-17,TT1clst1,2522.276,-0.008088,0.004554,0.004394,0.007818,0.011219,0.021525,...,0.001087,0.000291,0.009148,0.005648,0.013315,0.019692,0.021877,0.008777,0.000690,0.011965
1,Claustrum4,06-05-17,TT1clst1,2524.386,0.001352,0.023582,0.013483,-0.001461,-0.009758,-0.004207,...,-0.013994,-0.008501,0.011865,0.015340,-0.000588,-0.005967,0.016675,0.018509,0.013400,0.014330
2,Claustrum4,06-05-17,TT1clst1,2524.624,0.008524,0.001863,-0.010096,-0.000834,0.002640,-0.003382,...,-0.010672,0.003892,0.007704,0.009081,0.010488,0.016148,0.013898,0.023750,0.023985,0.026323
3,Claustrum4,06-05-17,TT1clst1,2524.808,0.007767,0.007993,0.000245,-0.005635,-0.009261,0.000478,...,-0.013415,-0.017957,-0.016589,-0.011438,0.006000,0.008253,0.005871,-0.001513,-0.004423,-0.002836
4,Claustrum4,06-05-17,TT1clst1,2524.822,-0.006845,-0.010927,-0.005218,0.008345,0.017392,0.007991,...,0.005638,-0.000432,-0.004517,-0.010900,-0.013720,0.013941,0.020712,0.009819,0.010520,0.000982
5,Claustrum4,06-05-17,TT1clst1,2529.243,0.025412,0.012938,0.017403,0.022814,0.021461,0.022111,...,-0.004757,-0.010860,-0.011437,0.006016,0.006035,0.003171,0.015583,0.024867,0.026852,0.012781
6,Claustrum4,06-05-17,TT1clst1,2529.572,-0.000950,0.002971,-0.000791,-0.004428,-0.002277,-0.007588,...,-0.001214,-0.007043,-0.008253,-0.029316,-0.027044,-0.015570,-0.014879,-0.013069,-0.012471,-0.004299
7,Claustrum4,06-05-17,TT1clst1,2529.795,0.015461,0.011365,0.007959,0.001821,0.006104,0.015241,...,0.003850,-0.001828,0.007650,0.017107,0.022186,0.027597,0.032781,0.032108,0.028619,0.022526
8,Claustrum4,06-05-17,TT1clst1,2529.893,0.024277,0.007131,0.003902,0.003075,-0.002575,-0.006523,...,-0.009215,-0.003666,-0.008092,-0.007454,0.012605,0.013986,0.011632,0.019604,0.018032,0.012201
9,Claustrum4,06-05-17,TT1clst1,2529.904,0.015700,0.003786,-0.001432,0.017299,0.014247,0.002761,...,-0.021945,-0.020248,-0.012398,0.004645,0.002819,0.002912,0.006990,0.000837,0.010696,0.012797


In [15]:
opto_log_df['waveforms1'] = np.nan
opto_log_df['waveforms2'] = np.nan
opto_log_df['waveforms3'] = np.nan
opto_log_df['waveforms4'] = np.nan

opto_log_df['waveforms1'] = opto_log_df['waveforms'].apply(lambda y: y[:,0,:])
opto_log_df['waveforms2'] = opto_log_df['waveforms'].apply(lambda y: y[:,1,:])
opto_log_df['waveforms3'] = opto_log_df['waveforms'].apply(lambda y: y[:,2,:])
opto_log_df['waveforms4'] = opto_log_df['waveforms'].apply(lambda y: y[:,3,:])
del opto_log_df['waveforms']

TypeError: 'float' object is not subscriptable

### sort optogenetic stimulus pulses into groups based on frequency of stimulation

In [None]:
unique_sessions = opto_log_df[['mouse_name', 'date']].drop_duplicates()

opto_log_df['first_last_opto_pulses'] = np.nan
opto_log_df['grouped_opto_pulses'] = np.nan

for session in range(unique_sessions.shape[0]):
    session_row_ind = (opto_log_df.loc[:,['mouse_name', 'date']] == unique_sessions.iloc[session]).all(axis=1)
    rows = opto_log_df.loc[session_row_ind]
    
    ISIs = np.around(rows.iloc[0,3][1:]-rows.iloc[0,3][0:-1], 4)
    ISIs = np.concatenate(([ISIs[0]], ISIs, [ISIs[-1]]))
    unique_ISIs = np.unique(ISIs, return_counts = True)
    ind = unique_ISIs[1]>50
    unique_ISIs = unique_ISIs[0][ind]

    first_last_opto_pulses = {}
    grouped_opto_pulses = {}
    for isi in unique_ISIs:
        ISI_category = np.where(np.absolute(ISIs-isi) <= 0.001)[0]
        opto_pulse_inds= np.array(range(np.min(ISI_category),np.max(ISI_category)))
        grouped_opto_pulses[1/isi] = rows.iloc[0,3][opto_pulse_inds]

        IBI_ind = ISIs[opto_pulse_inds] != isi
        first_pulse_ind = np.concatenate(([opto_pulse_inds[0]],opto_pulse_inds[IBI_ind]))
        last_pulse_ind = np.concatenate((first_pulse_ind[1:]-1, [opto_pulse_inds[-1]]))
        first_last_opto_pulses[1/isi] = [rows.iloc[0,3][first_pulse_ind],rows.iloc[0,3][last_pulse_ind]]
    opto_log_df.loc[session_row_ind,'first_last_opto_pulses'] = [first_last_opto_pulses]
    opto_log_df.loc[session_row_ind,'grouped_opto_pulses'] = [grouped_opto_pulses]

In [None]:
opto_log_df.head(3)

### plot spike rasters for optogenetic stimulus trials
#### define function that will plot rasters

In [None]:
import matplotlib.patches as patches
from matplotlib import gridspec
font = {'family' : 'sans-serif',
        'weight' : 'normal',
        'size'   : 18}

mpl.rc('font', **font)
mpl.rc('xtick', labelsize=16) 
mpl.rc('ytick', labelsize=16)
mpl.rc('axes', labelsize=18)

    

def opto_identify(unit_num, log):
    
    unit = opto_log_df.iloc[unit_num]
    mpl.close('all')
    fig = mpl.figure(figsize=(15, 8))
    fig.suptitle(unit['mouse_name'] + ', ' +  unit['date'] + ', ' + unit['cluster_name'])
    gs = gridspec.GridSpec(5, 2, height_ratios=[1, 15, 5, 1, 15]) 
    ax1 = fig.add_subplot(gs[0])
    ax2 = fig.add_subplot(gs[1])
    ax3 = fig.add_subplot(gs[6])
    ax4 = fig.add_subplot(gs[7])

    rs1 = fig.add_subplot(gs[2])
    rs2 = fig.add_subplot(gs[3])
    rs3 = fig.add_subplot(gs[8])
    rs4 = fig.add_subplot(gs[9])

    rasters = [rs1, rs2, rs3, rs4]
    opto_axes = [ax1, ax2, ax3, ax4]
    frequencies = [1,5,10,40]

    for i, frequency in enumerate(frequencies):
        flop = unit['first_last_opto_pulses'][frequency]

        trial_dur = flop[1][0] - flop[0][0]
        trial_total = 0
        for trial in range(len(flop[0])):
            trial_spike_inds = (flop[0][trial]-0.5 < unit['spikes']) & (unit['spikes'] < flop[1][trial]+.5)
            trial_spikes = unit['spikes'][trial_spike_inds] - flop[0][trial]
            rasters[i].vlines(trial_spikes, trial + .5, trial + 1.3, linewidth = 0.5)
            trial_total += 1

        figure_pulse_inds = [(unit['grouped_opto_pulses'][frequency] >= flop[0][1])
                  & (unit['grouped_opto_pulses'][frequency] <= flop[1][1])]
        figure_pulses = unit['grouped_opto_pulses'][frequency][figure_pulse_inds] - flop[0][1]
        example_pulse_ind = np.where(unit['opto_stim_onsets'] == flop[0][1])
        stim_duration = unit['opto_stim_offsets'][example_pulse_ind] - unit['opto_stim_onsets'][example_pulse_ind]

        for p in figure_pulses:
            opto_axes[i].add_patch(patches.Rectangle((p,0), stim_duration*5, 4, color = 'xkcd:sky blue')) 

        rasters[i].autoscale(enable=True, tight=True)
        rasters[i].spines['right'].set_visible(False)
        rasters[i].spines['top'].set_visible(False)
        rasters[i].xaxis.set_ticks_position('bottom')
        rasters[i].yaxis.set_ticks_position('left')
        rasters[i].set_xlim(-trial_dur*0.1, trial_dur + trial_dur*0.1)
        rasters[i].set_xlabel('Time(s)')
        rasters[i].set_ylabel('Trials')

        opto_axes[i].set_ylim(0, 4, )
        opto_axes[i].set_xlim(-trial_dur*0.1, trial_dur + trial_dur*0.1 )
        opto_axes[i].axis('off')
        opto_axes[i].set_title(str(frequency) + ' Hz')

    mpl.subplots_adjust(left=0.1, right=.9, top=0.9, bottom=0.1)

    return fig

#### create widget that will make exploring data easier - drag slider to switch between units

In [None]:
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
from IPython.html import widgets
from IPython.display import display
import traitlets

log = opto_log_df
unit_num = 12
num = widgets.IntSlider(min = 1, max = opto_log_df.shape[0], step = 1, width='75%', height='30px', 
                    description= "Unit", continuous_update=False)
s = interactive(opto_identify, unit_num = num, log = fixed(log), continuous_update=False)
display(s)

### plot reliability and latency of optogenetic responses
#### calculate identification metrics for each unit

In [None]:
all_opto_metrics_df = None
for row in range(len(opto_log_df)):
    unit = opto_log_df.iloc[row]

    first_pulse = np.min(np.concatenate(list(unit['grouped_opto_pulses'].values())))
    last_pulse = np.max(np.concatenate(list(unit['grouped_opto_pulses'].values())))
    cont_spike_inds = (unit['spikes'] < first_pulse) | (unit['spikes'] > last_pulse)

    frequencies = [1,5,10,40]

    opto_windows_df = pd.DataFrame(dict([(k,pd.Series(v)) for k,v in unit['grouped_opto_pulses'].items()]))
    del opto_windows_df[2]
    names = [str(freq) + 'end' for freq in frequencies]
    evoked_spike_inds_df = opto_windows_df.applymap(lambda y: np.where((y <= unit['spikes']) & (y+0.01>= unit['spikes']))[0]) 

    opto_metrics_df= pd.DataFrame(frequencies, columns = ['frequency'])
    opto_metrics_df['unit_number'] = row
    opto_metrics_df['mouse_name'] = unit['mouse_name']
    opto_metrics_df['date'] = unit['date']
    opto_metrics_df['cluster_name'] = unit['cluster_name']

    spike_inds = []
    for freq in frequencies:
        spike_inds.append(np.concatenate(evoked_spike_inds_df[freq].as_matrix()))
    opto_metrics_df['spike_inds'] = pd.Series(spike_inds)
    pulses = pd.DataFrame([[unit['grouped_opto_pulses'][freq]] for freq in frequencies],
                          columns = ['pulses'])
    opto_metrics_df['pulses'] = pulses

    def find_latencies(s):
        spikes = unit['spikes']
        idx = spikes.searchsorted(s['pulses'], side = 'right')
        idx2 = idx[idx < len(unit['spikes'])] #otherwise will return an error if no spikes follow the last laser pulse
        raw_latencies = spikes[idx2] - s['pulses'][0:len(idx2)]
        if len(raw_latencies <= 0.01) > 0:
            reliability = sum(raw_latencies <= 0.01)/len(raw_latencies <= 0.01)
        else:
            reliability = 0
        latencies = raw_latencies[raw_latencies <= 0.01]
        return [[latencies], reliability]
    # raw_latencies, latencies
    opto_metrics_df['reliability'] = opto_metrics_df.apply(lambda y: find_latencies(y)[1], axis = 1)
    opto_metrics_df['latencies'] = opto_metrics_df.apply(lambda y: find_latencies(y)[0], axis = 1)
    opto_metrics_df['latencies'] = opto_metrics_df['latencies'].apply(lambda y: y[0])
    opto_metrics_df['mean_latencies'] = opto_metrics_df['latencies'].apply(lambda y: np.mean(y) if len(y)>0 else None)
    
    waveform_corr = []
    for waveform in ['waveforms1','waveforms2','waveforms3','waveforms4']:    
        opto_metrics_df[waveform] = opto_metrics_df['spike_inds'].apply(
            lambda y: np.mean(unit[waveform][y], axis=0))
        opto_metrics_df['cont_'+waveform] = [np.mean(unit[waveform][cont_spike_inds], axis=0)]*4
        waveform_corr.append([sp.stats.pearsonr(opto_metrics_df[waveform][i],
                           opto_metrics_df['cont_'+waveform][i])[0] for i in range(4)])
    opto_metrics_df['waveform_corr'] = np.around(np.mean(waveform_corr),2)

    
    if all_opto_metrics_df is None: 
        all_opto_metrics_df = opto_metrics_df
    else:
        all_opto_metrics_df = all_opto_metrics_df.append(opto_metrics_df)
all_opto_metrics_df = all_opto_metrics_df.reset_index(drop = True)

In [None]:
all_opto_metrics_df.head(3)

#### define function that will plot identification metrics for each unit

In [None]:
import seaborn as sns
sns.set_style("ticks")

def opto_metrics(unit_num, log, opto_metrics_log):

    unit = all_opto_metrics_df[all_opto_metrics_df['unit_number'] == unit_num]
    sub_log_df = opto_log_df.iloc[unit_num]
    frequencies = [1,5,10,40]
    stim_duration = 0.002

    mpl.close('all')
    fig2 = mpl.figure(figsize=(18, 10))
    fig2.suptitle(sub_log_df['mouse_name'] + ', ' +  sub_log_df['date'] + ', ' + sub_log_df['cluster_name'])

    ax1 = mpl.subplot2grid((2,2), (0,0), rowspan=2, colspan=1)
    ax2 = mpl.subplot2grid((2,2), (0,1), rowspan=1, colspan=1)
    ax3 = mpl.subplot2grid((2,2), (1,1), rowspan=1, colspan=1)

    ax4 = fig2.add_axes([0.34, 0.425, 0.1, 0.4])
    sns.despine()

    data={}
    total_stims = 0

    laser_evoked_spike_inds = []
    for freq in frequencies:
        opto_stims = sub_log_df['grouped_opto_pulses'][freq]
        for stim_num in range(90):
            spike_inds = (sub_log_df['spikes'] >= opto_stims[stim_num]-0.02) & (sub_log_df['spikes'] 
                                                                                < opto_stims[stim_num] + .03)
            spikes = (sub_log_df['spikes'][spike_inds] - opto_stims[stim_num])*1000
            ax1.vlines(spikes, total_stims + stim_num + .5, total_stims + stim_num + 1.3, linewidth = 2)
            x = total_stims+stim_num

        ax1.plot([0,stim_duration*1000], [total_stims+stim_num, total_stims+stim_num], color = 'k', linewidth = 3)
        ax1.text(-5, total_stims + stim_num/2, str(freq) + ' Hz', fontsize = 14),
        total_stims = total_stims + stim_num

    frequencies = unit['frequency']
    values = np.concatenate(list(unit['latencies']))
    labels = np.concatenate([[freq] * len(unit['latencies'].iloc[i])
                             for i, freq in enumerate(frequencies)])
    data_df = pd.DataFrame([values*1000, labels], index = ['Latency (ms)', 'Frequency (Hz)']).T
    sns.violinplot(x = 'Frequency (Hz)', y = 'Latency (ms)', data = data_df, ax=ax3, color = 'xkcd:sky blue')

    sns.pointplot(x = 'frequency', y = 'reliability', data = unit, ax=ax2, join = False, color = 'xkcd:sky blue')

    offset = 0
    waveform_corr = []
    for electrode in ['waveforms1','waveforms2','waveforms3','waveforms4']:
        laser_evoked_waveforms_df = pd.DataFrame(np.mean(unit[electrode]))+offset
        cont_waveforms_df = pd.DataFrame(np.mean(unit['cont_'+electrode]))+offset
        ax4.plot(cont_waveforms_df, color = 'k', alpha = 0.8, linewidth = 3)
        ax4.plot(laser_evoked_waveforms_df, color = 'xkcd:sky blue',alpha = 0.8, linewidth = 3)
        offset += 0.2
    ax4.set_title('r = ' + str(unit['waveform_corr'].iloc[0]), fontsize = 14)

    ax1.set_xlabel('Time(ms)')
    ax1.set_ylabel('Stim trial')
    ax1.autoscale(enable=True, tight=True)
    ax1.xaxis.set_ticks_position('bottom')
    ax1.yaxis.set_ticks_position('left')
    ax1.add_patch(patches.Rectangle((0,0), stim_duration*1000, total_stims, color = 'xkcd:sky blue'))
    ax1.set_xlim(-20,30)

    ax3.set_xlabel('Frequency (Hz)')
    ax3.set_ylabel('Mean spike latency (ms)')
    ax3.set_ylim(0,20)
    ax3.yaxis.set_ticks(np.arange(0, 20, 2.5))

    ax2.set_xlabel('Frequency (Hz)')
    ax2.set_ylabel('P(spike)')
    ax2.set_ylim(0,1)

    ax4.patch.set_alpha(0)
    ax4.axis('off')

    ax5 = fig2.add_axes([0.40, 0.40, 0.05, 0.1])
    ax5.spines['left'].set_visible(False)
    ax5.spines['top'].set_visible(False)
    ax5.yaxis.set_ticks_position('right')
    ax5.spines['right'].set_smart_bounds(True)
    ax5.spines['bottom'].set_smart_bounds(True)
    mpl.yticks((0,0.2,0.2), fontsize = 13)
    mpl.xticks((0,0.5,0.5), fontsize = 13)
    ax5.set_ylim(0,0.2)
    ax5.set_xlim(0,0.5)
    ax5.patch.set_alpha(0)
    ax5.set_xlabel('ms',fontsize = 13)
    ax5.set_ylabel('mV',fontsize = 13)
    ax5.yaxis.set_label_position('right')

    return fig2

#### create widgets to make exploring data easier

In [None]:
num = widgets.IntSlider(min = 1, max = opto_log_df.shape[0], step = 1, width='75%', height='30px', 
                    description= "Unit", continuous_update=False)
s = interactive(opto_metrics, unit_num = num, log = fixed(log), opto_metrics_log = fixed(all_opto_metrics_df), continuous_update=False)
display(s)

In [None]:
mpl.savefig('Cl3_03-27-17_TT3c2(highReliability).png',format='png', dpi=900)

In [None]:
import colorlover as cl
import plotly.plotly as py
import plotly.graph_objs as go
py.sign_in('efinkel11', 'yhtY1t0dIUH4hX37eXdO')
traces = []

colors = ['rgb(236,231,242)','rgb(166,189,219)','rgb(43,140,190)', 'rgb(10,50,110)']
frequencies = [1,5,10,40]
for i, freq in enumerate(frequencies):
    rows = all_opto_metrics_df['frequency'] == freq
    x = all_opto_metrics_df['mean_latencies'][rows]
    y = all_opto_metrics_df['waveform_corr'][rows]
    z = all_opto_metrics_df['reliability'][rows]
    n = all_opto_metrics_df.loc[rows, ['mouse_name',
                                'date', 'cluster_name']].as_matrix()
    clust_names = [' '.join(name) for name in n]
    clust_names = ['.'.join((ns[0:2]+ns[9],ns[12],ns[14:16],ns[17:])) for ns in clust_names]

    trace = go.Scatter3d(
        text = clust_names,
        x=x, y=y, z=z,
        mode = 'markers',
        marker=dict(
            size=6,
            opacity = 0,
            color = colors[i]),
        line = dict(
            width = 0,
            color = colors[i]),
        name = str(freq)+'Hz',)
    traces.append(trace)
data = traces

layout = dict(
    width=1000,
    height=800,
    autosize=False,
    legend = dict(orientation= "h"),
    scene=dict(
        xaxis=dict(title = 'mean spike latency'),
        yaxis=dict(title = 'waveform corr'),
        zaxis=dict(
            title = 'p(spike)')))

fig = dict(data=data, layout=layout)

py.iplot(fig, filename='pandas-brownian-motion-3d', height=700, validate=False)

In [None]:
log_df['stim_onset'] = log_df['stim_onset'].fillna(0)
log_df['spike_times'] = log_df['spike_times'] - log_df['stim_onset']
licks = pd.concat([log_df['licks_right'] - log_df['stim_onset'] , log_df['licks_left']-log_df['stim_onset']], axis=1)
licks = licks.applymap(lambda y: y[[0.1<y]] if len(y) > 0 else y)
licks = licks.applymap(lambda y: y[[1>y]] if len(y) > 0 else y)
licks = licks.applymap(lambda y: min(y) if len(y) > 0 else np.nan)
log_df['first_lick'] = licks.min(axis=1)

log_df[['licks_right','licks_left', 'spike_times']] = \
log_df[['licks_right','licks_left', 'spike_times']].applymap(lambda x: np.concatenate(x) if len(list(x)) > 0 else x)

log_df = log_df.sort_values(['mouse_name', 'date', 'cluster_name', 'first_lick'], ascending = [1,1,1,1])
log_df['identified'] = 'unidentified'
log_df

In [None]:
subset_dict = {}
size_dict = {}
subset_dict['None'] = 0
size_dict['None'] = 0

categories = np.concatenate([log_df['mouse_name'].unique(), log_df['identified'].unique()])

for cat in categories:
        subset = log_df[log_df['mouse_name'] == cat]
        if subset.size == 0:
            subset = log_df[log_df['identified'] == cat]
            print(cat)
        subset_dict[cat] = subset
        unique_units = subset[['mouse_name', 'date', 'cluster_name']].drop_duplicates()
        size_dict[cat] = len(unique_units)
        #print(unique_units.size)

In [None]:
%autoreload 2
%timeit
import ind_unit as iu
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
from IPython.html import widgets
from IPython.display import display
import traitlets


w = widgets.Dropdown(options = list(subset_dict.keys()), value = 'None', description='cellType:')

num = widgets.IntSlider(min = 1, max = 50, step = 1, width='75%', height='30px', 
                    description= "Unit", continuous_update=False)
x1 = widgets.FloatSlider(min = -.9, max = 0, step = 0.05, value = -0.25,width='75%',
                          height='30px', description= "xMin", continuous_update=False)
x2 = widgets.FloatSlider(min = 0, max = 2, step = 0.05, value=0.75,width='75%', height='30px',
                          description= "xMax",continuous_update=False)
s = interactive(iu.plot_unit, df_dict = fixed(subset_dict), s_key = w, n=num, x_min = x1, x_max = x2, continuous_update=False)


def update_max(*args):
    num.value = 1
    num.max = size_dict[w.value]
w.observe(update_max, 'value')

display(s)

In [None]:
mpl.savefig('Cl3_03-15-17_TT6c4_rasters.png',format='png', dpi=900)

In [None]:
ar

In [None]:
len(ar)