# _Extracting the Dynamics of Behavior in Decision-Making Experiments_
## Figure Generator

by Nicholas A. Roy $\quad$  _(v1.1, last updated November 23, 2020)_

---

This notebook will precisely recreate all figures (Figures 1-8 and Supplementary Figures S1-8) from our manuscript _Extracting the Dynamics of Behavior in Decision-Making Experiments_. All figures will require the `PsyTrack` python package, as well as several other standard Python libraries. Figures requiring data will require that the corresponding dataset be downloaded and pre-processed. The necessary requirements for each figure are listed below, followed by instructions for downloading & preparing each of the three datasets:
 
 - Only the `PsyTrack` package is needed to produce the simulated data required for Figures 1, 2, S1, and S2
 
 - The IBL mouse dataset is required (as well as the `ONE Light` Python library) for Figures 3, 4, and S3-6
 
 - The Akrami rat dataset is required for Figures 5, 6, 8, and S7
 
 - The Akrami human subject dataset is required for Figures 7 and S8

A section with preliminary setup code is below, followed by code and instructions to load each dataset. There is then a section for each figure, with subsections for each subfigure. A few things to note:

 - **ALTERATIONS** | Many subfigures in the paper include some superficial additions done in Adobe Illustrator. Subfigures created purely inside Adobe illustrator (e.g. schematic figures) are noted.
 - **COMPUTE TIME** | While most individual `PsyTrack` models can be fit quickly, some figure require fitting dozens of models and so can take a relatively long time to compute. Subfigures which take longer than 90 seconds to produce are marked with an approximation of how long they ought to take.
 - **LOCAL STORAGE** | Many figures save the results of model fits to local storage, so figures can be retrieved and modified without having to refit the model each time. All the temporary files produced by the notebook are saved to the directory specified by the `SPATH` variable in the Preliminary setup section below. All temporary files plus all the subfigures saved should use under 500MB total. Note that if you are using a Colab hosted runtime, then anything saved to Colab local storage will disappear once the runtime expires (Colab has a 12 hour max). There is code to download all figures from Colab at the end of the notebook.
 - **SUBFIGURE DEPENDENCIES** | Occasionally, subfigures will depend upon the results of an earlier subfigure (usually part of the same figure) — a cell which fails to run may simply need an earlier cell to be run first (these instances should be clearly marked).
 - **SUBJECT-SPECIFIC DETAILS** | Many analyses run on an example subject should allow for other subjects to be easily swapped in, but some analyses may have subject-specific code that may impede this (i.e. hardcoded dates to extract certain sessions for analysis).
 - **VERSIONING** | Any additions, fixes, or changes made to this notebook will be noted in the versioning section at the very end of the notebook.
 
---

# Preliminary setup and data retrieval

Users will need to install the `PsyTrack` package (version 2.0), by running the cell below. We also define a variable `SPATH` which is the directory where all data files and figures produced by the notebook will be saved.

Several standard Python packages are used: `numpy`, `scipy`, `matplotlib`, and `pandas`. We import all these libraries before proceeding, as well as setting several parameters in `matplotlib` to standardize the figures produced.

In [1]:
import os
import re
from IPython.display import clear_output
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

# Install then import PsyTrack
!pip install psytrack==2.0
import psytrack as psy

# Set save path for all figures, decide whether to save permanently
SPATH = "Figures/"
!mkdir -p "{SPATH}"

# Set matplotlib defaults for making files consistent in Illustrator
colors = psy.COLORS
zorder = psy.ZORDER
plt.rcParams['figure.dpi'] = 140
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['savefig.facecolor'] = (1,1,1,0)
plt.rcParams['savefig.bbox'] = "tight"
plt.rcParams['font.size'] = 10
# plt.rcParams['font.family'] = 'sans-serif'     # not available in Colab
# plt.rcParams['font.sans-serif'] = 'Helvetica'  # not available in Colab
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['axes.labelsize'] = 12

clear_output()

---

## Download and pre-process IBL mouse data

1) Use the command below to instal the IBL's [ONE Light](https://github.com/int-brain-lab/ibllib/tree/master/oneibl) Python library, download the [IBL mouse behavior dataset](https://doi.org/10.6084/m9.figshare.11636748.v7) _(version 7, uploaded February 7, 2020)_ to our `SPATH` directory as `ibl-behavior-data-Dec2019.zip`, and unzip the file.

In [7]:
#!pip install ibllib
!wget -nc -O "{SPATH}ibl-behavior-data-Dec2019.zip" "https://ndownloader.figshare.com/files/21623715"
!unzip -d "{SPATH}" -n "{SPATH}ibl-behavior-data-Dec2019.zip"
#clear_output()

zsh:1: command not found: wget
unzip:  cannot find or open Figures/ibl-behavior-data-Dec2019.zip, Figures/ibl-behavior-data-Dec2019.zip.zip or Figures/ibl-behavior-data-Dec2019.zip.ZIP.


2) Use the [ONE Light](https://github.com/int-brain-lab/ibllib/tree/master/oneibl) library to build a table of all the subject and session data contained within the dataset.

In [3]:
from oneibl.onelight import ONE

ibl_data_path = SPATH + 'ibl-behavioral-data-Dec2019'
current_cwd = os.getcwd()
os.chdir(ibl_data_path)

# Search all sessions that have these dataset types.
required_vars = ['_ibl_trials.choice', '_ibl_trials.contrastLeft',
                 '_ibl_trials.contrastRight','_ibl_trials.feedbackType']
one = ONE()
eids = one.search(required_vars)

mouseData = pd.DataFrame()
for eid in eids:
    lab, _, subject, date, session = eid.split("/")    
    sess_vars = {
        "eid": eid,
        "lab": lab,
        "subject": subject,
        "date": date,
        "session": session,
    }
    mouseData = mouseData.append(sess_vars, sort=True, ignore_index=True)

os.chdir(current_cwd)

FileNotFoundError: [Errno 2] No such file or directory: 'Figures/ibl-behavioral-data-Dec2019'

In [None]:
mouseData

3) Next, we use the table of session data to process the raw trial data below into a single CSV file, `ibl_processed.csv`, saved to our `SPATH` directory.

There are several known anomalies in the raw data:
 - CSHL_002 codes left contrasts as negative right contrasts on 81 trials (these trials are corrected)
 - ZM_1084 has `feedbackType` of 0 for 3 trials (these trials are omitted)
 - DY_009, DY_010, DY_011 each have <5000 trials total (no adjustment)
 - ZM_1367, ZM_1369, ZM_1371, ZM_1372, and ZM_1743 are shown non-standard contrast values of 0.04 and 0.08 (no adjustment)

In [None]:
all_vars = ["contrastLeft", "contrastRight", "choice", "feedbackType", "probabilityLeft"]
df = pd.DataFrame()

all_mice = []
for j, s in enumerate(mouseData["subject"].unique()):
    print("\rProcessing " + str(j+1) + " of " + str(len(mouseData["subject"].unique())), end="")
    mouse = mouseData[mouseData["subject"]==s].sort_values(['date', 'session']).reset_index()
    for i, row in mouse.iterrows():
        myVars = {}
        for v in all_vars:
            filename = "_ibl_trials." + v + ".npy"
            var_file = os.path.join(ibl_data_path, row.eid, "alf", filename)
            myVars[v] = list(np.load(var_file).flatten())

        num_trials = len(myVars[v])
        myVars['lab'] = [row.lab]*num_trials
        myVars['subject'] = [row.subject]*num_trials
        myVars['date'] = [row.date]*num_trials
        myVars['session'] = [row.session]*num_trials

        all_mice += [pd.DataFrame(myVars, columns=myVars.keys())]
        
df = pd.concat(all_mice, ignore_index=True)

df = df[df['choice'] != 0]        # dump mistrials
df = df[df['feedbackType'] != 0]  # 3 anomalous trials from ZM_1084, omit
df.loc[np.isnan(df['contrastLeft']), "contrastLeft"] = 0
df.loc[np.isnan(df['contrastRight']), "contrastRight"] = 0
df.loc[df["contrastRight"] < 0, "contrastLeft"] = np.abs(df.loc[df["contrastRight"] < 0, "contrastRight"])
df.loc[df["contrastRight"] < 0, "contrastRight"] = 0  # 81 anomalous trials in CSHL_002, correct
df["answer"] = df["feedbackType"] * df["choice"]      # new column to indicate correct answer
df.loc[df["answer"]==1, "answer"] = 0
df.loc[df["answer"]==-1, "answer"] = 1
df.loc[df["feedbackType"]==-1, "feedbackType"] = 0
df.loc[df["choice"]==1, "choice"] = 0
df.loc[df["choice"]==-1, "choice"] = 1
df.to_csv(SPATH+"ibl_processed.csv", index=False)

4) Next we run a few sanity checks on our data, to make sure everything processed correctly.

In [None]:
print("contrastLeft: ", np.unique(df['contrastLeft']))   # [0, 0.0625, 0.125, 0.25, 0.5, 1.0] and [0.04, 0.08]
print("contrastRight: ", np.unique(df['contrastRight'])) # [0, 0.0625, 0.125, 0.25, 0.5, 1.0] and [0.04, 0.08]
print("choice: ", np.unique(df['choice']))               # [0, 1]
print("feedbackType: ", np.unique(df['feedbackType']))   # [0, 1]
print("answer: ", np.unique(df['answer']))               # [0, 1]

5) Finally, we define a function `getMouse` that extracts the data for a single mouse from our CSV file, and returns it as a PsyTrack compatible `dict`. We will use this function to access IBL mouse data in the figures below. Note the keyword argument and default value $p=5$ which controls the strength of the $\tanh$ transformation on the contrast values. See Figure S4 and the STAR Methods of the accompanying paper for more details.

**Note:** Once steps 1-5 have been run once, only step 5 will need to be run on subsequent uses.

In [None]:
ibl_mouse_data_path = SPATH + "ibl_processed.csv"

MOUSE_DF = pd.read_csv(ibl_mouse_data_path)
def getMouse(subject, p=5):
    df = MOUSE_DF[MOUSE_DF['subject']==subject]   # Restrict data to the subject specified
    
    cL = np.tanh(p*df['contrastLeft'])/np.tanh(p)   # tanh transformation of left contrasts
    cR = np.tanh(p*df['contrastRight'])/np.tanh(p)  # tanh transformation of right contrasts
    inputs = dict(cL = np.array(cL)[:, None], cR = np.array(cR)[:, None])

    dat = dict(
        subject=subject,
        lab=np.unique(df["lab"])[0],
        contrastLeft=np.array(df['contrastLeft']),
        contrastRight=np.array(df['contrastRight']),
        date=np.array(df['date']),
        dayLength=np.array(df.groupby(['date','session']).size()),
        correct=np.array(df['feedbackType']),
        answer=np.array(df['answer']),
        probL=np.array(df['probabilityLeft']),
        inputs = inputs,
        y = np.array(df['choice'])
    )
    
    return dat

---

## Download and pre-process Akrami rat data

1) Download the [Akrami rat behavior dataset](https://doi.org/10.6084/m9.figshare.12213671.v1) _(version 1, uploaded May 18, 2020)_ to the `SPATH` directory as `rat_behavior.csv`.

In [None]:
!wget -nc -O "{SPATH}rat_behavior.csv" "https://ndownloader.figshare.com/files/22461707"
clear_output()

2) Sessions in the data corresponding to early shaping stages will be omitted, as will all mistrials (see the dataset's README for more info). The `getRat` function will then load a particular rat into a PsyTrack compatible `dict`.

`getRat` has two optional parameters: `first` which will return a data set with only the first `first` trials (the default of 20,000 works for all analyses); `cutoff` excludes sessions with fewer than `cutoff` valid trials (default set to 50). We will use this function to access Akrami rat data in the figures below.

In [None]:
akrami_rat_data_path = SPATH + "rat_behavior.csv"

RAT_DF = pd.read_csv(akrami_rat_data_path)
RAT_DF = RAT_DF[RAT_DF["training_stage"] > 2]  # Remove trials from early training
RAT_DF = RAT_DF[~np.isnan(RAT_DF["choice"])]   # Remove mistrials
def getRat(subject, first=20000, cutoff=50):

    df = RAT_DF[RAT_DF['subject_id']==subject]  # restrict dataset to single subject
    df = df[:first]  # restrict to "first" trials of data
    # remove sessions with fewer than "cutoff" valid trials
    df = df.groupby('session').filter(lambda x: len(x) >= cutoff)   

    # Normalize the stimuli to standard normal
    s_a = (df["s_a"] - np.mean(df["s_a"]))/np.std(df["s_a"])
    s_b = (df["s_b"] - np.mean(df["s_b"]))/np.std(df["s_b"])
    
    # Determine which trials do not have a valid previous trial (mistrial or session boundary)
    t = np.array(df["trial"])
    prior = ((t[1:] - t[:-1]) == 1).astype(int)
    prior = np.hstack(([0], prior))

    # Calculate previous average tone value
    s_avg = (df["s_a"][:-1] + df["s_b"][:-1])/2
    s_avg = (s_avg - np.mean(s_avg))/np.std(s_avg)
    s_avg = np.hstack(([0], s_avg))
    s_avg = s_avg * prior  # for trials without a valid previous trial, set to 0

    # Calculate previous correct answer
    h = (df["correct_side"][:-1] * 2 - 1).astype(int)   # map from (0,1) to (-1,1)
    h = np.hstack(([0], h))
    h = h * prior  # for trials without a valid previous trial, set to 0
    
    # Calculate previous choice
    c = (df["choice"][:-1] * 2 - 1).astype(int)   # map from (0,1) to (-1,1)
    c = np.hstack(([0], c))
    c = c * prior  # for trials without a valid previous trial, set to 0
    
    inputs = dict(s_a = np.array(s_a)[:, None],
                  s_b = np.array(s_b)[:, None],
                  s_avg = np.array(s_avg)[:, None],
                  h = np.array(h)[:, None],
                  c = np.array(c)[:, None])

    dat = dict(
        subject = subject,
        inputs = inputs,
        s_a = np.array(df['s_a']),
        s_b = np.array(df['s_b']),
        correct = np.array(df['hit']),
        answer = np.array(df['correct_side']),
        y = np.array(df['choice']),
        dayLength=np.array(df.groupby(['session']).size()),
    )
    return dat

---

## Download and pre-process Akrami human subject data

1) Download the [Akrami human subject behavior dataset](https://doi.org/10.6084/m9.figshare.12213671.v1) _(version 1, uploaded May 18, 2020)_. See the dataset's README for more info.

In [None]:
!wget -nc -O "{SPATH}human_auditory.csv" "https://ndownloader.figshare.com/files/22461695"
clear_output()

2) We define a function `getHuman` that extracts the data for a single human subject from the downloaded CSV file, and returns it in a PsyTrack compatible `dict`. We will use this function to access Akrami human subject data in the figures below.

