Imports
=======

``` ipython
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset, DataLoader
from scipy.stats import binned_statistic
```

``` ipython
import sys
sys.path.insert(0, '../../../')

import pandas as pd
import torch.nn as nn
from time import perf_counter
from scipy.stats import circmean

from src.network import Network
from src.plot_utils import plot_con
from src.decode import *
from src.lr_utils import masked_normalize, clamp_tensor, normalize_tensor
from src.utils import clear_cache
```

``` ipython
import pickle as pkl

#+RESULTS:

import os

def pkl_save(obj, name, path="."):
            os.makedirs(path, exist_ok=True)
            destination = path + "/" + name + ".pkl"
            print("saving to", destination)
            pkl.dump(obj, open(destination, "wb"))


def pkl_load(name, path="."):
            source = path + "/" + name + '.pkl'
            # print('loading from', source)
            return pkl.load(open( source, "rb"))
```

Notebook Settings
=================

``` ipython
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

%run ../../../notebooks/setup.py
%matplotlib inline
%config InlineBackend.figure_format = 'png'

REPO_ROOT = "/home/leon/models/NeuroFlame"
pal = sns.color_palette("tab10")
```

Helpers
=======

``` ipython
def map2center(angles):
    """Map angles from [0, 2π] to [-π, π] using PyTorch tensors."""
    return np.where(angles > np.pi, angles - 2 * np.pi, angles)

def map2pos(angles):
    """Map angles from [-π, π] to [0, 2π] using PyTorch tensors."""
    return np.where(angles < 0, angles + 2 * np.pi, angles)
```

``` ipython
def maptens2center(angles):
    """Map angles from [0, 2π] to [-π, π] using PyTorch tensors."""
    return torch.where(angles > torch.pi, angles - 2 * torch.pi, angles)

def maptens2pos(angles):
    """Map angles from [-π, π] to [0, 2π] using PyTorch tensors."""
    return torch.where(angles < 0, angles + 2 * torch.pi, angles)
```

``` ipython
def add_vlines(model, ax=None):

    if ax is None:
        for i in range(len(model.T_STIM_ON)):
            plt.axvspan(model.T_STIM_ON[i], model.T_STIM_OFF[i], alpha=0.25)
    else:
        for i in range(len(model.T_STIM_ON)):
            ax.axvspan(model.T_STIM_ON[i], model.T_STIM_OFF[i], alpha=0.25)

```

Model
=====

``` ipython
IF_LOAD = 0
IF_SAVE = 1
IF_SAVE_POPT = 0

name = 'controls'
# name = 'patients_B'
if IF_LOAD:
    groups = ['controls', 'patients_A', 'patients_B']
else:
    groups = [name]
```

``` ipython
kwargs = {
    'N_TRIALS': 100,
    'GAIN_NMDA': [1.0, 1.0],
    'GAIN_SZ': [1., 1.],

    'GAIN': 1.2,
    'DURATION': 9.0,
    'T_STEADY': 1.0,

    'T_STIM_ON': [1.0, 2.0],
    'T_STIM_OFF': [2.0, 3.0],

    'I0': [1.0, -10.0],
    'PHI0': [180.0, 180],
    'SIGMA0': [4.0, 0.0],
    'M0': 1.0,
    'VAR_FF': [0.3, 0.3],

    'RANDOM_DELAY': 1,
    'MIN_DELAY': 0,
    'MAX_DELAY': 6,

    'RANDOM_ITI': 0,
    'MAX_ITI': 6,
    'MIN_ITI': 0,
    'ITI_LIST': [0, 2, 4, 6],

    'TAU': [0.2, 0.1],

    'SYN_DYN': 0,

    'IF_NMDA': 1,
    'R_NMDA': 1.0,
    'TAU_NMDA': [0.5, 0.5],

    'IF_FF_STP': 0,
    'FF_USE': 0.5,
    'TAU_FF_FAC': 0.0,
    'TAU_FF_REC': 0.5,

    'Jab': [1.0, -1.4, 1.0, -1],

    'IS_STP': [1, 0, 0, 0],
    'USE': [0.06, 0.03, 0.03, 0.1],
    'TAU_FAC': [4.0, 2.0, 2.0, 0.0],
    'TAU_REC': [0.4, 0.2, 0.2, 0.1],
    'W_STP': [1.0, 3.0, 4.0, 1.0],

    'IF_FF_ADAPT': 0,
    'A_FF_ADAPT': 1.0,
    'TAU_FF_ADAPT': 150.0,

    'IF_ADAPT': 1,
    'A_ADAPT': 1.5, # 4 or 0.5 0.15 or 1.5
    'TAU_ADAPT': 10.0,

    'REP_BIAS': 2.0, # 1.25
    'REP_VAR': 0.0,
    'REP_SIG': 90,
}
```

