In [97]:
import sys
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly as ply
import ipywidgets as iw
from oneibl import one
import brainbox.plot as bbp
from fitplot_funs import err_wght_sync
from export_funs import trialinfo_to_df

%matplotlib inline

fit_types = {'Stimulus on': 'stim',
             'Feedback': 'fdbck',
             'Prior estimate gain': 'prior',}

# Iterate though directories and find which mice/sessions we have fits for
mice = [x for x in os.listdir('./fits/') if os.path.isdir('./fits/' + x)]
fitsess = {mouse: [y[:-2] for y in os.listdir(f'./fits/{mouse}/') if y.split('.')[-1] == 'p'] for mouse in mice}

mousewidget = iw.Dropdown(options=mice)
sesswidget = iw.Dropdown()
cellwidget = iw.Dropdown()
fitwidget = iw.Dropdown()

def updatesess(*args):
    sesswidget.options = fitsess[mousewidget.value]
sesswidget.observe(updatesess)

mouse = 'ZM_2240'
sess = '2020-01-22_session_2020-04-12_probe0_fit.p'
probe_idx=0


fits = np.load(f'./fits/{mouse}/{sess}', allow_pickle=True)
wts_per_kern = fits['wts_per_kern']
kern_length = fits['kern_length']
binw = fits['glm_binsize']
uuid = fits['session_uuid']
p_est = fits['prior_est']
trdf = trialinfo_to_df(uuid, maxlen=2.)
one = one.ONE()
spikes = one.load(uuid, ['spikes.times'])[probe_idx]
clu = one.load(uuid, ['spikes.clusters'])[probe_idx]

Connected to https://alyx.internationalbrainlab.org as berk.gercek


In [141]:
percentile = 50.  # We should only show cells whose prior gain / gain estimate variance is above this percentile
fullfits = fits['fits']
fitdf = fullfits[np.isfinite(fullfits.prior)]
prior_threshold = np.percentile(np.abs(fitdf.prior / fitdf.varprior), percentile)
subsetdf = fitdf[np.abs(fitdf.prior / fitdf.varprior) > prior_threshold]
if 'bias' not in trdf.columns:
    trdf = trdf.join(p_est)
trdf = trdf[np.isfinite(trdf.bias)]