In [None]:
akrami_human_data_path = SPATH + "human_auditory.csv"

HUMAN_DF = pd.read_csv(akrami_human_data_path)
def getHuman(subject):
    
    df = HUMAN_DF[HUMAN_DF['subject_id']==subject]
    
    s_a = (df["s_a"] - np.mean(df["s_a"]))/np.std(df["s_a"])
    s_b = (df["s_b"] - np.mean(df["s_b"]))/np.std(df["s_b"])
    
    s_avg = (df["s_a"][:-1] + df["s_b"][:-1])/2
    s_avg = (s_avg - np.mean(s_avg))/np.std(s_avg)
    s_avg = np.hstack(([0], s_avg))
    
    inputs = dict(s_a = np.array(s_a)[:, None],
                  s_b = np.array(s_b)[:, None],
                  s_avg = np.array(s_avg)[:, None])

    dat = dict(
        subject = subject,
        inputs = inputs,
        s_a = np.array(df['s_a']),
        s_b = np.array(df['s_b']),
        correct = np.array(df['reward']),
        answer = np.array(df['correct_side']),
        y = np.array(df['choice'])
    )
    return dat

# Figure 1 | Schematic of Psychometric Weight Model

**(A)** IBL task schematic (Illustrator only)

**(B)** Example inputs (Illustrator only)

**(C)** Schematic weight trajectories using regressors in (B)

**(D)** Psychometric curves produced from weights from (C) at different points in training

## Figure 1c

In [None]:
# Fig 1b — generate schematic weight trajectories
def sigmoid(lenx, bias, slope):
    x = np.arange(lenx)
    return 1.0/(1.0 + np.exp(-(x-bias)/slope))

x = np.arange(10000)
bias_w = 0.8*sigmoid(10000, 6000, 1500)[::-1] - 0.08
sL_w = -sigmoid(10000, 5000, 700) + 0.05
sR_w = sigmoid(10000, 6500, 800) - 0.1

gain = 4
w = gain*np.vstack((bias_w,sL_w,sR_w))

# Plotting
plt.figure(figsize=(3.5,1.2))
plt.plot(x, w[0], c=colors['bias'], lw=2)
plt.plot(x, w[1], c=colors['sL'], lw=2)
plt.plot(x, w[2], c=colors['sR'], lw=2)

plt.axhline(0, color="black", linestyle="--", alpha=0.5, zorder=0)

plt.xticks([]); plt.yticks([0])
plt.gca().set_yticklabels([0])
plt.xlim(0,10000); plt.ylim(-1.02*gain,1.02*gain)
# plt.xlabel("Trials"); plt.ylabel("Weights")

# hand pick divider lines to make the Illustrator plot look nice
xs = [1270,4975,8690]
for x in xs:
    plt.axvline(x, color="gray", lw=2, alpha=0.0)
    
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False) 

# this makes the plot itself reflect the figsize, excluding the axis labels and ticks
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig1c.pdf")

## Figure 1d

In [None]:
# Fig 1c — generate psychmoetric curves corresponding to weights at various times
def generate_psych(w,x):
    xL = x.copy(); xR = x.copy()
    xL[xL>0] = 0; xR[xR<0] = 0
    xL = np.abs(xL)
    
    wx = w[0] + xL*w[1] + xR*w[2]
    pR = 1/(1+np.exp(-wx))
    return pR

# Generate psychometric curve for each time point in xs
for i,cut in enumerate(xs):

    x = np.arange(-1,1.01,.01)
    pR = generate_psych(w[:,cut],x)
    
    x_dot = np.array([-1.0,-0.5,0.0,0.5,1.0])
    pR_dot = generate_psych(w[:,cut],x_dot)

    plt.figure(figsize=(1.25,1))
    plt.plot(x*100, pR*100, color="black", lw=1.5)
    plt.plot(x_dot*100, pR_dot*100, color="black", marker='o', lw=0, markersize=4)

    # Grid lines
    plt.axvline(  0, color="black", linestyle="-", alpha=0.1)
    plt.axhline( 50, color="black", linestyle="-", alpha=0.1)
    
    plt.xticks([-100,-50,0,50,100]); plt.yticks([0,50,100])
    plt.gca().set_xticklabels([]); plt.gca().set_yticklabels([])
    plt.xlim(-110,110); plt.ylim(0,100)
#     plt.xlabel("Right - Left Contrast (%)"); plt.ylabel("Prob. Left (%)")
    
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['top'].set_visible(False)

    plt.savefig(SPATH + "Fig1d_"+str(i)+".pdf")


# Figure 2 | Recovering Psychometric Weights from Simulated Data

**(A)** $K=4$ simulated weights of different sigma for $N=5000$ trials, with recovery showing 95% credible interval

**(B)** Show the recovery for each sigma in (A), with 95% credible interval

**(C)** 3 simulated weights as in (A), except with $\sigma_{\text{Day}}$  

**(D)** Show the recovery for hyperparameters in (C), as in (B)

## Figure 2a

In [None]:
# Fig 2a — generate simulated weights and recover with errorbars
# Simulate
seed = 31  # paper uses 31
num_weights = 4
num_trials = 5000
hyper = {'sigma'   : 2**np.array([-4.0,-5.0,-6.0,-7.0]),
         'sigInit' : 2**np.array([ 0.0, 0.0, 0.0, 0.0])}

# Compute
gen = psy.generateSim(K=num_weights, N=num_trials, hyper=hyper,
                      boundary=6.0, iterations=1, seed=seed, savePath=None)

In [None]:
# Recovery
rec = psy.recoverSim(gen)

# Save interim result
np.savez_compressed(SPATH+'fig2a_data.npz', rec=rec, gen=gen)

In [None]:
# Reload data
rec = np.load(SPATH+'fig2a_data.npz', allow_pickle=True)['rec'].item()
gen = np.load(SPATH+'fig2a_data.npz', allow_pickle=True)['gen'].item()

# Plotting
sim_colors = [colors['bias'], colors['s1'], colors['s2'], colors['s_avg']]
fig = plt.figure(figsize=(3.75,1.4))
for i, c in enumerate(sim_colors):
    plt.plot(gen['W'][:,i], c=c, lw=0.5, zorder=2*i)
    plt.plot(rec['wMode'][i], c=c, lw=1, linestyle='--', alpha=0.5, zorder=2*i+1)
    plt.fill_between(np.arange(num_trials),
                     rec['wMode'][i] - 2 * rec['hess_info']['W_std'][i],
                     rec['wMode'][i] + 2 * rec['hess_info']['W_std'][i],
                     facecolor=c, alpha=0.2, zorder=2*i+1)

plt.axhline(0, color="black", linestyle="--", lw=0.5, alpha=0.5, zorder=0)

plt.xticks(1000*np.arange(0,6))
plt.gca().set_xticklabels([0,1000,2000,3000,4000,5000])
plt.yticks(np.arange(-4,5,2))

plt.xlim(0,5000); plt.ylim(-4.3,4.3)
# plt.xlabel("Trials"); plt.ylabel("Weights")

plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig2a.pdf")

## Figure 2b

In [None]:
# Reload data
rec = np.load(SPATH+'fig2a_data.npz', allow_pickle=True)['rec'].item()

# Plotting
sim_colors = [colors['bias'], colors['s1'], colors['s2'], colors['s_avg']]
plt.figure(figsize=(1.4,1.4))

true_sigma = np.log2(rec['input']['sigma'])
avg_sigma = np.log2(rec['hyp']['sigma'])
err_sigma = rec['hess_info']['hyp_std']

for i, c in enumerate(sim_colors):
    plt.plot([i-0.3, i+0.3], [true_sigma[i]]*2, color="black", linestyle="-", lw=1.2, zorder=0)
    plt.errorbar([i], avg_sigma[i], yerr=1.96*err_sigma[i], c=c, lw=1, marker='o', markersize=5)

plt.xticks([0,1,2,3]); plt.yticks(np.arange(-8,-2))
plt.xlim(-0.5,3.5); plt.ylim(-7.5,-3.5)

plt.gca().set_xticklabels([r"$\sigma_1$", r"$\sigma_2$", r"$\sigma_3$", r"$\sigma_4$"])

# plt.ylabel(r"$\log_2(\sigma)$")

plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig2b.pdf")

## Figure 2c

_2 min_

In [None]:
# Fig 2c — generate simulated weights and recover with errorbars
# Simulate
seed = 102  # paper uses 102
num_weights = 3
num_trials = 5000
hyper = {'sigma'   : 2**np.array([-4.5, -5.0,-16.0]),
         'sigInit' : 2**np.array([ 0.0,  0.0,  0.0]),
         'sigDay'  : 2**np.array([ 0.5,-16.0,  1.0])
        }
days = [500]*9

# Compute
gen = psy.generateSim(K=num_weights, N=num_trials, hyper=hyper, days=days,
                      boundary=10.0, iterations=1, seed=seed, savePath=None)

In [None]:
# Recovery
rec = psy.recoverSim(gen)

# Save interim result
np.savez_compressed(SPATH+'fig2c_data.npz', rec=rec, gen=gen)

In [None]:
# Reload data
rec = np.load(SPATH+'fig2c_data.npz', allow_pickle=True)['rec'].item()
gen = np.load(SPATH+'fig2c_data.npz', allow_pickle=True)['gen'].item()

# Plotting
sim_colors = [colors['bias'], colors['s1'], colors['s2']]
fig = plt.figure(figsize=(3.75,1.4))
for i, c in enumerate(sim_colors):
    plt.plot(gen['W'][:,i], c=c, lw=0.5, zorder=5-i)
    plt.plot(rec['wMode'][i], c=c, lw=1, linestyle='--', alpha=0.5, zorder=5-i)
    plt.fill_between(np.arange(num_trials),
                     rec['wMode'][i] - 2 * rec['hess_info']['W_std'][i],
                     rec['wMode'][i] + 2 * rec['hess_info']['W_std'][i],
                     facecolor=c, alpha=0.2, zorder=5-i)

for i in np.cumsum(days):
    plt.axvline(i, color="black", lw=0.5, alpha=0.5, zorder=0)
    
plt.axhline(0, color="black", linestyle="--", lw=0.5, alpha=0.5, zorder=0)
plt.xticks(1000*np.arange(0,6))
plt.gca().set_xticklabels([0,1000,2000,3000,4000,5000])
plt.yticks(np.arange(-4,5,2))

plt.xlim(0,5000); plt.ylim(-4.3,4.3)
# plt.xlabel("Trials"); plt.ylabel("Weights")

plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig2c.pdf")

## Figure 2d

In [None]:
# Reload data
rec = np.load(SPATH+'fig2c_data.npz', allow_pickle=True)['rec'].item()

# Plotting
plt.figure(figsize=(1.4,1.4))

true_sigma = np.log2(rec['input']['sigma'])
avg_sigma = np.log2(rec['hyp']['sigma'])
err_sigma = rec['hess_info']['hyp_std'][:3]
for i, c in enumerate(sim_colors):
    plt.plot([2*i-0.3, 2*i+0.3], [true_sigma[i]]*2, color="black", linestyle="-", lw=1.2, zorder=0)
    plt.errorbar([2*i], avg_sigma[i], yerr=1.96*err_sigma[i], c=c, lw=1, marker='o', markersize=5)