``` ipython
REPO_ROOT = "/home/leon/models/NeuroFlame"
conf_name = "train_odr_EI.yml"
DEVICE = 'cuda:1'
seed = np.random.randint(0, 1e6)

seed = 1
print('seed', seed)
```

``` ipython
N_BATCH = 768
model = Network(conf_name, REPO_ROOT, VERBOSE=0, DEVICE=DEVICE, SEED=seed, N_BATCH=N_BATCH, **kwargs)
```

``` ipython
model_state_dict = torch.load('../models/odr/odr_%d.pth' % seed)
model.load_state_dict(model_state_dict);
model.eval();
```

``` ipython
model.J_STP = torch.nn.Parameter(model.J_STP.detach() / kwargs['GAIN_SZ'][0])
model.USE[0] = model.USE[0] * kwargs['GAIN_SZ'][0]

model.J_STP = torch.nn.Parameter(kwargs['GAIN_NMDA'][0] * model.J_STP.detach())
model.Wab_T[model.slices[0], model.slices[1]] *= kwargs['GAIN_NMDA'][1]
print(model.J_STP)
```

``` ipython
import torch

def torch_dog(x, A, mu, sigma, C):
       return A * (-(x - mu)/sigma**2) * torch.exp(-(x - mu)**2/(2*sigma**2)) + C

def shifted_phase(phase1, phase2, bias_strength, bias_var, direction='repulsive', sigma=50, order=1, popt=None):
    """
    Shift phase2 away/toward phase1 by bias_strength (degrees).
    direction: 'repulsive' (away) or 'attractive' (toward).
    All phases and sigma in degrees!
    """
    # Minimal difference, wrapped to [-180, 180]
    delta = (phase1 - phase2 + 180) % 360 - 180

    sign = -1 if direction == 'repulsive' else 1

    if popt is not None:

           kernel = torch_dog(delta, popt[0], popt[1], popt[2], popt[3])
    else:
           kernel = torch.sin(torch.deg2rad(delta)) * torch.exp(-torch.abs(delta)**order / sigma**order / order)

    phase2_biased = (
           phase2 + sign * bias_strength * kernel + bias_var * torch.randn_like(phase2)
    )

    # wrap output to [0, 360)
    return phase2_biased % 360.0
```

``` ipython
if IF_SAVE_POPT:
    model.REP_BIAS=0
    popt = None
else:
     model.REP_BIAS = kwargs['REP_BIAS']
     popt = pkl_load('popt_%s' % name, './models/nih/rand_delay')

phase1 = torch.randint(low=0, high=360, size=(100,), device=DEVICE, dtype=torch.float)
phase2 = torch.randint(low=0, high=360, size=(100,), device=DEVICE, dtype=torch.float)

delta = phase1 - phase2
rel_loc = (delta + 180) % 360 - 180
shift =  shifted_phase(phase1, phase2, model.REP_BIAS, model.REP_VAR, sigma=kwargs['REP_SIG'], popt=popt)
error =  (shift-phase2).cpu().numpy()
error = (error + 180) %360 - 180

plt.plot(rel_loc.cpu().numpy(), error, 'o')
plt.show()
```