In [142]:
@iw.interact
def plot_kern(cell=subsetdf.index.to_list(), kern=fit_types.keys()):
    fig, axes = plt.subplots(2, 1, figsize=(15, 9))
    currfit = fitdf.loc[cell]
    timescale = np.arange(0, kern_length, binw)
    if fit_types[kern] == 'stim':
        weight1 = currfit['stim_L']
        weight2 = currfit['stim_R']
        err1 = err_wght_sync(currfit['varstim_L'], weight1)
        err2 = err_wght_sync(currfit['varstim_R'], weight2)
        label1 = 'Left stimulus onset'
        label2 = 'Right stimulus onset'
        title = 'Kernels fit to right and left stimulus onset'
        event_t1 = trdf[np.isfinite(trdf.contrastLeft)].stimOn_times
        event_t2 = trdf[np.isfinite(trdf.contrastRight)].stimOn_times
    elif fit_types[kern] == 'fdbck':
        weight1 = currfit['fdbck_corr']
        weight2 = currfit['fdbck_incorr']
        err1 = err_wght_sync(currfit['varfdbck_corr'], weight1)
        err2 = err_wght_sync(currfit['varfdbck_incorr'], weight2)
        label1 = 'Correct feedback'
        label2 = 'Incorrect feedback'
        title = 'Kernels fit to correct and incorrect feedback'
        event_t1 = trdf[trdf.feedbackType == 1].stimOn_times
        event_t2 = trdf[trdf.feedbackType == -1].stimOn_times
    elif fit_types[kern] == 'prior':
        plt.close(fig)
        fig, axes = plt.subplots(1, 1, figsize=(15, 9))
        low, mid, high = np.percentile(trdf.bias, [33, 66, 99.99])
        lowtrials = trdf[trdf.bias < low].stimOn_times
        midtrials = trdf[(low < trdf.bias) & (trdf.bias < mid)].stimOn_times
        hightrials = trdf[(mid < trdf.bias) & (trdf.bias < high)].stimOn_times
        print('Gain modulation of prior estimate:', currfit['prior'], 'Std dev of fit:', currfit['varprior'])
        bbp.peri_event_time_histogram(spikes, clu, lowtrials, int(cell[4:]), t_before=0.4, t_after=0.6, ax=axes,
                                      error_bars='sem',
                                      pethline_kwargs={'color': 'navy', 'lw': 2, 'label': 'Low bias (toward left)'},
                                      errbar_kwargs={'color': 'navy', 'alpha': 0.2})
        ymi1, yma1 = axes.get_ylim()
        bbp.peri_event_time_histogram(spikes, clu, midtrials, int(cell[4:]), t_before=0.4, t_after=0.6, ax=axes,
                                      error_bars='sem',
                                      pethline_kwargs={'color': 'black', 'lw': 2, 'label': 'Neutral bias'},
                                      errbar_kwargs={'color': 'black', 'alpha': 0.2})
        ymi2, yma2 = axes.get_ylim()
        bbp.peri_event_time_histogram(spikes, clu, hightrials, int(cell[4:]), t_before=0.4, t_after=0.6, ax=axes,
                                      error_bars='sem',
                                      pethline_kwargs={'color': 'maroon', 'lw': 2, 'label': 'High bias (toward right)'},
                                      errbar_kwargs={'color': 'maroon', 'alpha': 0.2})
        ymi3, yma3 = axes.get_ylim()
        axes.set_title('PSTH about stimulus on at low, middle, and high bias estimate')
        axes.set_xlabel('Time (s)')
        axes.set_ylabel('Firing rate')
        axes.set_ylim([min([ymi1, ymi2, ymi3]), max([yma1, yma2, yma3])])
        axes.legend()
        return
    bbp.peri_event_time_histogram(spikes, clu, event_t1, int(cell[4:]), t_before=0., t_after=0.6, ax=axes[0],
                                  error_bars='sem',
                                  pethline_kwargs={'color': 'navy', 'lw': 2, 'label': label1 + ' PSTH'},
                                  errbar_kwargs={'color': 'navy', 'alpha': 0.2})
    ymin1, ymax1 = axes[0].get_yticks()[1:]
    bbp.peri_event_time_histogram(spikes, clu, event_t2, int(cell[4:]), t_before=0., t_after=0.6, ax=axes[0],
                                  error_bars='sem',
                                  pethline_kwargs={'color': 'orange', 'lw': 2, 'label': label2 + ' PSTH'},
                                  errbar_kwargs={'color': 'orange', 'alpha': 0.2})
    ymin2, ymax2 = axes[0].get_yticks()[1:]
    axes[0].set_ylim([min((ymin1, ymin2)), max((ymax1, ymax2))])
    axes[0].legend()
    axes[0].set_title('PSTH about event')
    plt.autoscale(axes[0])
    axes[1].errorbar(timescale, weight1, yerr=err1, label=label1)
    axes[1].errorbar(timescale, weight2, yerr=err2, label=label2)
    axes[1].legend()
    axes[1].set_xlabel('Time (s)')
    axes[1].set_ylabel('Weight value')
    axes[1].set_title(title)
    
    

interactive(children=(Dropdown(description='cell', options=('cell1', 'cell2', 'cell3', 'cell4', 'cell5', 'cell…

In [130]:
trdf

Unnamed: 0,choice,response_times,probabilityLeft,feedbackType,feedback_times,contrastLeft,contrastRight,goCue_times,stimOn_times,bias
0,-1,17.879011,0.5,1,17.879212,,1.000,17.330767,17.437556,
1,1,21.219040,0.5,1,21.219145,0.2500,,20.989734,21.019431,-7.339601
2,-1,24.929371,0.5,1,24.929473,,0.250,24.412359,24.437156,-7.340415
3,-1,29.252499,0.5,1,29.252608,,0.125,27.680332,27.701763,-7.340949
4,-1,32.474228,0.5,1,32.474352,,0.250,31.983968,32.003166,-7.339941
...,...,...,...,...,...,...,...,...,...,...
639,-1,3568.994364,0.2,1,3568.994498,,0.125,3568.136317,3568.143117,-4.493479
640,-1,3572.878115,0.2,1,3572.878243,,0.125,3572.382293,3572.394658,-4.499667
642,-1,3593.259743,0.2,-1,3593.260771,0.0625,,3592.800117,3592.809983,-4.500962
644,1,3608.220022,0.2,-1,3608.220975,,0.000,3607.904307,3607.924405,-4.502257