true_sigma = np.log2(rec['input']['sigDay'])
avg_sigma = np.log2(rec['hyp']['sigDay'])
err_sigma = rec['hess_info']['hyp_std'][3:]
for i, c in enumerate(sim_colors):
    plt.plot([2*i-0.3+1, 2*i+0.3+1], [true_sigma[i]]*2, color="black", linestyle="-", lw=1.2, zorder=0)
    plt.errorbar([2*i+1], avg_sigma[i], yerr=1.96*err_sigma[i], c=c, lw=1, marker='s', markersize=5)

plt.axvspan(2.6,4.4, facecolor="black", edgecolor="none", alpha=0.1)
plt.xticks(np.arange(6))
plt.yticks([-8,-6,-4,-2,0,2])
plt.gca().set_xticklabels([r"$\sigma_1$", r"$_{day}$",
                           r"$\sigma_2$", r"$_{day}$",
                           r"$\sigma_3$", r"$_{day}$",])
plt.xlim(-0.5,5.5); plt.ylim(-8.5,2.5)
# plt.ylabel(r"$\log_2(\sigma)$")

plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig2d.pdf")

# Figure 3 | Visualization of Early Learning in IBL Mice

**(A)** A performance curve of an example mouse (`CSHL_003`) on easy trials during early training

**(B)** Psychometric weights for the mouse and sessions shown in (A)

**(C)** The performance curves of a subset (1 in 8) of the full population of mice on easy trials in early training (first 16 sessions)

**(D)** Psychometric weights for all the mice shown in (C), plus average weights calculated from all mice in the population

## Figure 3a

In [None]:
from datetime import date, datetime, timedelta

outData = getMouse('CSHL_003', 5)
easy_trials = (outData['contrastLeft'] > 0.45).astype(int) | (outData['contrastRight'] > 0.45).astype(int)

perf = []
for d in np.unique(outData['date']):
    date_trials = (outData['date'] == d).astype(int)
    inds = (date_trials * easy_trials).astype(bool)
    perf += [np.average(outData['correct'][inds])]

dates = np.unique([datetime.strptime(i, "%Y-%m-%d") for i in outData['date']])
dates = np.arange(len(dates)) + 1

# Plotting
fig = plt.figure(figsize=(2.75,0.9))

plt.plot(dates[:16], perf[:16], color="black", linewidth=1.5, zorder=2)
plt.scatter(dates[9], perf[9], c="white", s=30, edgecolors="black", linestyle="--", lw=0.75, zorder=5, alpha=1)

plt.axhline(0.5, color="black", linestyle="--", lw=1, alpha=0.5, zorder=0)

plt.xticks(np.arange(0,16,5))
plt.yticks([0.4,0.6,0.8,1.0])
plt.ylim(0.25,1.0)
plt.xlim(1, 15.5)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig3a.pdf")

## Figure 3b

In [None]:
# Collect data from manually determined training period
new_dat = psy.trim(outData, END=7000)

# Compute
weights = {'bias' : 0, 'cL' : 1, 'cR' : 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : None
  }
optList = ['sigma']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat}

# Save interim result
np.savez_compressed(SPATH+'fig3b_data.npz', dat=dat)

In [None]:
dat = np.load(SPATH+'fig3b_data.npz', allow_pickle=True)['dat'].item()

fig = psy.plot_weights(dat['wMode'], dat['weights'], days=dat['new_dat']["dayLength"], 
                       errorbar=dat['W_std'], figsize=(2.75,1.3))

plt.axvline(np.cumsum(dat['new_dat']['dayLength'])[8], c="black", lw=1.5, ls="--", zorder=15)
plt.ylim(-5.3,5.3)
plt.xlim(0, 6950)
plt.yticks([-4,-2,0,2,4])
plt.xlabel(None); plt.ylabel(None)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig3b.pdf")

## Figure 3c

_2 min_

In [None]:
from datetime import date, datetime, timedelta

all_dates = []
all_perf = []
for s in np.unique(MOUSE_DF['subject']):
    outData = getMouse(s, 5)
    easy_trials = (outData['contrastLeft'] > 0.45).astype(int) | (outData['contrastRight'] > 0.45).astype(int)

    perf = []
    for d in np.unique(outData['date']):
        date_trials = (outData['date'] == d).astype(int)
        inds = (date_trials * easy_trials).astype(bool)
        perf += [np.average(outData['correct'][inds])]

    dates = np.unique([datetime.strptime(i, "%Y-%m-%d") for i in outData['date']])
    dates = np.arange(len(dates))
    
    all_dates += [dates]
    all_perf += [perf]
    
x = [[] for i in range(25)]
for dates, perf in zip(all_dates, all_perf):
    for ind, d in enumerate(dates):
        if d < 25:
            x[d] += [perf[ind]]
perf_avg = [np.average(i) for i in x] 

In [None]:
fig = plt.figure(figsize=(2.75,0.9))

for dates, perf in zip(all_dates[::8], all_perf[::8]):
    plt.plot(dates[:25], perf[:25], color="black", linewidth=1, alpha=0.2, zorder=1)

plt.plot(perf_avg[:25], color="black", lw=2.5, alpha=0.8, zorder=6)

plt.axhline(0.5, color="black", linestyle="--", lw=1, alpha=0.5, zorder=0)

plt.xticks(np.arange(0,16,5))
plt.yticks([0.4,0.6,0.8,1.0])
plt.ylim(0.25,1.0)
plt.xlim(1, 15.5)
plt.gca().set_yticklabels([])

plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig3c.pdf")

## Figure 3d
_20 min_

In [None]:
for i, s in enumerate(MOUSE_DF['subject']):

    print("\rProcessing " + str(i+1) + " of " + str(len(MOUSE_DF['subject'].unique())), end="")

    outData = getMouse(s, 5)
    
    # Collect data from manually determined training period
    new_dat = psy.trim(outData, END=7000)

    # Compute
    weights = {'bias' : 0, 'cL' : 1, 'cR' : 1}
    K = np.sum([weights[i] for i in weights.keys()])
    hyper_guess = {
     'sigma'   : [2**-5]*K,
     'sigInit' : 2**5,
     'sigDay'  : None
      }
    optList = ['sigma']

    hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList, hess_calc=None)

    dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'hess_info' : hess_info,
           'weights' : weights, 'new_dat' : new_dat}

    # Save interim result
    np.savez_compressed(SPATH+'fig3c_'+s+'_data.npz', dat=dat)


In [None]:
plt.figure(figsize=(2.75,1.3))
w0 = []
w1 = []
for i, s in enumerate(np.unique(MOUSE_DF['subject'])):

    dat = np.load(SPATH+'fig3c_'+s+'_data.npz', allow_pickle=True)['dat'].item()

    w0 += [np.hstack((dat['wMode'][0][:7000], [np.nan]*(7000 - len(dat['wMode'][0][:7000]))))]
    w1 += [np.hstack((dat['wMode'][1][:7000], [np.nan]*(7000 - len(dat['wMode'][1][:7000]))))]

    if not i%8:
        plt.plot(dat['wMode'][0], color=colors['cL'], lw=1, alpha=0.2, zorder=4)
        plt.plot(dat['wMode'][1], color=colors['cR'], lw=1, alpha=0.2, zorder=2)

    
plt.plot(np.nanmean(w0, axis=0), color=colors['cL'], lw=2.5, alpha=0.8, zorder=6)
plt.plot(np.nanmean(w1, axis=0), color=colors['cR'], lw=2.5, alpha=0.8, zorder=6)
plt.axhline(0, linestyle='--', color="black", lw=1, alpha=0.5, zorder=0)
plt.ylim(-5.3,5.3)
plt.xlim(0, 6950)
plt.yticks([-4,-2,0,2,4])
plt.gca().set_yticklabels([])

plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig3d.pdf")

# Figure 4 | Adaptation to Bias Blocks in an Example IBL Mouse

**(A)** Show performance curve of example mouse on easy trials, highlight different training periods

**(B)** Show data for early bias blocks of example mouse

**(C)** Show data for late bias blocks of example mouse

**(D)** For early bias blocks (B), chunk the bias weight by block, plot how the weight changes from start to end of each block

**(E)** Same as (D) but for late bias blocks (C)

**(F)** Overlay optimal bias weight on the 2nd session shown in (C)

## Figure 4a

In [None]:
from datetime import date, datetime, timedelta

outData = getMouse("CSHL_003", 5)
easy_trials = (outData['contrastLeft'] > 0.45).astype(int) | (outData['contrastRight'] > 0.45).astype(int)

perf = []
for d in np.unique(outData['date']):
    date_trials = (outData['date'] == d).astype(int)
    inds = (date_trials * easy_trials).astype(bool)
    perf += [np.average(outData['correct'][inds])]

dates = [datetime.strptime(i, "%Y-%m-%d") for i in outData['date']]
dates = np.arange(len(dates)) + 1

# Plotting
plt.figure(figsize=(3.5,0.9))
plt.plot(dates[:52], perf[:52], color="black", linewidth=1.5, zorder=2)

plt.axhline(0.5, linestyle='--', color="black", lw=1, alpha=0.5, zorder=1)
plt.yticks([0.4,0.6,0.8,1.0])
plt.ylim(0.25,1)
plt.xlim(1,47)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
#plt.ylabel("Performance\n(on easy trials)")
#plt.xlabel("Weeks of Training")

plt.axvspan(0,17.5, ymax=1,
            edgecolor='None', alpha=0.1, facecolor="black", zorder=0)
plt.axvspan(16.5,19.5, linestyle="-", lw=2.5, ymin=0.03, ymax=0.98,
            edgecolor='#E32D91', alpha=.8, facecolor="None", zorder=8)
plt.axvspan(43.5,45.5, linestyle="-", lw=2.5, ymin=0.03, ymax=0.98,
            edgecolor='#9252AB', alpha=.8, facecolor="None", zorder=9)

plt.subplots_adjust(0,0,1,1) 
# plt.savefig(SPATH + "Fig4a.pdf")

## Figure 4b

In [None]:
# Collect data from manually determined training period
outData = getMouse("CSHL_003", 5)
_start  = np.where(outData['date'] >= '2019-03-21')[0][0]
_end    = np.where(outData['date'] >= '2019-03-23')[0][0]
new_dat = psy.trim(outData, START=_start, END=_end)

# Hardcode random trials where probL != 0.5 before bias blocks begin to 0.5
# (fyi, this is due to anti-biasing in the IBL early training protocol)
new_dat['probL'][:np.where(new_dat['date'] >= '2019-03-22')[0][0]] = 0.5

# Compute
weights = {'bias' : 1, 'cL' : 1, 'cR' : 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**-5]*K
  }
optList = ['sigma', 'sigDay']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat}

# Save interim result
np.savez_compressed(SPATH+'fig4b_data.npz', dat=dat)

In [None]:
BIAS_COLORS = {50 : 'None', 20 : psy.COLORS['sR'], 80 : psy.COLORS['sL']}
def addBiasBlocks(fig, pL):
    plt.sca(fig.gca())
    i = 0
    while i < len(pL):
        start = i
        while i+1 < len(pL) and np.linalg.norm(pL[i] - pL[i+1]) < 0.0001:
            i += 1
        fc = BIAS_COLORS[int(100 * pL[start])]
        plt.axvspan(start, i+1, facecolor=fc, alpha=0.2, edgecolor=None)
        i += 1
    return fig

In [None]:
dat = np.load(SPATH+'fig4b_data.npz', allow_pickle=True)['dat'].item()

fig = psy.plot_weights(dat['wMode'], dat['weights'], days=dat['new_dat']["dayLength"], 
                       errorbar=dat['W_std'], figsize=(2.75,1.3))
fig = addBiasBlocks(fig, dat['new_dat']['probL'])

plt.xlabel(None); plt.ylabel(None)
plt.gca().set_yticks(np.arange(-6, 7,2))
plt.ylim(-5.3,5.3)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig4b.pdf")

## Figure 4c

In [None]:
# Collect data from manually determined training period
outData = getMouse("CSHL_003", 5)
_start  = np.where(outData['date'] >= '2019-04-30')[0][0]
_end    = np.where(outData['date'] >= '2019-05-02')[0][0]
new_dat = psy.trim(outData, START=_start, END=_end)

# Compute
weights = {'bias' : 1, 'cL' : 1, 'cR' : 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**-5]*K
  }
optList = ['sigma', 'sigDay']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat}

# Save interim result
np.savez_compressed(SPATH+'fig4c_data.npz', dat=dat)

In [None]:
dat = np.load(SPATH+'fig4c_data.npz', allow_pickle=True)['dat'].item()

fig = psy.plot_weights(dat['wMode'], dat['weights'], days=dat['new_dat']["dayLength"], 
                       errorbar=dat['W_std'], figsize=(2.75,1.3))
fig = addBiasBlocks(fig, dat['new_dat']['probL'])

plt.xlabel(None); plt.ylabel(None)
plt.gca().set_yticks(np.arange(-6, 7,2))
plt.gca().set_yticklabels([])
plt.ylim(-5.3,5.3)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig4c.pdf")

## Figure 4d

In [None]:
outData = getMouse("CSHL_003", 5)