``` ipython
if kwargs['RANDOM_DELAY']:
    name += '_delay_rand'
print(name)
```

Simulating Consecutive Trials
=============================

``` ipython
model.N_BATCH = N_BATCH
# runing a baseline trial with no task
model.TASK: 'None'
model.I0 = [0.0, 0.0]
model.RANDOM_DELAY = 0

with torch.no_grad():
    ff_input = model.init_ff_input()
    rates_tensor = model.forward(ff_input=ff_input)
    clear_cache()
print(ff_input.shape, rates_tensor.shape)
```

``` ipython
model.N_BATCH = N_BATCH

model.I0 = kwargs['I0']

model.TASK: 'odr'
model.RANDOM_DELAY = kwargs['RANDOM_DELAY']
model.PHI0 = torch.randint(low=0, high=360, size=(N_BATCH, len(model.I0), 1), device=DEVICE, dtype=torch.float)

with torch.no_grad():
    ff_input = model.init_ff_input()
    rates_tensor = model.forward(ff_input=ff_input, IF_INIT=0)
    # del ff_input
    clear_cache()
print(ff_input.shape, rates_tensor.shape)
```

``` ipython
from src.configuration import init_time_const

N_TRIALS = 100

rates_list = []
prev_list = [model.PHI0[:, 0].cpu().detach()]
curr_list = []

start_list = torch.ones((N_TRIALS, 2, model.N_BATCH))
end_list = torch.ones((N_TRIALS, 2, model.N_BATCH))
delay_list = torch.ones((N_TRIALS, model.N_BATCH))

iti_list = torch.ones((N_TRIALS, model.N_BATCH))
min_iti = model.DURATION - model.T_STIM_OFF[-1]
interval = N_TRIALS // len(model.ITI_LIST)

for trial in tqdm(range(N_TRIALS)):
    with torch.no_grad():

        # if IF_ITI:
        #     if trial % interval == 0:
        #         model.T_STEADY = model.ITI_LIST[trial // interval]

        if model.RANDOM_ITI or model.RANDOM_DELAY:
            init_time_const(model)

        delay_list[trial] = model.random_shifts
        iti_list[trial] = model.N_STEADY * model.DT + min_iti

        start_list[trial] = model.start_idx
        end_list[trial] = model.end_idx

        model.PHI0 = torch.randint(low=0, high=360, size=(N_BATCH, len(model.I0), 1), device=DEVICE, dtype=torch.float)
        model.PHI0_UNBIASED = torch.deg2rad(model.PHI0.clone())

        if model.REP_BIAS>0:
            model.PHI0[:, 0] = shifted_phase(prev_list[-1].to(DEVICE)*180.0 / torch.pi, model.PHI0[:, 0], model.REP_BIAS, model.REP_VAR, popt=popt)

        ff_input = model.init_ff_input()

        rates = model.forward(ff_input=ff_input, IF_INIT=0)

        curr_list.append(model.PHI0_UNBIASED[:, 0].cpu().detach())
        prev_list.append(model.PHI0_UNBIASED[:, 0].cpu().detach())

        rates_list.append(rates.cpu().detach())

        del ff_input, rates
        clear_cache()

rates_list = torch.stack(rates_list).cpu().numpy()
prev_list = torch.stack(prev_list).cpu().numpy()[:-1]
curr_list = torch.stack(curr_list).cpu().numpy()

delay_list = np.hstack(delay_list.cpu().numpy()) * model.DT
iti_list = np.hstack(iti_list.cpu().numpy())

```

``` ipython
print('rates', rates_list.shape)
print('curr', curr_list.shape, 'prev', prev_list.shape, 'iti', iti_list.shape)
print('start', start_list.shape, 'end', end_list.shape, 'delay', delay_list.shape)
```

``` ipython
start_idx = model.start_idx.cpu().numpy()
end_idx = model.end_idx.cpu().numpy()
print(start_idx.shape, end_idx.shape)
```

