In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns

### Grab IBL data from FlatIron server

In [None]:
# Must be in ibllibenv to access
from oneibl.one import ONE
one = ONE()

date_range = ['2018-01-01','2019-05-30']
subjects = ['CSHL_003','CSHL_005','CSHL_007','IBL-T1','IBL-T4','ibl_witten_04','ibl_witten_05', 'IBL_10', 'IBL_13']
var_extract = ['_ibl_trials.feedbackType', '_ibl_trials.choice', '_ibl_trials.contrastLeft',
               '_ibl_trials.contrastRight', '_ibl_trials.included', '_ibl_trials.probabilityLeft']

for s in subjects:
    eids = one.search(subjects=s, date_range=date_range)

    for eid in eids:
        one.load(eid, dataset_types=var_extract)

### Process Flatiron data

In [None]:
from PBups.IBL_Data_Processing import getAllData_IBL

sourceDir = '/Users/nicholasroy/FlatIron/'
mouse = 'CSHL_003'
lab = "churchlandlab"
outData = getAllData_IBL(sourceDir, mouse, forceNew=True, labs=lab)

### Example Figure

In [None]:
### Define a custom function for plotting weights for the manuscript
### A generally more flexible version of plotting is available in PsyTrack as:
###    from psytrack.plot.analysisFunctions import makeWeightPlot

def weightPlot_IBL(wMode, outData, weights, colors, zorder,
                    START=0, END=0, errorbar=None, stimbias=None):
    
    ### Initialization
    K, N = wMode.shape
    
    if START <0: START = N + START
    if START > N: raise Exception("START > N : " + str(START) + ", " + str(N))
    if END <=0: END = N + END
    if END > N: END = N
    if START >= END: raise Exception("START >= END : " + str(START) + ", " + str(END))
    
    # Some useful values to have around
    maxval = 7.9 # largest magnitude of any weight across all 3 training periods
    cumdays = np.cumsum(outData['dayLength'])
    myrange = np.arange(START,END)

    ##### Plotting |
    #####----------+
    plt.figure(figsize=(7.5,3))
    
    labels = []
    for j in sorted(weights.keys()):
        labels += [j]*weights[j]

    for i, w in enumerate(labels):

        plt.plot(wMode[i], lw=1.5, alpha=0.8, linestyle='-', c=colors[w], zorder=zorder[w])

        # Plot errorbars on weights if option is passed
        if errorbar is not None:
            plt.fill_between(np.arange(len(wMode[i])), wMode[i]-2*errorbar[i], wMode[i]+2*errorbar[i], 
                             facecolor=colors[w], zorder=zorder[w], alpha=0.2)

    # Plot vertical session lines + write text if enough space and option passed
    for i in range(len(cumdays)):
        start = cumdays[i-1] * int(i!=0)
        end = cumdays[i]
        plt.axvline(start, color='black', linestyle = '-', lw=0.5, alpha=0.5, zorder=0)
        
    if stimbias:
        fc = {0.5 : 'None', 0.2 : colors['sR'], 0.8 : colors['sL']}
        probL = outData['probL']
        i = START
        while i < END:
            _start = i
            while i+1 < END and np.linalg.norm(probL[i] - probL[i+1]) < 0.0001: i+=1
            plt.axvspan(_start, i+1, facecolor=fc[probL[_start]], alpha=0.2, edgecolor=None)
            i += 1

    plt.axhline(0, color="black", linestyle="--", alpha=0.5, zorder=0)
    plt.ylim(-maxval,maxval); plt.xlim(START,END)
#     plt.xlabel("Trial #"); plt.ylabel("Weights")
    plt.gca().set_yticks(np.arange(-int(maxval)+1, int(maxval)+1,2))


### Early Blocks Plot

CSHL_003 day 03-21 to 03-23, p=3.5

In [None]:
from psytrack.hyperOpt import hyperOpt
from psytrack.helper.helperFunctions import trim
from psytrack.helper.invBlkTriDiag import invDiagHess

### Collect data from manually determined training period
_start  = np.where(outData['date'] >= '2019-03-21')[0][0]
_end    = np.where(outData['date'] >= '2019-03-23')[0][0]
new_dat = trim(outData, START=_start, END=_end)

# Hardcode random trials where probL != 0.5 before bias begins to 0.5
new_dat['probL'][:np.where(new_dat['date'] >= '2019-03-22')[0][0]] = 0.5

### Compute
weights = {'bias' : 1, 'sL' : 1, 'sR' : 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**-5]*K
  }
optList = ['sigma','sigDay']

hyp, evd, wMode, hess = hyperOpt(new_dat, hyper_guess, weights, optList)

W_std = np.sqrt(invDiagHess(hess)).reshape(K,-1)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : W_std,
       'weights' : weights, 'new_dat' : new_dat}

In [None]:
# Set default colormap (used by IBL)
cmap = plt.get_cmap('vlag') #vlag0 = #2369bd; vlag1 = #a9373b
colors = {'bias' : '#FAA61A', 
          's1' : cmap(1.0), 's2' : cmap(0.0), 
          'sR' : cmap(1.0), 'sL' : cmap(0.0),
          'c' : '#59C3C3', 'h' : '#9593D9', 's_avg' : '#99CC66'}
zorder = {'bias' : 2, 
          'sR' : 3, 'sL' : 3,
          'c' : 1, 'h' : 1, 's_avg' : 1}

weightPlot_IBL(dat['wMode'], dat['new_dat'], dat['weights'], colors, zorder,
               START=0, END=0, errorbar=dat['W_std'], stimbias=True)