# Collect data from manually determined training period
_start  = np.where(outData['date'] >= '2019-03-22')[0][0]
_end    = np.where(outData['date'] >= '2019-03-26')[0][0]
new_dat = psy.trim(outData, START=_start, END=_end)

# Hardcode random trials where probL != 0.5 before bias begins to 0.5
# (fyi, this is due to anti-biasing in the IBL early training protocol)
new_dat['probL'][:np.where(new_dat['date'] >= '2019-03-22')[0][0]] = 0.5

# Compute
weights = {'bias' : 1, 'cL' : 1, 'cR' : 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**-5]*K
  }
optList = ['sigma', 'sigDay']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat}

# Save interim result
np.savez_compressed(SPATH+'fig4d_data.npz', dat=dat)

In [None]:
def bias_diff(dat_load, figsize=(1.5,1.5)):
    dat = np.load(dat_load, allow_pickle=True)['dat'].item()
    pL = dat['new_dat']['probL']
    pL_diff = pL[1:] - pL[:-1]
    inds = np.where(pL_diff)[0]
    start_inds = [0] + list(inds+1)
    start_inds = [i for i in start_inds if (np.isclose(pL[i], 0.2) or np.isclose(pL[i], 0.8))]
    end_inds = list(inds) + [len(pL)-1]
    end_inds = [i for i in end_inds if (np.isclose(pL[i], 0.2) or np.isclose(pL[i], 0.8))]

    fig = plt.figure(figsize=figsize)
    for s, e in zip(start_inds, end_inds):
        if e-s < 20: continue
        block_inds = np.arange(s, e+1)
        block = dat['wMode'][0, block_inds] - dat['wMode'][0, s]
        if np.isclose(pL[s], 0.2):
            plt.plot(block, color=colors['cR'], alpha=0.8, zorder=2, lw=1)
        else:
            plt.plot(block, color=colors['cL'], alpha=0.8, zorder=4, lw=1)
    
    plt.axhline(0, linestyle='--', color="black", lw=1, alpha=0.5, zorder=0)
    plt.ylim(-5.5,5.5)
    plt.xlim(0, 75)

    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['top'].set_visible(False)
    plt.subplots_adjust(0,0,1,1)
    return fig

fig = bias_diff(SPATH+'fig4d_data.npz', figsize=(1.3,1.3));
plt.gca().set_yticks([-4,-2,0,2,4])
plt.savefig(SPATH + "Fig4d.pdf")

## Figure 4e

In [None]:
outData = getMouse("CSHL_003", 5)

# Collect data from manually determined training period
_start  = np.where(outData['date'] >= '2019-04-30')[0][0]
_end    = np.where(outData['date'] >= '2019-05-03')[0][0]
new_dat = psy.trim(outData, START=_start, END=_end)

# Compute
weights = {'bias' : 1, 'cL' : 1, 'cR' : 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**-5]*K
  }
optList = ['sigma', 'sigDay']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat}

# Save interim result
np.savez_compressed(SPATH+'fig4e_data.npz', dat=dat)

In [None]:
fig = bias_diff(SPATH+'fig4e_data.npz', figsize=(1.3,1.3));
plt.gca().set_yticks([-4,-2,0,2,4])
plt.gca().set_yticklabels([])
plt.savefig(SPATH + "Fig4e.pdf")

## Figure 4f

In [None]:
def max_bias(bias, side, wL, wR):
        
    contrasts = np.array([-1., -0.25, -0.125, -0.0625, 0., 0.0625, 0.125, 0.25, 1.])
    
    p=5
    transformed_con = np.tanh(p*np.abs(contrasts))/np.tanh(p)

    p_biasL = [.8/4.5]*4 + [1/9] + [.2/4.5]*4    
    p_biasR = [.2/4.5]*4 + [1/9] + [.8/4.5]*4
    p_biasM = [1/9]*9

    w = [wL]*4 + [0] + [wR]*4
    correct = [0]*4 + [0] + [1]*4

    pL = 1 - (1/(1+np.exp(-(transformed_con*w + bias))))
    pCorrect = np.abs(correct - pL)
    
    if side=="L":
        pCorrect[4] = pL[4]*0.8 + (1-pL[4])*0.2
        expval = np.sum(p_biasL * pCorrect)
    
    elif side=="R":
        pCorrect[4] = pL[4]*0.2 + (1-pL[4])*0.8
        expval = np.sum(p_biasR * pCorrect)
    
    elif side=="M":
        pCorrect[4] = 0.5
        expval = np.sum(p_biasM * pCorrect)
    
    return -expval

In [None]:
from scipy.optimize import minimize

dat = np.load(SPATH+'fig4c_data.npz', allow_pickle=True)['dat'].item()
start = dat['new_dat']['dayLength'][0]

optBias = []
optReward = []
for i in np.arange(start, dat['wMode'].shape[1]):
    
    if dat['new_dat']['probL'][i] < 0.21: side = 'R'
    elif dat['new_dat']['probL'][i] > 0.79: side = 'L'
    else: side = 'M'
        
    res = minimize(max_bias,[0], args=(side, dat['wMode'][1,i], dat['wMode'][2,i]))
    optBias += [res.x]
    optReward += [-res.fun]

print("Avg. Reward:", np.mean(optReward))

In [None]:
fig = psy.plot_weights(dat['wMode'], dat['weights'], days=dat['new_dat']["dayLength"],
                       errorbar=dat['W_std'], figsize=(2.75,1.3))
fig = addBiasBlocks(fig, dat['new_dat']['probL'])

plt.plot(np.arange(start, dat['wMode'].shape[1]), optBias, 'k-', lw=2, zorder=10)
plt.gca().set_yticks(np.arange(-6, 7,2))
plt.gca().set_yticklabels([])
plt.gca().set_xticks([750, 1000, 1250])
plt.xlim(start, None); plt.ylim(-5.3,5.3)
plt.xlabel(None); plt.ylabel(None)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig4f.pdf")

In [None]:
# Actual predicted reward using actual bias weight
from scipy.optimize import minimize

optReward_pred = []
optReward_0bias = []
for i in np.arange(start, dat['wMode'].shape[1]):
    
    if dat['new_dat']['probL'][i] < 0.21: side = 'R'
    elif dat['new_dat']['probL'][i] > 0.79: side = 'L'
    else: side = 'M'
        
    optReward_pred += [-max_bias(dat['wMode'][0,i], side, dat['wMode'][1,i], dat['wMode'][2,i])]
    optReward_0bias += [-max_bias(0.0, side, dat['wMode'][1,i], dat['wMode'][2,i])]

print("Predicted Avg. Reward:", np.mean(optReward_pred))
print("No Bias Avg. Reward:", np.mean(optReward_0bias))
print("Empirical Avg. Reward:", np.mean(dat['new_dat']['correct'][start:]))


# Figure 5 | Visualization of Learning in an Example Akrami Rat

**(A)** Akrami rat task schematic (Illustrator only)

**(B)** Psychometric weights for an example rat (`W080`)
 
**(C)** Compare model predictions to empirical choice behavior under various trial conditions, for a 500 trial window starting at trial 2000.

**(D)** As in (C), starting at trial 6500

**(E)** As in (C), starting at trial 11000

## Figure 5b

_15 min_

In [None]:
outData = getRat("W080")
new_dat = psy.trim(outData, START=0, END=12500)

weights = {'bias': 1, 's_a': 1, 's_b': 1, 'h': 1, 'c': 1, "s_avg": 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**-4]*K,
  }
optList = ['sigma', 'sigDay']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat}

# Save interim result
np.savez_compressed(SPATH+'fig5b_data.npz', dat=dat)

In [None]:
dat = np.load(SPATH+'fig5b_data.npz', allow_pickle=True)['dat'].item()

fig = psy.plot_weights(dat['wMode'], dat['weights'], days=dat['new_dat']["dayLength"], 
                       errorbar=dat['W_std'], figsize=(4.75,1.4))

selected_days = [[2000,2500], [6500,7000], [11000,11500]]
for d in selected_days:
    plt.plot(d, [-1.3]*2, lw=2, color="k")

plt.xlabel(None); plt.ylabel(None)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig5b.pdf")

## Fig 5c-e

_2.5 hours_

In [None]:
outData = getRat("W080")
new_dat = psy.trim(outData, END=12500)

FOLDS = 10  # number of cross-validation folds
SEED = 42   # controls random divide of trials into FOLDS bins

weights = {'bias': 1, 's_a': 1, 's_b': 1, 'h': 1, 'c': 1, "s_avg": 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**-4]*K,
  }
optList = ['sigma', 'sigDay']

_, xval_pL = psy.crossValidate(new_dat, hyper_guess, weights, optList, F=FOLDS, seed=SEED)
np.savez_compressed(SPATH+'fig5c_data.npz', new_dat=new_dat, xval_pL=xval_pL)

In [None]:
from datetime import date, datetime, timedelta
from scipy.stats import sem

outData = np.load(SPATH+'fig5c_data.npz', allow_pickle=True)['new_dat'].item()
xval_pL = np.load(SPATH+'fig5c_data.npz', allow_pickle=True)['xval_pL'] 
outData['xval_pR'] = 1 - xval_pL

all_hists = []
all_ys = []
all_pRs = []

selected_days = [[2000,2500], [6500,7000], [11000,11500]]
for d in selected_days:
    new_dat = psy.trim(outData, START=d[0], END=d[1])

    hists = []
    ys = []
    pRs = []
    for h in [-1,1]:
        for c in [-1,1]:
            for a in [-1,1]:
                ind_h = (new_dat['inputs']['h'][:,0] == h)
                ind_c = (new_dat['inputs']['c'][:,0] == c)
                ind_a = (np.sign(new_dat['s_a'] - new_dat['s_b']) == a)
                inds = ind_h * ind_c * ind_a
                hists += [[h,c,a]]
                ys += [new_dat['y'][inds]]
                pRs += [new_dat['xval_pR'][inds]]
    
    all_hists += [hists]
    all_ys += [ys]
    all_pRs += [pRs]

In [None]:
import matplotlib as mpl

def colorFader(c1,c2,mix=0):
    c1=np.array(mpl.colors.to_rgb(c1))
    c2=np.array(mpl.colors.to_rgb(c2))
    w =np.array(mpl.colors.to_rgb("white"))
    if mix <= 0.5:
        return mpl.colors.to_hex((1-mix*2)*c1 + mix*2*w)
    else:
        return mpl.colors.to_hex((1-(mix-0.5)*2)*w + (mix-0.5)*2*c2)

def cF(mix):
    return colorFader(colors['s2'],colors['s1'],mix)

In [None]:
diff = 0.19
rad = 0.45 
cm = plt.get_cmap('RdBu_r')

for d in range(len(selected_days)):
    plt.figure(figsize=(0.75,1.5))
    avg = [np.average(i) for i in all_ys[d]]
    avg_pR = [np.average(i) for i in all_pRs[d]]

    std = [sem(i) for i in all_ys[d]]
    std_pR = [sem(i) for i in all_pRs[d]]

    for i in range(len(avg)):
        h = all_hists[d][i][0]
        c = all_hists[d][i][1]
        a = all_hists[d][i][2]
        x = a/2
        y = h + c/2

        plt.text(x-diff, y+diff, int(np.round(avg_pR[i]*100)),
                 ha="center", va="center", fontsize=10, zorder=i+1)
        t1 = plt.Polygon([[x-rad,y-rad],[x-rad,y+rad],[x+rad,y+rad]], 
                         facecolor=cF(avg_pR[i]), edgecolor="k", lw=0, zorder=i)
        plt.gca().add_patch(t1)

        plt.text(x+diff, y-1.5*diff, int(np.round(avg[i]*100)),
                 ha="center", va="center", fontsize=10, zorder = i+11)
        t2 = plt.Polygon([[x-rad,y-rad],[x+rad,y-rad],[x+rad,y+rad]], 
                         facecolor=cF(avg[i]), edgecolor="k", lw=0.5, zorder = i+10)
        plt.gca().add_patch(t2)
        
    plt.ylim(-2,2)
    plt.xlim(-1,1)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['left'].set_visible(False)
    plt.gca().spines['bottom'].set_visible(False)
    plt.gca().set_xticks([])
    plt.gca().set_yticks([])

    plt.subplots_adjust(0,0,1,1) 
    plt.savefig(SPATH + "Fig5cde_" + str(d) + ".pdf")

In [None]:
# Make colorbar
n=500
fig, ax = plt.subplots(figsize=(.2, 1.5))
for x in range(n+1):
    ax.axhline(1 - x/n, color=cF(x/n), linewidth=4) 

plt.gca().set_yticks([0.005,.25,.5,.75,0.995])
plt.gca().set_xticks([])
plt.gca().set_xticklabels([])
plt.gca().set_yticklabels([])
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['left'].set_visible(False)
plt.gca().spines['bottom'].set_visible(False)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig5_colorbar.pdf")

# Figure 6 | Population Psychometric Weights from Akrami Rats

**(A)** Show overlay of population weights (including the average weights) for Tones A + B

**(B)** For the Bias weight

**(C)** For the Previous Tones weight