``` ipython
m0_list, m1_list, phi_list = decode_bump_torch(rates_list, axis=-1, RET_TENSOR=0)

print(phi_list.shape)
```

``` ipython
DURATION = rates_list.shape[2] / 10
N_NEURONS = rates_list.shape[-1]
N_SESSION = rates_list.shape[1]
```

``` ipython
fig, ax = plt.subplots(1, 1, figsize=[3*width, height])

n_trials = 4
idx = np.random.randint(0, 100)
rates = np.vstack(rates_list[:n_trials, idx]).T
vmin, vmax = np.percentile(rates.reshape(-1), [5, 95])
print(vmax)
plt.imshow(rates, aspect='auto', cmap='jet', vmin=vmin, vmax=vmax, origin='lower', extent=[0, n_trials * DURATION, 0, N_NEURONS])
plt.ylabel('Pref. Location (°)')
plt.yticks(np.linspace(0, N_NEURONS, 5), np.linspace(0, 360, 5).astype(int))
plt.xlabel('Time (s)')

plt.show()
```

``` ipython
fig, ax = plt.subplots(1, 3, figsize=[3*width, height])

idx = np.random.randint(0, 100)

m0 = np.hstack(m0_list[:n_trials, idx]).T
m1 = np.hstack(m1_list[:n_trials, idx]).T
phi = np.hstack(phi_list[:n_trials, idx]).T

xtime = np.linspace(0, n_trials*DURATION, phi.shape[-1])
idx = np.random.randint(0, model.N_BATCH, 8)

ax[0].plot(xtime, m0)
ax[0].set_ylabel('$\mathcal{F}_0$ (Hz)')
ax[0].set_xlabel('Time (s)')

ax[1].plot(xtime, m1)
ax[1].set_ylabel('$\mathcal{F}_1$ (Hz)')
ax[1].set_xlabel('Time (s)')

ax[2].plot(xtime, phi * 180 / np.pi)
ax[2].set_yticks(np.linspace(0, 360, 5).astype(int), np.linspace(0, 360, 5).astype(int))
ax[2].set_ylabel('Bump Center (°)')
ax[2].set_xlabel('Time (s)')

# for i in range(n_trials):
#     ax[2].axhline(model.PHI0[i, 0, 0].cpu().detach(), xmin=0, xmax=int(1.0/n_trials), ls='--', color=colors[i])

plt.show()
```

``` ipython
if IF_SAVE:
    pkl_save(rates_list, 'rates_list_%s' % name, './models/nih/')

    pkl_save(curr_list, 'curr_list_%s' % name, './models/nih/')
    pkl_save(prev_list, 'prev_list_%s' % name, './models/nih/')

    pkl_save(delay_list, 'delay_list_%s' % name, './models/nih/')

    pkl_save(start_idx, 'start_idx_%s' % name, './models/nih/')
    pkl_save(end_idx, 'end_idx_%s' % name, './models/nih/')
    pkl_save(phi_list, 'phi_list_%s' % name, './models/nih/')
```

Errors
======

``` ipython
n_half = N_TRIALS // 2
```

``` ipython
curr_ini =  curr_list[:n_half]
curr_last = curr_list[-n_half:]

prev_ini =  prev_list[:n_half]
prev_last = prev_list[-n_half:]
print(curr_ini.shape, prev_ini.shape)
```

``` ipython
phi_ini = phi_list[:n_half]
phi_last = phi_list[n_half:]
```

``` ipython
print(phi_ini.shape, prev_ini.shape)
```

``` ipython
def get_error_curr_prev(phi, curr, prev):
    target_loc = curr  * 180.0 / np.pi

    rel_loc = prev - curr
    rel_loc = (rel_loc + np.pi) % (2 * np.pi) - np.pi
    rel_loc *= 180 / np.pi

    error_curr = phi - curr
    error_curr = (error_curr + np.pi) % (2 * np.pi) - np.pi
    error_curr *= 180 / np.pi

    return np.vstack(target_loc), np.vstack(rel_loc), np.array(error_curr)
```

