In [81]:
import sys
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly as ply
import ipywidgets as iw
from oneibl import one
import brainbox.plot as bbp
from fitplot_funs import err_wght_sync
from export_funs import trialinfo_to_df

%matplotlib inline

fit_types = {'Stimulus on': 'stim',
             'Feedback': 'fdbck',
             'Prior estimate gain': 'prior',}

# Iterate though directories and find which mice/sessions we have fits for
mice = [x for x in os.listdir('./fits/') if os.path.isdir('./fits/' + x)]
fitsess = {mouse: [y[:-2] for y in os.listdir(f'./fits/{mouse}/') if y.split('.')[-1] == 'p'] for mouse in mice}

mousewidget = iw.Dropdown(options=mice)
sesswidget = iw.Dropdown()
cellwidget = iw.Dropdown()
fitwidget = iw.Dropdown()

def updatesess(*args):
    sesswidget.options = fitsess[mousewidget.value]
sesswidget.observe(updatesess)

mouse = 'ZM_2240'
sess = '2020-01-23_session_2020-04-11_probe0_fit.p'
percentile = 95.  # We should only show cells whose prior gain / gain estimate variance is above this percentile

fits = np.load(f'./fits/{mouse}/{sess}', allow_pickle=True)
wts_per_kern = fits['wts_per_kern']
kern_length = fits['kern_length']
binw = fits['glm_binsize']
uuid = fits['session_uuid']
trdf = trialinfo_to_df(uuid, maxlen=2.)
one = one.ONE()
spikes, clu = one.load(uuid, ['spikes.times', 'spikes.clusters'])
fullfits = fits['fits']
fitdf = fullfits[np.isfinite(fullfits.prior)]
prior_threshold = np.percentile(np.abs(fitdf.prior / fitdf.varprior), percentile)
subsetdf = fitdf[np.abs(fitdf.prior / fitdf.varprior) > prior_threshold]

In [94]:
@iw.interact
def plot_kern(cell=subsetdf.index.to_list(), kern=fit_types.keys()):
    fig, axes = plt.subplots(2, 1, figsize=(15, 9))
    currfit = fitdf.loc[cell]
    timescale = np.arange(0, kern_length, binw)
    if fit_types[kern] == 'stim':
        weight1 = currfit['stim_L']
        weight2 = currfit['stim_R']
        err1 = err_wght_sync(currfit['varstim_L'], weight1)
        err2 = err_wght_sync(currfit['varstim_R'], weight2)
        label1 = 'Left stimulus onset'
        label2 = 'Right stimulus onset'
        title = 'Kernels fit to right and left stimulus onset'
        event_t1 = np.array(trdf[np.isfinite(trdf.contrastLeft)].stimOn_times)
        event_t2 = np.array(trdf[np.isfinite(trdf.contrastRight)].stimOn_times)
    elif fit_types[kern] == 'fdbck':
        weight1 = currfit['fdbck_corr']
        weight2 = currfit['fdbck_incorr']
        err1 = err_wght_sync(currfit['varfdbck_corr'], weight1)
        err2 = err_wght_sync(currfit['varfdbck_incorr'], weight2)
        label1 = 'Correct feedback'
        label2 = 'Incorrect feedback'
        title = 'Kernels fit to correct and incorrect feedback'
        event_t1 = np.array(trdf[trdf.feedbackType == 1].stimOn_times)
        event_t2 = np.array(trdf[trdf.feedbackType == -1].stimOn_times)
    elif fit_types[kern] == 'prior':
        plt.close(fig)
        print('Gain modulation of prior estimate:', currfit['prior'], 'Std dev of fit:', currfit['varprior'])
        return
    bbp.peri_event_time_histogram(spikes, clu, event_t1, int(cell[4:]), t_before=0., t_after=0.6, ax=axes[0],
                                  error_bars='sem',
                                  pethline_kwargs={'color': 'navy', 'lw': 2, 'label': label1 + ' PSTH'},
                                  errbar_kwargs={'color': 'navy', 'alpha': 0.2})
    ymin1, ymax1 = axes[0].get_yticks()[1:]
    bbp.peri_event_time_histogram(spikes, clu, event_t2, int(cell[4:]), t_before=0., t_after=0.6, ax=axes[0],
                                  error_bars='sem',
                                  pethline_kwargs={'color': 'orange', 'lw': 2, 'label': label2 + ' PSTH'},
                                  errbar_kwargs={'color': 'orange', 'alpha': 0.2})
    ymin2, ymax2 = axes[0].get_yticks()[1:]
    axes[0].set_ylim([min((ymin1, ymin2)), max((ymax1, ymax2))])
    axes[0].legend()
    axes[0].set_title('PSTH about event')
    plt.autoscale(axes[0])
    axes[1].errorbar(timescale, weight1, yerr=err1, label=label1)
    axes[1].errorbar(timescale, weight2, yerr=err2, label=label2)
    axes[1].legend()
    axes[1].set_xlabel('Time (s)')
    axes[1].set_ylabel('Weight value')
    axes[1].set_title(title)
    
    

interactive(children=(Dropdown(description='cell', options=('cell8', 'cell10', 'cell20', 'cell27', 'cell35', '…

1         27.753291
2         31.336966
5         44.437422
6         48.420591
8         56.618471
           ...     
1069    4800.805471
1071    4823.438375
1072    4827.736678
1073    4832.136605
1074    4837.037581
Name: stimOn_times, Length: 579, dtype: float64