**(D)** For the Previous (Correct) Answer weight

**(E)** For the Previous Choice weight

**F)** Show average hyperparamter recovery ($\sigma$ and $\sigma_\text{day}$) for each weight ($\pm1$SD)

## Figure 6a

_6 hours_

In [None]:
all_rats = RAT_DF["subject_id"].unique()
for i, subject in enumerate(all_rats):

    print("\rProcessing " + str(i+1) + " of " + str(len(all_rats)), end="")
        
    outData = getRat(subject)

    # Collect data from manually determined training period
    new_dat = psy.trim(outData, END=20000)

    # Compute
    weights = {'bias': 1, 's_a': 1, 's_b': 1, 'h': 1, 'c': 1, "s_avg": 1}
    K = np.sum([weights[i] for i in weights.keys()])
    hyper_guess = {
     'sigma'   : [2**-5]*K,
     'sigInit' : 2**5,
     'sigDay'  : [2**-4]*K,
      }
    optList = ['sigma', 'sigDay']

    hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList, hess_calc=None)

    dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'hess_info' : hess_info,
           'weights' : weights, 'new_dat' : new_dat}

    # Save interim result
    np.savez_compressed(SPATH+'fig6a_'+subject+'_data.npz', dat=dat)

In [None]:
all_labels = []
all_w = []
for subject in RAT_DF["subject_id"].unique():
    rat = np.load(SPATH+'fig6a_'+subject+'_data.npz', allow_pickle=True)['dat'].item()
    
    labels = []
    for j in sorted(rat['weights'].keys()):
        labels += [j]*rat['weights'][j]
        
    all_labels += [np.array(labels)]
    all_w += [rat['wMode']] 

In [None]:
def plot_all(all_labels, all_w, Weights, figsize):
    fig = plt.figure(figsize=figsize)
    Weights = [Weights] if type(Weights) is str else Weights
    avg_len=20000
    for i, W in enumerate(Weights):
        avg = []
        for i in np.arange(0,len(all_w),1):
            bias_ind = np.where(all_labels[i] == W)[0][-1]
            bias_w = all_w[i][bias_ind]
            avg += [list(bias_w[:avg_len]) + [np.nan]*(avg_len - len(bias_w[:avg_len]))]
            plt.plot(bias_w, color=colors[W], alpha=0.2, lw=1, zorder=2+i)
        plt.plot(np.nanmean(avg, axis=0), color=colors[W], alpha=0.8, lw=2.5, zorder=5+i)

    plt.axhline(0, color="black", linestyle="--", lw=1, alpha=0.5, zorder=1)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['top'].set_visible(False)
    plt.xlim(0, 19000)
    plt.ylim(-2.5, 2.5)
    return fig

In [None]:
plot_all(all_labels, all_w, ["s_a", "s_b"], (1.85, 0.8))
plt.subplots_adjust(0,0,1,1) 
plt.gca().set_yticks([-2,0,2])
plt.gca().set_xticklabels([])
plt.savefig(SPATH + "Fig6a.pdf")

## Figure 6b

In [None]:
plot_all(all_labels, all_w, ["bias"], (1.85, 0.8))
plt.gca().set_yticks([-2,0,2])
plt.gca().set_xticklabels([])
plt.gca().set_yticklabels([])
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig6b.pdf")

## Figure 6c

In [None]:
plot_all(all_labels, all_w, ["s_avg"], (1.85, 0.8))
plt.gca().set_yticks([-2,0,2])
plt.gca().set_yticklabels([])
plt.gca().set_xticklabels([])
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig6c.pdf")


## Figure 6d

In [None]:
plot_all(all_labels, all_w, ["h"], (1.85, 0.8))
plt.ylim(-0.25, 2.25)
# plt.gca().set_yticklabels([])
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig6d.pdf")

## Figure 6e

In [None]:
plot_all(all_labels, all_w, ["c"], (1.85, 0.8))
plt.ylim(-0.25, 2.25)
plt.gca().set_yticklabels([])
plt.gca().set_xticklabels([])
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig6e.pdf")

## Figure 6f

In [None]:
all_sigma = []
all_sigDay = []
for subject in RAT_DF["subject_id"].unique():
    rat = np.load(SPATH+'fig6a_'+subject+'_data.npz', allow_pickle=True)['dat'].item()
    
    labels = []
    for j in sorted(rat['weights'].keys()):
        labels += [j]*rat['weights'][j]
        
    all_sigma += [rat['hyp']['sigma']]
    all_sigDay += [rat['hyp']['sigDay']]

all_sigma = np.array(all_sigma)
all_sigDay = np.array(all_sigDay)

In [None]:
pos_map = {0: 2, 1: 5, 2: 4, 3: 0, 4: 3, 5: 1}
plt.figure(figsize=(1.55, 0.8))
for i, j in enumerate(labels):
    plt.errorbar([pos_map[i]], np.average(np.log2(all_sigma[:,i])),
                 yerr=np.std(np.log2(all_sigma[:,i])),
                 color=colors[j], marker="o", ms=4, elinewidth=1.5)
    plt.errorbar([pos_map[i]+8], np.average(np.log2(all_sigDay[:,i])),
                 yerr=np.std(np.log2(all_sigDay[:,i])),
                 color=colors[j], marker="s", ms=4, elinewidth=1.5)
    
plt.ylim(-12.5,-2.6)
# plt.xlim(-1,1)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.gca().set_xticks([])
plt.gca().set_yticks([-12, -10, -8, -6, -4])
plt.gca().set_yticklabels([-12, None, -8, None, -4])

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig6f.pdf", transparent=True)

# Figure 7 | Population Psychometric Weights from Akrami Human Subjects

**(A)** Athena human subject task schematic (Illustrator only)

**(B)** Psychometric weights for an example human subject (`subject_id=6`)

**(C)** Show psychometric weights for all human subjects together

## Figure 7b

In [None]:
new_dat = getHuman(6)

# Compute
weights = {'bias': 1, 's_a': 1, 's_b': 1, 's_avg': 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : None
  }
optList = ['sigma']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat}

# Save interim result
np.savez_compressed(SPATH+'fig7b_data.npz', dat=dat)

In [None]:
dat = np.load(SPATH+'fig7b_data.npz', allow_pickle=True)['dat'].item()
fig = psy.plot_weights(dat['wMode'], dat['weights'], errorbar=dat['W_std'], figsize=(4.75,1))

plt.xlabel(None); plt.ylabel(None)
plt.gca().set_xticks([0,500,1000,1500,2000])
plt.gca().set_yticks(np.arange(-2, 3,2))
plt.xlim(0, 1900); plt.ylim(-3.4, 3.4)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig7b.pdf")

## Figure 7c

_3 min_

In [None]:
all_dat = []
all_subjects = HUMAN_DF["subject_id"].unique()
for i, subject in enumerate(all_subjects):
    
    print("\rProcessing " + str(i+1) + " of " + str(len(all_subjects)), end="")

    new_dat = getHuman(subject)

    # Compute
    weights = {'bias': 1, 's_a': 1, 's_b': 1, 's_avg': 1}
    K = np.sum([weights[i] for i in weights.keys()])
    hyper_guess = {
     'sigma'   : [2**-5]*K,
     'sigInit' : 2**5,
     'sigDay'  : None
      }
    optList = ['sigma']

    hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

    dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
           'weights' : weights, 'new_dat' : new_dat}
    all_dat += [dat]

# Save interim result
np.savez_compressed(SPATH+'fig7c_data.npz', all_dat=all_dat)

In [None]:
all_dat = np.load(SPATH+'fig7c_data.npz', allow_pickle=True)['all_dat']

plt.figure(figsize=(4.75,1))
for dat in all_dat:

    weights = dat['weights']
    wMode = dat['wMode']
    labels = []
    for j in sorted(weights.keys()):
        labels += [j]*weights[j]

    for i, w in enumerate(labels):
        plt.plot(wMode[i], lw=1.5, alpha=0.5, linestyle='-', c=colors[w], zorder=zorder[w])

plt.axhline(0, color="black", linestyle="--", lw=1, alpha=0.5, zorder=0)
plt.gca().set_xticks([0,500,1000,1500,2000])
plt.gca().set_yticks(np.arange(-2, 3,2))
plt.xlim(0, 1900); plt.ylim(-3.4, 3.4)

plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig7c.pdf")

# Figure 8 | History Regressors Improve Model Accuracy for an Example Akrami Rat

**(A)** Show plot for model w/o using history weights: predicted accuracy on x-axis, and empirical accuracy on y-axis

**(B)** Show histogram of predicted accuracy in trials from (A)

**(C)** Show plot for model with history weights: predicted accuracy on x-axis, and empirical accuracy on y-axis

**(D)** Show histogram of predicted accuracy in trials from (C)

## Figure 8a

_30 min_

In [None]:
outData = getRat("W080")
new_dat = psy.trim(outData, END=12500)

FOLDS = 10  # number of cross-validation folds
SEED = 42   # controls random divide of trials into FOLDS bins

weights = {'bias': 1, 's_a': 1, 's_b': 1, 'h': 0, 'c': 0, "s_avg": 0}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**-4]*K,
  }
optList = ['sigma', 'sigDay']

_, xval_pL = psy.crossValidate(new_dat, hyper_guess, weights, optList, F=FOLDS, seed=SEED)
np.savez_compressed(SPATH+'fig8a_data.npz', new_dat=new_dat, xval_pL=xval_pL)

In [None]:
from scipy.stats import sem
xval_pL = np.load(SPATH+'../fig8a_data.npz', allow_pickle=True)['xval_pL'] 
new_dat = np.load(SPATH+'../fig8a_data.npz', allow_pickle=True)['new_dat'].item()

step = 0.02
edges = np.arange(0.5,1.0+step,step)

est_correct = np.abs(xval_pL - 0.5) + 0.5
match = ((-np.sign(xval_pL - 0.5) + 1)/2).astype(int) == new_dat["y"].astype(int)

print("Average Empirical Accuracy:", np.round(np.average(match), 3))
print("Average Predicted Accuracy:", np.round(np.average(est_correct), 3))

choices = []
centers = []
for i in edges[:-1]:
    mask = (est_correct >= i) & (est_correct < i+step)
    choices += [match[mask]]
    centers += [np.average(est_correct[mask])];

avg_correct = np.array([np.average(i) if len(i) > 40 else np.nan for i in choices])
sem_correct = np.array([sem(i) if len(i) > 40 else np.nan for i in choices])

plt.figure(figsize=(2,1.5))
plt.errorbar(centers, avg_correct, yerr=1.96*sem_correct,
             alpha=1, color=colors['bias'], linestyle="None", marker="o", markersize=2)
plt.plot(np.average(est_correct), np.average(match), marker="*", markersize=10, alpha=0.75,
         markeredgecolor="None", markerfacecolor="black", zorder=10)

plt.plot([0.4,1.1], [0.4,1.1], color="black", linestyle="--", lw=1, alpha=0.5, zorder=0)

plt.xlim(0.5, 1)
plt.ylim(0.5, 1)
plt.xticks([0.5,0.6,0.7,0.8,0.9,1.0])
plt.gca().set_xticklabels([])

plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig8a.pdf")

## Figure 8b

In [None]:
plt.figure(figsize=(2,1.5))
plt.hist(est_correct, bins=edges, alpha=1, lw=0.5, color=colors['bias'], edgecolor="black")

plt.xlim(0.5, 1)
plt.ylim(0, 1700)
plt.xticks([0.5,0.6,0.7,0.8,0.9,1.0])
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig8b.pdf")

## Figure 8c

These subfigures reuse data from Figure 5C-E, please go run the cell above that creates the file `fig5c_data.npz` to produce Figures 8C+D.

In [None]:
from scipy.stats import sem
xval_pL = np.load(SPATH+'fig5c_data.npz', allow_pickle=True)['xval_pL'] 
new_dat = np.load(SPATH+'fig5c_data.npz', allow_pickle=True)['new_dat'].item()

step = 0.02
edges = np.arange(0.5,1.0+step,step)

est_correct = np.abs(xval_pL - 0.5) + 0.5
match = ((-np.sign(xval_pL - 0.5) + 1)/2).astype(int) == new_dat["y"].astype(int)

print("Average Empirical Accuracy:", np.round(np.average(match), 3))
print("Average Predicted Accuracy:", np.round(np.average(est_correct), 3))

choices = []
centers = []
for i in edges[:-1]:
    mask = (est_correct >= i) & (est_correct < i+step)
    choices += [match[mask]]
    centers += [np.average(est_correct[mask])];

avg_correct = np.array([np.average(i) if len(i) > 40 else np.nan for i in choices])
sem_correct = np.array([sem(i) if len(i) > 40 else np.nan for i in choices])

plt.figure(figsize=(2,1.5))
plt.errorbar(centers, avg_correct, yerr=1.96*sem_correct,
             alpha=1, color=colors['h'], linestyle="None", marker="o", markersize=2)
plt.plot(np.average(est_correct), np.average(match), marker="*", markersize=10, alpha=0.75,
         markeredgecolor="None", markerfacecolor="black", zorder=10)