``` ipython
def get_end_point(start_idx, errors):

    # stim_idx = ((model.start_indices - model.N_STEADY) / model.N_WINDOW).to(int).cpu().numpy()

    end_point = []

    for k in range(errors.shape[1]): # sessions
            idx = start_idx[1][k]-1
            end_point.append(errors[:, k, idx])

    return np.array(end_point).T.reshape(-1, 1)
```

``` ipython
targ_ini, rel_ini, errors_ini = get_error_curr_prev(phi_ini, curr_ini, prev_ini)
targ_last, rel_last, errors_last = get_error_curr_prev(phi_last, curr_last, prev_last)
print(targ_ini.shape, rel_ini.shape, errors_ini.shape)
```

``` ipython
print(start_list.shape)
start_ini = np.swapaxes(start_list[:n_half]-1, 0, 1).cpu().numpy()
start_ini = (start_ini).astype(int)

i, j = np.indices(start_ini[1].shape)

# end_point_ini = errors_ini[i, j, start_ini[1]]
end_point_ini = np.hstack(errors_ini[i, j, start_ini[1]])[:, np.newaxis]
print(end_point_ini.shape)
```

``` ipython
print(start_list.shape)
start_last = np.swapaxes(start_list[:n_half]-1, 0, 1).cpu().numpy()
start_last = (start_last).astype(int)

i, j = np.indices(start_last[1].shape)

end_point_last = np.hstack(errors_last[i, j, start_last[1]])[:, np.newaxis]
print(end_point_last.shape)
```

``` ipython
# end_point_ini = get_end_point(start_idx, errors_ini)
# end_point_last = get_end_point(start_idx, errors_last)
# print(end_point_ini.shape, end_point_last.shape)
```

``` ipython
fig, ax = plt.subplots(1, 2, figsize=[2*width, height])
ax[0].hist(end_point_ini[:, 0], bins=50)
ax[1].hist(end_point_last[:, 0], bins=50)
plt.show()
```

``` ipython
time_points = np.linspace(0, model.DURATION, errors_ini.shape[-1])
idx = np.random.randint(errors_ini.shape[1], size=100)

fig, ax = plt.subplots(1, 2, figsize=[2*width, height])
ax[0].plot(time_points, errors_ini[0][idx].T, alpha=.4)
add_vlines(model, ax[0])

ax[0].set_xlabel('t')
ax[0].set_ylabel('error first half(°)')

ax[1].plot(time_points, errors_last[0][idx].T, alpha=.4)
add_vlines(model, ax[1])

ax[1].set_xlabel('t')
ax[1].set_ylabel('error last half (°)')
plt.show()
```

``` ipython
```

Biases
======

``` ipython
print(targ_ini.shape, rel_ini.shape, end_point_ini.shape)
```

``` ipython
n_bins = 32
data_ini = pd.DataFrame({'target_loc': targ_ini[:, -1], 'rel_loc': rel_ini[:, -1], 'errors': end_point_ini[:, 0]})
data_last = pd.DataFrame({'target_loc': targ_last[:, -1], 'rel_loc': rel_last[:, -1], 'errors': end_point_last[:, 0]})
```

``` ipython
def get_correct_error(nbins, df, thresh=None):
    if thresh is not None:
        data = df[(df['errors'] >= -thresh) & (df['errors'] <= thresh)].copy()
    else:
        data = df.copy()

    # 1. Bias-correct both error and error_half
    bin_edges = np.linspace(0, 360, n_bins + 1)
    data['bin_target'] = pd.cut(data['target_loc'], bins=bin_edges, include_lowest=True)
    mean_errors_per_bin = data.groupby('bin_target')['errors'].mean()
    data['adjusted_errors'] = data['errors'] - data['bin_target'].map(mean_errors_per_bin).astype(float)


    # 2. Bin by relative location for both sessions (full version, [-180, 180])
    data['bin_rel'] = pd.cut(data['rel_loc'], bins=n_bins)
    bin_rel = data.groupby('bin_rel')['adjusted_errors'].agg(['mean', 'sem']).reset_index()
    edges = bin_rel['bin_rel'].cat.categories
    centers = (edges.left + edges.right) / 2

    # 3. FLIP SIGN for abs(rel_loc): defects on the left (-) are flipped so all bins reflect the same "direction"
    data['rel_loc_abs'] = np.abs(data['rel_loc'])
    data['bin_rel_abs'] = pd.cut(data['rel_loc_abs'], bins=n_bins, include_lowest=True)

    # Flip errors for abs plot:
    data['adjusted_errors_abs'] = data['adjusted_errors'] * np.sign(data['rel_loc'])

    bin_rel_abs = data.groupby('bin_rel_abs')['adjusted_errors_abs'].agg(['mean', 'sem']).reset_index()
    edges_abs = bin_rel_abs['bin_rel_abs'].cat.categories
    centers_abs = (edges_abs.left + edges_abs.right) / 2

    # 4. Bin by target location for target-centered analysis (optional)
    bin_target = data.groupby('bin_target')['adjusted_errors'].agg(['mean', 'sem']).reset_index()
    edges_target = bin_target['bin_target'].cat.categories
    target_centers = (edges_target.left + edges_target.right) / 2

    return centers, bin_rel, centers_abs, bin_rel_abs
```

``` ipython
centers_ini, bin_rel_ini, centers_abs_ini, bin_rel_abs_ini = get_correct_error(n_bins, data_ini)
centers_last, bin_rel_last, centers_abs_last, bin_rel_abs_last = get_correct_error(n_bins, data_last)
```

``` ipython
fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

# Panel 2: By Relative Location (Full vs Half session, -180..180)
ax[0].plot(centers_ini, bin_rel_ini['mean'], 'r', label='First half')
ax[0].fill_between(centers_ini, bin_rel_ini['mean'] - bin_rel_ini['sem'], bin_rel_ini['mean'] + bin_rel_ini['sem'], color='r', alpha=0.2)

ax[0].plot(centers_last, bin_rel_last['mean'], 'b', label='Second half')
ax[0].fill_between(centers_last, bin_rel_last['mean'] - bin_rel_last['sem'], bin_rel_last['mean'] + bin_rel_last['sem'], color='b', alpha=0.2)

ax[0].axhline(0, color='k', linestyle=":")
ax[0].set_xlabel('Rel. Loc. (°)')
ax[0].set_ylabel('Corrected Error (°)')


# Panel 3: By |Relative Location| (Full and Half)
ax[1].plot(centers_abs_ini, bin_rel_abs_ini['mean'], 'r', label='First half')
ax[1].fill_between(centers_abs_ini, bin_rel_abs_ini['mean'] - bin_rel_abs_ini['sem'], bin_rel_abs_ini['mean'] + bin_rel_abs_ini['sem'], color='r', alpha=0.2)

ax[1].plot(centers_abs_last, bin_rel_abs_last['mean'], 'b', label='Second half')
ax[1].fill_between(centers_abs_last, bin_rel_abs_last['mean'] - bin_rel_abs_last['sem'], bin_rel_abs_last['mean'] + bin_rel_abs_last['sem'], color='b', alpha=0.2)

ax[1].axhline(0, color='k', linestyle=":")
ax[1].set_xlabel('|Rel. Loc.| (°)')
ax[1].set_ylabel('Corrected Error (°)')
ax[1].legend(fontsize=12)

plt.tight_layout()

plt.savefig('./figures/NIH/10_25/sb_half_%s.svg' % name)

plt.show()
```

``` ipython
```

Delay Dependency
================

``` ipython
def get_delay_points(errors_list, start_idx, end_idx):
    delay_point = []

    for i in range(errors_list.shape[1]): # loop over sessions
        idx_start = end_idx[0][i] # delay start
        idx_end = start_idx[1][i] # delay stops

        end_ = []

        for idx in range(idx_start, idx_end): # loop over delay idx
            end__ = []
            for j in range(errors_list.shape[0]): # loop over trials
                end__.append(errors_list[j, i, idx]) # append all delay errors

            end_.append(end__)
        delay_point.append(end_)

    return np.vstack(np.array(delay_point).T.swapaxes(-1, 1))
```