plt.plot([0.4,1.1], [0.4,1.1], color="black", linestyle="--", lw=1, alpha=0.5, zorder=0)

plt.xlim(0.5, 1)
plt.ylim(0.5, 1)
plt.xticks([0.5,0.6,0.7,0.8,0.9,1.0])
plt.gca().set_xticklabels([])
plt.gca().set_yticklabels([])

plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig8c.pdf")

## Figure 8d

In [None]:
plt.figure(figsize=(2,1.5))
plt.hist(est_correct, bins=edges, alpha=1, lw=0.5, color=colors['h'], edgecolor="black")

plt.xlim(0.5, 1)
plt.ylim(0, 1700)
plt.xticks([0.5,0.6,0.7,0.8,0.9,1.0])
plt.gca().set_yticklabels([])
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "Fig8d.pdf")

#     

---


# Supplementary Figures

## Figure S1 | Compute time and model accuracy

**(A)** Show compute time

**(B)** Show weight recovery accuracy

### Figure S1a

_5 hours_

In [None]:
from psytrack.runSim import generateSim, recoverSim

num_weights = [2,4,6]
num_trials = 1000*np.array([1,2,4,8,16])
num_simulations = 20

results = []

for N in num_trials:
    for K in num_weights:
        for i in range(num_simulations):
            print("K =", K, "  N =", N, "  iter = ", i)
            # Simulate data
            seed = N+100*K+i
            np.random.seed(seed)
            hyper = {'sigma': 2**np.random.uniform(-7.5, -3.5, size=K), 'sigInit': 1.0}
            dat = generateSim(K=K, N=N, hyper=hyper, boundary=5.0, iterations=1, seed=seed)
            
            # Recover data
            try:
                rec = recoverSim(dat, hess_calc=None)
            except:
                print("ERROR!!!")
                results += [[N, K, i, np.nan, np.nan]]
                continue
            
            # Save all data, mainly duration and mean squared error in weight recovery
            mse = np.average((rec['wMode'] - rec['input']['W'].T)**2)
            print("      " + str(rec['duration'].seconds) +"s   mse =", np.round(mse, 4))
            results += [[N, K, i, rec['duration'], mse]]
            
# Update saved record of all info on each iteration
np.savez(SPATH + "FigS1_dat.npz", results=results)


In [None]:
res = np.load(SPATH + "FigS1_dat.npz", allow_pickle=True)['results']

plt.figure(figsize=(2.5,2.5))
COLORS = [colors['bias'],colors['s1'],colors['s2'],]
adjust = [-0.3, 0, 0.3]
for i, K in enumerate(num_weights):
    all_duration = [i[3] for i in res if i[1]==K]
    all_duration = np.array([i.total_seconds()/60
                             if i is not None else np.nan
                             for i in all_duration]).reshape(-1,num_simulations)
    plt.errorbar(num_trials/1000 + adjust[i], np.nanmean(all_duration, axis=1),
                 yerr=np.nanstd(all_duration, axis=1),
                 color=COLORS[i], marker="o", markersize=3, lw=1)


plt.xlim(0.25, 16.5)
plt.ylim(0, 8.2)
plt.xticks([1,2,4,8,16])
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS1a.pdf")

### Figure S1b

In [None]:
res = np.load(SPATH + "FigS1_dat.npz", allow_pickle=True)['results']

plt.figure(figsize=(2.5,2.5))
COLORS = [colors['bias'],colors['s1'],colors['s2']]
adjust = [-0.3, 0, 0.3]
for i, K in enumerate(num_weights):
    all_mse = [i[4] for i in res if i[1]==K]
    all_mse = np.array([i if i is not None else np.nan
                        for i in all_mse]).reshape(-1,num_simulations)
    plt.errorbar(num_trials/1000 + adjust[i], np.nanmean(all_mse, axis=1),
                 yerr=np.nanstd(all_mse, axis=1),
                 color=COLORS[i], linestyle="None", marker="o", markersize=3, lw=1)

plt.xlim(0.25, 16.5); plt.ylim(0, 0.152)
plt.xticks([1,2,4,8,16]); plt.yticks([0,0.05,0.1,0.15])
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS1b.pdf")

## Figure S2 | Recovering sudden changes in behavior with smooth weight trajectories

**(A)** Same as (2C) but without including $\sigma_\text{day}$ in the recovery model

**(B)** Same as (2D) but without including $\sigma_\text{day}$ in the recovery model

### Figure S2a

In [None]:
# Simulate
seed = 102  # paper uses 102
num_weights = 3
num_trials = 5000
hyper = {'sigma'   : 2**np.array([-4.5, -5.0,-16.0]),
         'sigInit' : 2**np.array([ 0.0,  0.0,  0.0]),
         'sigDay'  : 2**np.array([ 0.5,-16.0,  1.0])
        }
days = [500]*9

# Compute
gen = psy.generateSim(K=num_weights, N=num_trials, hyper=hyper, days=days,
                      boundary=10.0, iterations=1, seed=seed, savePath=None)

In [None]:
# Recovery
gen['dayLength'] = None
rec = psy.recoverSim(gen)

# Save interim result
np.savez_compressed(SPATH+'figS2_data.npz', rec=rec, gen=gen)

In [None]:
# Reload data
rec = np.load(SPATH+'figS2_data.npz', allow_pickle=True)['rec'].item()
gen = np.load(SPATH+'figS2_data.npz', allow_pickle=True)['gen'].item()

# Plotting
sim_colors = [colors['bias'], colors['s1'], colors['s2']]
fig = plt.figure(figsize=(3.75,1.4))
for i, c in enumerate(sim_colors):
    plt.plot(gen['W'][:,i], c=c, lw=0.5, zorder=5-i)
    plt.plot(rec['wMode'][i], c=c, lw=1, linestyle='--', alpha=0.5, zorder=5-i)
    plt.fill_between(np.arange(num_trials),
                     rec['wMode'][i] - 2 * rec['hess_info']['W_std'][i],
                     rec['wMode'][i] + 2 * rec['hess_info']['W_std'][i],
                     facecolor=c, alpha=0.2, zorder=5-i)

for i in np.cumsum(days):
    plt.axvline(i, color="black", lw=0.5, alpha=0.5, zorder=0)
    
plt.axhline(0, color="black", linestyle="--", lw=0.5, alpha=0.5, zorder=0)
plt.xticks(1000*np.arange(0,6))
plt.gca().set_xticklabels([0,1000,2000,3000,4000,5000])
plt.yticks(np.arange(-4,5,2))

plt.xlim(0,5000); plt.ylim(-4.3,4.3)
# plt.xlabel("Trials"); plt.ylabel("Weights")

plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS2a.pdf")

### Figure S2b

In [None]:
# Reload data
rec = np.load(SPATH+'figS2_data.npz', allow_pickle=True)['rec'].item()

# Plotting
plt.figure(figsize=(1.4,1.4))

true_sigma = np.log2(rec['input']['sigma'])
avg_sigma = np.log2(rec['hyp']['sigma'])
err_sigma = rec['hess_info']['hyp_std'][:3]
for i, c in enumerate(sim_colors):
    plt.plot([i-0.3, i+0.3], [true_sigma[i]]*2, color="black", linestyle="-", lw=1.2, zorder=0)
    plt.errorbar([i], avg_sigma[i], yerr=1.96*err_sigma[i], c=c, lw=1, marker='o', markersize=5)

plt.axvspan(1.6,2.4, facecolor="black", edgecolor="none", alpha=0.1)
plt.xticks(np.arange(3))
# plt.yticks([-8,-6,-4,-2,0,2])
plt.gca().set_xticklabels([r"$\sigma_1$",
                           r"$\sigma_2$", 
                           r"$\sigma_3$"])
plt.xlim(-0.5,2.5); plt.ylim(-7.5,-2.5)
# plt.ylabel(r"$\log_2(\sigma)$")

plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS2b.pdf")

## Figure S3 | Adding weights to early training sessions in IBL mice

**(A)** Refit model from Figure 3b, with history regressor weights

**(B)** Refit model from Figure 3b, with bias weight

### Figure S3a

In [None]:
# Collect data from manually determined training period
outData = getMouse('CSHL_003', 5)

prev_choice = np.hstack(([0], outData['y'][:-1]*2 - 3)).reshape(-1,1)
prev_answer = np.hstack(([0], outData['answer'][:-1]*2 - 3)).reshape(-1,1)
outData['inputs']['c'] = prev_choice
outData['inputs']['h'] = prev_answer

new_dat = psy.trim(outData, END=7000)

# Compute
weights = {'bias' : 0, 'cL' : 1, 'cR' : 1, 'h' : 1, 'c' : 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : None
  }
optList = ['sigma']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat}

# Save interim result
np.savez_compressed(SPATH+'figS3a_data.npz', dat=dat)

In [None]:
dat = np.load(SPATH+'figS3a_data.npz', allow_pickle=True)['dat'].item()

fig = psy.plot_weights(dat['wMode'], dat['weights'], days=dat['new_dat']["dayLength"], 
                       errorbar=dat['W_std'], figsize=(2.75,1.3))

plt.axvline(np.cumsum(dat['new_dat']['dayLength'])[8], c="black", lw=1.5, ls="--", zorder=15)
plt.ylim(-5.3,5.3)
plt.xlim(0, 6950)
plt.yticks([-4,-2,0,2,4])
plt.xlabel(None); plt.ylabel(None)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS3a.pdf")

### Figure S3b

In [None]:
# Collect data from manually determined training period
outData = getMouse('CSHL_003', 5)

new_dat = psy.trim(outData, END=7000)

# Compute
weights = {'bias' : 1, 'cL' : 1, 'cR' : 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : None
  }
optList = ['sigma']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat}

# Save interim result
np.savez_compressed(SPATH+'figS3b_data.npz', dat=dat)

In [None]:
dat = np.load(SPATH+'figS3b_data.npz', allow_pickle=True)['dat'].item()

fig = psy.plot_weights(dat['wMode'], dat['weights'], days=dat['new_dat']["dayLength"], 
                       errorbar=dat['W_std'], figsize=(2.75,1.3))

plt.axvline(np.cumsum(dat['new_dat']['dayLength'])[8], c="black", lw=1.5, ls="--", zorder=15)
plt.ylim(-5.3,5.3)
plt.xlim(0, 6950)
plt.yticks([-4,-2,0,2,4])
plt.xlabel(None); plt.ylabel(None)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS3b.pdf")

## Figure S4 | The impact of the $\tanh$ transformation of IBL contrasts on model weights

**(A)** $\tanh$ tranformation on IBL contrasts

**(B)** Refit model from Figure 3b, without $\tanh$ transformation

### Figure S4a

In [None]:
contrasts = [-1, -0.5, -0.25, -0.125, -0.0625, 0, 0.0625, 0.125, 0.25, 0.5, 1.0]
def tanh_transform(c, p):
    return np.tanh(p*np.array(c))/np.tanh(p)


COLORS = [colors['s_avg'], colors['c'], colors['h']]
plt.figure(figsize=(2.25, 2.25))
plt.plot(contrasts, contrasts, "ko-", markersize=3, lw=1, label="Original")
for i, j in enumerate([1,3,5]):
    plt.plot(contrasts, tanh_transform(contrasts, j),
             "o-", markersize=3, lw=1, color=COLORS[i], label=r"$p = $" +str(j))

plt.axhline(0, color="black", linestyle="--", lw=0.5, zorder=0)#, alpha=0.5)
plt.axvline(0, color="black", linestyle="--", lw=0.5, zorder=0)#, alpha=0.5)
plt.legend(fontsize=10)

# plt.xlabel("Original Contrasts"); plt.ylabel("Transformed Contrasts")
plt.xlim(-1.05,1.05); plt.ylim(-1.05,1.05)
plt.xticks(contrasts, va="top", ha="center")
plt.yticks(contrasts, rotation=90, va="center", ha="right", ma="center")
plt.gca().set_xticklabels(["100%\nLeft",None,None,None,None,0,None,None,None,None,"100%\nRight"])
plt.gca().set_yticklabels(["100%\nLeft",None,None,None,None,0,None,None,None,None,"100%\nRight"])

plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS4a.pdf")

### Figure S4b

In [None]:
# Collect data from manually determined training period
outData = getMouse("CSHL_003", 0.00001)
_start  = np.where(outData['date'] >= '2019-03-21')[0][0]
_end    = np.where(outData['date'] >= '2019-03-23')[0][0]
new_dat = psy.trim(outData, START=_start, END=_end)

# Hardcode random trials where probL != 0.5 before bias blocks begin to 0.5
new_dat['probL'][:np.where(new_dat['date'] >= '2019-03-22')[0][0]] = 0.5

# Compute
weights = {'bias' : 1, 'cL' : 1, 'cR' : 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**-5]*K
  }
optList = ['sigma', 'sigDay']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat}

# Save interim result
np.savez_compressed(SPATH+'figS4b_data.npz', dat=dat)

In [None]:
dat = np.load(SPATH+'figS4b_data.npz', allow_pickle=True)['dat'].item()