``` ipython
targ_list, rel_list, errors_list = get_error_curr_prev(phi_list, curr_list, prev_list)

print(start_list.shape)

end_points = np.swapaxes(start_list-1, 0, 1).cpu().numpy()
end_points = (end_points).astype(int)

i, j = np.indices(end_points[1].shape)

errors_end_point = np.hstack(errors_list[i, j, end_points[1]])[:, np.newaxis]
print(errors_end_point.shape)
```

``` ipython
prev_delay = delay_list.reshape(N_TRIALS, 768)
prev_delay = np.hstack(np.roll(prev_delay, shift=1, axis=0))

data = pd.DataFrame({'target_loc': targ_list[:, -1], 'rel_loc': rel_list[:, -1], 'errors': errors_end_point[:, -1], 'delay': delay_list, 'prev_delay': prev_delay})
print(data.head())
```

``` ipython
unique_delay = np.unique(delay_list)
print(len(delay_list), unique_delay)
plt.hist(delay_list)
plt.show()
```

``` ipython
n_bins = 16

cmap = plt.get_cmap('Blues')

fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

for i_name, name in enumerate(groups):
    serial_list = []

    if IF_LOAD:
        name = name + '_delay_rand'

        phi_list = pkl_load('phi_list_%s' % name, './models/nih/')
        curr_list = pkl_load('curr_list_%s' % name, './models/nih/')
        prev_list = pkl_load('prev_list_%s' % name, './models/nih/')
        start_idx = pkl_load('start_idx_%s' % name, './models/nih/')
        delay_list = pkl_load('delay_list_%s' % name, './models/nih/')

    serial_list = []
    unique_delay = np.unique(delay_list)
    print(unique_delay, len(delay_list))
    n_unique = len(unique_delay)
    colors = [cmap((i+1)/ n_unique) for i in range(n_unique+1)]

    for i, delay in enumerate(unique_delay):
        centers, bin_rel, centers_abs, bin_rel_abs = get_correct_error(n_bins, data[data.delay==delay])

        if i==0:
            centers_0 = centers
            bin_rel_0 = bin_rel

        ax[0].plot(centers_abs, bin_rel_abs['mean'], color=pal[i_name], alpha=(i+1) / len(unique_delay))

        loc = (centers_abs>=0) & (centers_abs<=90)
        serial_max = bin_rel_abs['mean'][loc].mean()
        serial_std = bin_rel_abs['mean'][loc].std(ddof=1) / np.sqrt(len(bin_rel_abs['mean'][loc]))

        serial_list.append([serial_max, serial_std])

    serial_list = np.array(serial_list).T

    xdelay = np.linspace(unique_delay[0], unique_delay[-1], serial_list.shape[1])

    ax[1].plot(xdelay, serial_list[0], '-', color=pal[i_name])
    ax[1].fill_between(xdelay, serial_list[0] - serial_list[1], serial_list[0] + serial_list[1], alpha=0.2, color=pal[i_name])

ax[0].set_xlabel('Rel. Loc. (°)')
ax[0].set_ylabel('Error (°)')
ax[0].axhline(0, ls="--", color='k')
ax[0].set_xticks(np.linspace(0, 180, 5))

ax[1].set_xlabel('Delay (s)')
ax[1].set_ylabel('Serial Bias (°)')
ax[1].axhline(0, ls='--', color='k')
ax[1].set_xticks(model.DELAY_LIST)

if IF_LOAD:
    plt.savefig('./figures/NIH/delay/sb_delay_all.svg')
else:
    plt.savefig('./figures/NIH/delay/sb_delay_%s.svg' % name)

plt.show()
```