fig = psy.plot_weights(dat['wMode'], dat['weights'], days=dat['new_dat']["dayLength"], 
                       errorbar=dat['W_std'], figsize=(3,1.5))
fig = addBiasBlocks(fig, dat['new_dat']['probL'])

plt.xlabel(None); plt.ylabel(None)
plt.gca().set_yticks(np.arange(-15,16,5))
plt.ylim(-16.3,16.3)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS4b.pdf")

## Figure S5 | Validating the model with a comparison to empirical psychometric curves

**(A)** Using the mouse from (4), show the weights recovered from Session 10

**(B)** As in (A), for Session 20 (now with a bias weight and bias blocks)

**(C)** As in (A), for Session 40

**(D)** Generate a psychometric curve from predictions derived from the model weights (in pink) and a curve caluclated from the empirical choice behavior (in black)

**(E)** As in (D), for Session 20

**(F)** As in (D), for Session 40

### Figure S5a

_(3 min)_

In [None]:
from datetime import date, datetime, timedelta
from scipy.stats import sem

FOLDS = 10  # number of cross-validation folds
SEED = 42   # controls random divide of trials into FOLDS bins

outData = getMouse("CSHL_003", 5)
outData['contrast'] = outData['contrastRight'] - outData['contrastLeft']

all_cs = []
all_ys = []
all_pRs = []

all_days = np.unique(outData['date'])
selected_days = [10,20,40]
for d in selected_days:
    _start  = np.where(outData['date'] >= all_days[d])[0][0]
    _end    = np.where(outData['date'] > all_days[d])[0][0] + 1
    _end = _start + ((_end-_start)//FOLDS * FOLDS)
    new_dat = psy.trim(outData, START=_start, END=_end)

    if d < 15:
        weights = {'bias' : 0, 'cL' : 1, 'cR' : 1}
    else:
        weights = {'bias' : 1, 'cL' : 1, 'cR' : 1}
    K = np.sum([weights[i] for i in weights.keys()])
    hyper_guess = {
    'sigma'   : [2**-5]*K,
    'sigInit' : 2**5,
    'sigDay'  : None
    }
    optList = ['sigma']

    _, _, wMode, hess_info = psy.hyperOpt(new_dat.copy(), hyper_guess, weights, optList)
    _, xval_pL = psy.crossValidate(new_dat.copy(), hyper_guess, weights, optList,
                                   F=FOLDS, seed=SEED)
    pR = 1 - xval_pL

    cs = []
    ys = []
    pRs = []
    for c in np.unique(new_dat['contrast']):
        cs += [c]
        inds = (new_dat['contrast'] == c)
        ys += [new_dat['y'][inds] - 1]
        pRs += [pR[inds]]
    all_cs += [cs]
    all_ys += [ys]
    all_pRs += [pRs]

    # Plot weights
    fig = psy.plot_weights(wMode, weights, days=None, 
                           errorbar=hess_info['W_std'], figsize=(1.75,1.5))

    plt.xlabel(None); plt.ylabel(None)
    plt.gca().set_yticks(np.arange(-4,6,2))
    plt.ylim(-5.3,5.3)
    if d > 15:
        fig = addBiasBlocks(fig, new_dat['probL'])
        plt.gca().set_yticklabels([])

    plt.subplots_adjust(0,0,1,1) 
    plt.savefig(SPATH + "FigS5abc_" + str(d) + ".pdf")

# Save interim result
np.savez_compressed(SPATH+'figS5_data.npz', 
                    all_cs=all_cs, all_ys=all_ys, all_pRs=all_pRs,
                    selected_days=selected_days)

### Figure S5d-f

In [None]:
all_cs = np.load(SPATH+'figS5_data.npz', allow_pickle=True)['all_cs']
all_ys = np.load(SPATH+'figS5_data.npz', allow_pickle=True)['all_ys']
all_pRs = np.load(SPATH+'figS5_data.npz', allow_pickle=True)['all_pRs']
selected_days = np.load(SPATH+'figS5_data.npz', allow_pickle=True)['selected_days']

# Plotting
diff = 0.01
for d, ind in enumerate(selected_days):
    plt.figure(figsize=(1.75,1.5))
    avg = [np.average(i) for i in all_ys[d]]
    std = [sem(i) for i in all_ys[d]]
    plt.errorbar(np.array(all_cs[d])-diff, avg, yerr=std, color="black", 
                 alpha=1.0, ls="-", lw=0.4, marker='_', markersize=3, elinewidth=1.3)

    avg_pR = [np.average(i) for i in all_pRs[d]]
    std_pR = [sem(i) for i in all_ys[d]]
    plt.errorbar(np.array(all_cs[d])+diff, avg_pR, yerr=std_pR, color=colors['emp_perf'],
                 alpha=1.0, ls="none", marker='_', markersize=3, elinewidth=1.3)
    
    plt.ylim(-0.01,1.01)
    plt.xlim(-1-4*diff,1+4*diff)
    plt.axvline(0, linestyle='--', color="black", lw=0.5, alpha=0.5, zorder=1)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().set_yticks([0,0.5,1])
    if not d:
        plt.gca().set_yticklabels([0,None,1])
    else:
        plt.gca().set_yticklabels([])
    plt.gca().set_xticks([-1,-0.5,-0.25,-.125,-.0625,0,0.0625,0.125,0.25,0.5,1])
    plt.gca().set_xticklabels([])

    # plt.title(d)
    plt.subplots_adjust(0,0,1,1) 
    plt.savefig(SPATH + "FigS5def_" + str(d) + ".pdf")


## Figure S6 | Allowing the bias weight to reset between bias blocks with $\sigma_\text{day}$

Replica of Figure 4, except bias block boundaries are treated as session boundaries and $\sigma_\text{day}$ is fixed to a large value, allowing for a "reset" of the bias weight between bias blocks

### Figure S6b

In [None]:
# Collect data from manually determined training period
outData = getMouse("CSHL_003", 5)
_start  = np.where(outData['date'] >= '2019-03-21')[0][0]
_end    = np.where(outData['date'] >= '2019-03-23')[0][0]
new_dat = psy.trim(outData, START=_start, END=_end)

# Hardcode random trials where probL != 0.5 before bias blocks begin to 0.5
# (fyi, this is due to anti-biasing in the IBL early training protocol)
new_dat['probL'][:np.where(new_dat['date'] >= '2019-03-22')[0][0]] = 0.5
probL_bound = np.where(new_dat['probL'][1:] - new_dat['probL'][:-1] != 0)[0] + 1
old_dayLength = new_dat['dayLength']
new_dat['dayLength'] = np.hstack((probL_bound[:1], np.diff(probL_bound)))

# Compute
weights = {'bias' : 1, 'cL' : 1, 'cR' : 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**5, 2**-5., 2**-5.]
  }
optList = ['sigma']#, 'sigDay']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat, 'old_dayLength': old_dayLength}

# Save interim result
np.savez_compressed(SPATH+'figS6b_data.npz', dat=dat)

In [None]:
BIAS_COLORS = {50 : 'None', 20 : psy.COLORS['sR'], 80 : psy.COLORS['sL']}
def addBiasBlocks(fig, pL):
    plt.sca(fig.gca())
    i = 0
    while i < len(pL):
        start = i
        while i+1 < len(pL) and np.linalg.norm(pL[i] - pL[i+1]) < 0.0001:
            i += 1
        fc = BIAS_COLORS[int(100 * pL[start])]
        plt.axvspan(start, i+1, facecolor=fc, alpha=0.2, edgecolor=None)
        i += 1
    return fig

In [None]:
dat = np.load(SPATH+'figS6b_data.npz', allow_pickle=True)['dat'].item()

fig = psy.plot_weights(dat['wMode'], dat['weights'], days=dat["old_dayLength"], 
                       errorbar=dat['W_std'], figsize=(2.75,1.3))
fig = addBiasBlocks(fig, dat['new_dat']['probL'])

plt.xlabel(None); plt.ylabel(None)
plt.gca().set_yticks(np.arange(-6, 7,2))
plt.ylim(-5.3,5.3)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS6b.pdf")

### Figure S6c

In [None]:
# Collect data from manually determined training period
outData = getMouse("CSHL_003", 5)
_start  = np.where(outData['date'] >= '2019-04-30')[0][0]
_end    = np.where(outData['date'] >= '2019-05-02')[0][0]
new_dat = psy.trim(outData, START=_start, END=_end)

# Compute
probL_bound = np.where(new_dat['probL'][1:] - new_dat['probL'][:-1] != 0)[0] + 1
old_dayLength = new_dat['dayLength']
new_dat['dayLength'] = np.hstack((probL_bound[:1], np.diff(probL_bound)))

# Compute
weights = {'bias' : 1, 'cL' : 1, 'cR' : 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**5, 2**-5., 2**-5.]
  }
optList = ['sigma']#, 'sigDay']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat, 'old_dayLength': old_dayLength}

# Save interim result
np.savez_compressed(SPATH+'figS6c_data.npz', dat=dat)

In [None]:
dat = np.load(SPATH+'figS6c_data.npz', allow_pickle=True)['dat'].item()

fig = psy.plot_weights(dat['wMode'], dat['weights'], days=dat["old_dayLength"], 
                       errorbar=dat['W_std'], figsize=(2.75,1.3))
fig = addBiasBlocks(fig, dat['new_dat']['probL'])

plt.xlabel(None); plt.ylabel(None)
plt.gca().set_yticks(np.arange(-6, 7,2))
plt.gca().set_yticklabels([])
plt.ylim(-5.3,5.3)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS6c.pdf")

### Figure S6d

In [None]:
outData = getMouse("CSHL_003", 5)

# Collect data from manually determined training period
_start  = np.where(outData['date'] >= '2019-03-22')[0][0]
_end    = np.where(outData['date'] >= '2019-03-26')[0][0]
new_dat = psy.trim(outData, START=_start, END=_end)

# Hardcode random trials where probL != 0.5 before bias begins to 0.5
# (fyi, this is due to anti-biasing in the IBL early training protocol)
new_dat['probL'][:np.where(new_dat['date'] >= '2019-03-22')[0][0]] = 0.5
probL_bound = np.where(new_dat['probL'][1:] - new_dat['probL'][:-1] != 0)[0] + 1
old_dayLength = new_dat['dayLength']
new_dat['dayLength'] = np.hstack((probL_bound[:1], np.diff(probL_bound)))

# Compute
weights = {'bias' : 1, 'cL' : 1, 'cR' : 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**5, 2**-5., 2**-5.]
  }
optList = ['sigma']#, 'sigDay']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat, 'old_dayLength': old_dayLength}

# Save interim result
np.savez_compressed(SPATH+'figS6d_data.npz', dat=dat)

In [None]:
def bias_diff(dat_load, figsize=(1.5,1.5)):
    dat = np.load(dat_load, allow_pickle=True)['dat'].item()
    pL = dat['new_dat']['probL']
    pL_diff = pL[1:] - pL[:-1]
    inds = np.where(pL_diff)[0]
    start_inds = [0] + list(inds+1)
    start_inds = [i for i in start_inds if (np.isclose(pL[i], 0.2) or np.isclose(pL[i], 0.8))]
    end_inds = list(inds) + [len(pL)-1]
    end_inds = [i for i in end_inds if (np.isclose(pL[i], 0.2) or np.isclose(pL[i], 0.8))]

    fig = plt.figure(figsize=figsize)
    for s, e in zip(start_inds, end_inds):
        if e-s < 20: continue
        block_inds = np.arange(s, e+1)
        block = dat['wMode'][0, block_inds] - dat['wMode'][0, s]
        if np.isclose(pL[s], 0.2):
            plt.plot(block, color=colors['cR'], alpha=0.8, zorder=2, lw=1)
        else:
            plt.plot(block, color=colors['cL'], alpha=0.8, zorder=4, lw=1)
    
    plt.axhline(0, linestyle='--', color="black", lw=1, alpha=0.5, zorder=0)
    plt.ylim(-5.5,5.5)
    plt.xlim(0, 75)

    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['top'].set_visible(False)
    plt.subplots_adjust(0,0,1,1)
    return fig

fig = bias_diff(SPATH+'figS6d_data.npz', figsize=(1.3,1.3));
plt.gca().set_yticks([-4,-2,0,2,4])
plt.savefig(SPATH + "FigS6d.pdf")

### Figure S6e

In [None]:
outData = getMouse("CSHL_003", 5)

# Collect data from manually determined training period
_start  = np.where(outData['date'] >= '2019-04-30')[0][0]
_end    = np.where(outData['date'] >= '2019-05-03')[0][0]
new_dat = psy.trim(outData, START=_start, END=_end)

# Compute
probL_bound = np.where(new_dat['probL'][1:] - new_dat['probL'][:-1] != 0)[0] + 1
old_dayLength = new_dat['dayLength']
new_dat['dayLength'] = np.hstack((probL_bound[:1], np.diff(probL_bound)))

weights = {'bias' : 1, 'cL' : 1, 'cR' : 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**5, 2**-5., 2**-5.]
  }
optList = ['sigma']#, 'sigDay']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat, 'old_dayLength': old_dayLength}