``` ipython
import numpy as np
from scipy.optimize import curve_fit

def dog(x, A, mu, sigma, C):
    return A * (-(x - mu)/sigma**2) * np.exp(-(x - mu)**2/(2*sigma**2)) + C

x = centers_0.values
y = bin_rel_0['mean'].values
sem = bin_rel_0['sem'].values

p0 = [max(y)-min(y), 0, 40, np.mean(y)]
if IF_SAVE_POPT:
    print('save')
    popt, _ = curve_fit(dog, x, y, p0)
    pkl_save(popt, 'popt_%s' % name, './models/nih/rand_delay')

```

``` ipython
import matplotlib.pyplot as plt
import numpy as np

plt.errorbar(x, y, yerr=sem, fmt='o', label='Data', capsize=3)

# 2. Dense x values for smooth fit curve
xfit = np.linspace(np.min(x), np.max(x), 200)

# 3. Your DoG function (same as used for fitting)
def dog(x, A, mu, sigma, C):
    return A * (-(x - mu)/sigma**2) * np.exp(-(x - mu)**2/(2*sigma**2)) + C

# 4. Plot the fit
plt.plot(xfit, dog(xfit, *popt), 'r-', label='DoG fit', linewidth=2)

plt.xlabel('x')
plt.ylabel('Serial bias')

plt.show()
```

``` ipython
n_bins = 8

cmap = plt.get_cmap('Blues')

fig, ax = plt.subplots(1, 2, figsize=[2*width, height])

for i_name, name in enumerate(groups):
    serial_list = []

    if IF_LOAD:
        name = name + '_delay_rand'

        phi_list = pkl_load('phi_list_%s' % name, './models/nih/')
        curr_list = pkl_load('curr_list_%s' % name, './models/nih/')
        prev_list = pkl_load('prev_list_%s' % name, './models/nih/')
        start_idx = pkl_load('start_idx_%s' % name, './models/nih/')
        delay_list = pkl_load('delay_list_%s' % name, './models/nih/')

    serial_list = []
    unique_delay = np.unique(delay_list)
    print(unique_delay, len(delay_list))
    n_unique = len(unique_delay)
    colors = [cmap((i+1)/ n_unique) for i in range(n_unique+1)]

    for i, delay in enumerate(unique_delay):
        # idx = np.where(delay_list == delay)[0]

        centers, bin_rel, centers_abs, bin_rel_abs = get_correct_error(n_bins, data[data.prev_delay==delay])

        ax[0].plot(centers_abs, bin_rel_abs['mean'], color=pal[i_name], alpha=(i+1) / len(unique_delay))

        loc = (centers_abs>=0) & (centers_abs<=90)
        serial_max = bin_rel_abs['mean'][loc].mean()
        serial_std = bin_rel_abs['mean'][loc].std(ddof=1) / np.sqrt(len(bin_rel_abs['mean'][loc]))

        serial_list.append([serial_max, serial_std])

    serial_list = np.array(serial_list).T

    xdelay = np.linspace(unique_delay[0], unique_delay[-1], serial_list.shape[1])

    ax[1].plot(xdelay, serial_list[0], '-', color=pal[i_name])
    ax[1].fill_between(xdelay, serial_list[0] - serial_list[1], serial_list[0] + serial_list[1], alpha=0.2, color=pal[i_name])

ax[0].set_xlabel('Rel. Loc. (°)')
ax[0].set_ylabel('Error (°)')
ax[0].axhline(0, ls="--", color='k')
ax[0].set_xticks(np.linspace(0, 180, 5))

ax[1].set_xlabel('Prev. Delay (s)')
ax[1].set_ylabel('Serial Bias (°)')
ax[1].axhline(0, ls='--', color='k')
ax[1].set_xticks(model.DELAY_LIST)

if IF_LOAD:
    plt.savefig('./figures/NIH/delay/sb_prev_delay_all.svg')
else:
    plt.savefig('./figures/NIH/delay/sb_prev_delay_%s.svg' % name)

plt.show()
```

``` ipython
```