# Save interim result
np.savez_compressed(SPATH+'figS6e_data.npz', dat=dat)

In [None]:
fig = bias_diff(SPATH+'figS6e_data.npz', figsize=(1.3,1.3));
plt.gca().set_yticks([-4,-2,0,2,4])
plt.gca().set_yticklabels([])
plt.savefig(SPATH + "FigS6e.pdf")

### Figure S6f

In [None]:
def max_bias(bias, side, wL, wR):
        
    contrasts = np.array([-1., -0.25, -0.125, -0.0625, 0., 0.0625, 0.125, 0.25, 1.])
    
    p=5
    transformed_con = np.tanh(p*np.abs(contrasts))/np.tanh(p)

    p_biasL = [.8/4.5]*4 + [1/9] + [.2/4.5]*4    
    p_biasR = [.2/4.5]*4 + [1/9] + [.8/4.5]*4
    p_biasM = [1/9]*9

    w = [wL]*4 + [0] + [wR]*4
    correct = [0]*4 + [0] + [1]*4

    pL = 1 - (1/(1+np.exp(-(transformed_con*w + bias))))
    pCorrect = np.abs(correct - pL)
    
    if side=="L":
        pCorrect[4] = pL[4]*0.8 + (1-pL[4])*0.2
        expval = np.sum(p_biasL * pCorrect)
    
    elif side=="R":
        pCorrect[4] = pL[4]*0.2 + (1-pL[4])*0.8
        expval = np.sum(p_biasR * pCorrect)
    
    elif side=="M":
        pCorrect[4] = 0.5
        expval = np.sum(p_biasM * pCorrect)
    
    return -expval

In [None]:
from scipy.optimize import minimize

dat = np.load(SPATH+'figS6c_data.npz', allow_pickle=True)['dat'].item()
start = dat['old_dayLength'][0]

optBias = []
optReward = []
for i in np.arange(start, dat['wMode'].shape[1]):
    
    if dat['new_dat']['probL'][i] < 0.21: side = 'R'
    elif dat['new_dat']['probL'][i] > 0.79: side = 'L'
    else: side = 'M'
        
    res = minimize(max_bias,[0], args=(side, dat['wMode'][1,i], dat['wMode'][2,i]))
    optBias += [res.x]
    optReward += [-res.fun]

print("Avg. Reward:", np.mean(optReward))

In [None]:
fig = psy.plot_weights(dat['wMode'], dat['weights'], days=dat["old_dayLength"],
                       errorbar=dat['W_std'], figsize=(2.75,1.3))
fig = addBiasBlocks(fig, dat['new_dat']['probL'])

plt.plot(np.arange(start, dat['wMode'].shape[1]), optBias, 'k-', lw=2, zorder=10)
plt.gca().set_yticks(np.arange(-6, 7,2))
plt.gca().set_yticklabels([])
plt.gca().set_xticks([750, 1000, 1250])
plt.xlim(start, None); plt.ylim(-5.3,5.3)
plt.xlabel(None); plt.ylabel(None)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS6f.pdf")

In [None]:
# Actual predicted reward using actual bias weight
from scipy.optimize import minimize

optReward_pred = []
optReward_0bias = []
for i in np.arange(start, dat['wMode'].shape[1]):
    
    if dat['new_dat']['probL'][i] < 0.21: side = 'R'
    elif dat['new_dat']['probL'][i] > 0.79: side = 'L'
    else: side = 'M'
        
    optReward_pred += [-max_bias(dat['wMode'][0,i], side, dat['wMode'][1,i], dat['wMode'][2,i])]
    optReward_0bias += [-max_bias(0.0, side, dat['wMode'][1,i], dat['wMode'][2,i])]

print("Predicted Avg. Reward:", np.mean(optReward_pred))
print("No Bias Avg. Reward:", np.mean(optReward_0bias))
print("Empirical Avg. Reward:", np.mean(dat['new_dat']['correct'][start:]))


## Figure S7 | Example Akrami rat without history regressors

Replica of Figure 5, except history regressors are not included in the model.

### Figure S7b

_15 min_

In [None]:
outData = getRat("W080")
new_dat = psy.trim(outData, START=0, END=12500)

weights = {'bias': 1, 's_a': 1, 's_b': 1, 'h': 0, 'c': 0, "s_avg": 0}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : [2**-4]*K,
  }
optList = ['sigma', 'sigDay']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat}

# Save interim result
np.savez_compressed(SPATH+'figS7b_data.npz', dat=dat)

In [None]:
dat = np.load(SPATH+'figS7b_data.npz', allow_pickle=True)['dat'].item()

fig = psy.plot_weights(dat['wMode'], dat['weights'], days=dat['new_dat']["dayLength"], 
                       errorbar=dat['W_std'], figsize=(4.75,1.4))

selected_days = [[2000,2500], [6500,7000], [11000,11500]]
for d in selected_days:
    plt.plot(d, [-1.3]*2, lw=2, color="k")

plt.xlabel(None); plt.ylabel(None)
plt.ylim(-1.45,1.45)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS7b.pdf")

### Fig S7c-e

These subfigures reuse data from Figure 8a, please go run the cell above that creates the file `fig8a_data.npz` to produce Figures S7c-e.

In [None]:
from datetime import date, datetime, timedelta
from scipy.stats import sem

outData = np.load(SPATH+'fig8a_data.npz', allow_pickle=True)['new_dat'].item()
xval_pL = np.load(SPATH+'fig8a_data.npz', allow_pickle=True)['xval_pL'] 
outData['xval_pR'] = 1 - xval_pL

all_hists = []
all_ys = []
all_pRs = []

selected_days = [[2000,2500], [6500,7000], [11000,11500]]
for d in selected_days:
    new_dat = psy.trim(outData, START=d[0], END=d[1])

    hists = []
    ys = []
    pRs = []
    for h in [-1,1]:
        for c in [-1,1]:
            for a in [-1,1]:
                ind_h = (new_dat['inputs']['h'][:,0] == h)
                ind_c = (new_dat['inputs']['c'][:,0] == c)
                ind_a = (np.sign(new_dat['s_a'] - new_dat['s_b']) == a)
                inds = ind_h * ind_c * ind_a
                hists += [[h,c,a]]
                ys += [new_dat['y'][inds]]
                pRs += [new_dat['xval_pR'][inds]]
    
    all_hists += [hists]
    all_ys += [ys]
    all_pRs += [pRs]

In [None]:
import matplotlib as mpl

def colorFader(c1,c2,mix=0):
    c1=np.array(mpl.colors.to_rgb(c1))
    c2=np.array(mpl.colors.to_rgb(c2))
    w =np.array(mpl.colors.to_rgb("white"))
    if mix <= 0.5:
        return mpl.colors.to_hex((1-mix*2)*c1 + mix*2*w)
    else:
        return mpl.colors.to_hex((1-(mix-0.5)*2)*w + (mix-0.5)*2*c2)

def cF(mix):
    return colorFader(colors['s2'],colors['s1'],mix)

In [None]:
diff = 0.19
rad = 0.45 
cm = plt.get_cmap('RdBu_r')

for d in range(len(selected_days)):
    plt.figure(figsize=(0.75,1.5))
    avg = [np.average(i) for i in all_ys[d]]
    avg_pR = [np.average(i) for i in all_pRs[d]]

    for i in range(len(avg)):
        h = all_hists[d][i][0]
        c = all_hists[d][i][1]
        a = all_hists[d][i][2]
        x = a/2
        y = h + c/2

        plt.text(x-diff, y+diff, int(np.round(avg_pR[i]*100)),
                 ha="center", va="center", fontsize=10, zorder=i+1)
        t1 = plt.Polygon([[x-rad,y-rad],[x-rad,y+rad],[x+rad,y+rad]], 
                         facecolor=cF(avg_pR[i]), edgecolor="k", lw=0, zorder=i)
        plt.gca().add_patch(t1)

        plt.text(x+diff, y-1.5*diff, int(np.round(avg[i]*100)),
                 ha="center", va="center", fontsize=10, zorder = i+11)
        t2 = plt.Polygon([[x-rad,y-rad],[x+rad,y-rad],[x+rad,y+rad]], 
                         facecolor=cF(avg[i]), edgecolor="k", lw=0.5, zorder = i+10)
        plt.gca().add_patch(t2)
        
    plt.ylim(-2,2)
    plt.xlim(-1,1)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['left'].set_visible(False)
    plt.gca().spines['bottom'].set_visible(False)
    plt.gca().set_xticks([])
    plt.gca().set_yticks([])

    plt.subplots_adjust(0,0,1,1) 
    plt.savefig(SPATH + "FigS7cde_" + str(d) + ".pdf")

## Figure S8 | Modeling the Akrami human subjects with the Previous Choice and Previous Answer weights

**(A)** Refit model from Figure 6b, with history regressor weights

**(B)** Refit models from Figure 6c, with history regressor weights (showing only those weights)

### Figure S8a

In [None]:
new_dat = getHuman(6)

prev_choice = np.hstack(([0], new_dat['y'][:-1]*2 - 1)).reshape(-1,1)
prev_answer = np.hstack(([0], new_dat['answer'][:-1]*2 - 1)).reshape(-1,1)
new_dat['inputs']['c'] = prev_choice
new_dat['inputs']['h'] = prev_answer

# Compute
weights = {'bias': 1, 's_a': 1, 's_b': 1, 's_avg': 1, 'h': 1, 'c': 1}
K = np.sum([weights[i] for i in weights.keys()])
hyper_guess = {
 'sigma'   : [2**-5]*K,
 'sigInit' : 2**5,
 'sigDay'  : None
  }
optList = ['sigma']

hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
       'weights' : weights, 'new_dat' : new_dat}

# Save interim result
np.savez_compressed(SPATH+'figS8a_data.npz', dat=dat)

In [None]:
dat = np.load(SPATH+'figS8a_data.npz', allow_pickle=True)['dat'].item()
fig = psy.plot_weights(dat['wMode'], dat['weights'], errorbar=dat['W_std'], figsize=(4.75,1.4))

plt.xlabel(None); plt.ylabel(None)
plt.gca().set_xticks([0,500,1000,1500,2000])
plt.gca().set_yticks(np.arange(-2, 3,2))
plt.xlim(0, 1900); plt.ylim(-3.4, 3.4)

plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS8a.pdf")

### Figure S8b

_6 min_

In [None]:
all_dat = []
all_subjects = HUMAN_DF["subject_id"].unique()
for i, subject in enumerate(all_subjects):
    
    print("\rProcessing " + str(i+1) + " of " + str(len(all_subjects)), end="")
    new_dat = getHuman(subject)

    prev_choice = np.hstack(([0], new_dat['y'][:-1]*2 - 1)).reshape(-1,1)
    prev_answer = np.hstack(([0], new_dat['answer'][:-1]*2 - 1)).reshape(-1,1)
    new_dat['inputs']['c'] = prev_choice
    new_dat['inputs']['h'] = prev_answer

    # Compute
    weights = {'bias': 1, 's_a': 1, 's_b': 1, 's_avg': 1, 'h': 1, 'c': 1}
    K = np.sum([weights[i] for i in weights.keys()])
    hyper_guess = {
     'sigma'   : [2**-5]*K,
     'sigInit' : 2**5,
     'sigDay'  : None
      }
    optList = ['sigma']

    hyp, evd, wMode, hess_info = psy.hyperOpt(new_dat, hyper_guess, weights, optList)

    dat = {'hyp' : hyp, 'evd' : evd, 'wMode' : wMode, 'W_std' : hess_info['W_std'],
           'weights' : weights, 'new_dat' : new_dat}
    all_dat += [dat]

# Save interim result
np.savez_compressed(SPATH+'figS8b_data.npz', all_dat=all_dat)

In [None]:
all_dat = np.load(SPATH+'figS8b_data.npz', allow_pickle=True)['all_dat']

plt.figure(figsize=(4.75,1.4))
for dat in all_dat:

    weights = dat['weights']
    wMode = dat['wMode']
    labels = []
    for j in sorted(weights.keys()):
        labels += [j]*weights[j]

    for i, w in enumerate(labels):
        if w in ['h', 'c']:
            plt.plot(wMode[i], lw=1.5, alpha=0.5, linestyle='-', c=colors[w], zorder=zorder[w])

plt.axhline(0, color="black", linestyle="--", lw=1, alpha=0.5, zorder=0)
plt.gca().set_xticks([0,500,1000,1500,2000])
plt.gca().set_yticks(np.arange(-2, 3,2))
plt.xlim(0, 1900); plt.ylim(-3.4, 3.4)

plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.subplots_adjust(0,0,1,1) 
plt.savefig(SPATH + "FigS8b.pdf")

#    
---

# Notebook Versioning

**1.1.0** : (November 23, 2020) update following _Neuron_ reviewer feedback
 - add/replace Figure 5C-E
 - add Figure 6F
 - add Figures S2, S5, S6, & S7


**1.0.0** : (May 21, 2020) original release

# Download All Figures

In [None]:
!zip -r "all_figures.zip" . -i "{SPATH}*.pdf"
import time; time.sleep(10)

from google.colab import files
files.download("all_figures.zip")