<a href="https://colab.research.google.com/github/motorlearner/neuromatch/blob/main/plot_errorseries.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!rm -rf getdata.py* sample_data data01_direction4priors.csv
!wget -q https://raw.githubusercontent.com/motorlearner/neuromatch/refs/heads/main/getdata.py
%run getdata.py

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

# make df visible
df = df

# plot function
def plot_err(subject_id:int, prior_sds:list=[10,20,40,80], towards_prior:bool=False):
  dat = df[(df.subject_id == subject_id) & df.prior_sd.isin(prior_sds)]
  dat.sort_values(by=['session_id', 'run_id', 'trial_id']).reset_index(drop=True)

  # max errors
  max_err = df[(df.subject_id == subject_id)].err.abs().max()

  # unique values
  prior_sd_set = sorted(dat.prior_sd.unique())
  stim_coh_set = sorted(dat.stim_coh.unique())

  fig, axes = plt.subplots(
    len(stim_coh_set), 1,
    figsize=(12, 6),
    sharex=True, sharey=True
  )

  for i, stim_coh in enumerate(stim_coh_set):
    ax = axes[i]
    # prep data
    thisdat = dat[(dat.stim_coh == stim_coh)].reset_index(drop=True)
    xall = np.arange(len(thisdat.err)) + 1
    yall = thisdat.err if not towards_prior else thisdat.err_toprior
    # plot data
    for j, prior_sd in enumerate(prior_sd_set):
      mask = thisdat.prior_sd==prior_sd
      x = np.where(mask, xall, np.nan)
      y = np.where(mask, yall, np.nan)
      y_avg = pd.Series(y).rolling(window=12, center=True).mean()
      ax.plot(x, y, ls='-', lw=0.5, alpha=0.5, color=colormap[prior_sd])
      ax.plot(x, y_avg, ls='-', lw=3.0, alpha=0.8, color='white')
      ax.plot(x, y_avg, ls='-', lw=1.0, alpha=1.0, color=colormap[prior_sd])
      ax.axhline(y=0, color='white', ls='--', lw=1)
    # plot text
    ax.set_title(f'Coherence={stim_coh:.0%}', fontsize=10)
    # plot axes
    ax.set_xlim(1, max([len(dat[dat.stim_coh==x]) for x in stim_coh_set]))
    ax.set_ylim(-max_err, max_err)
    ax.set_xlabel('Trial #' if i==len(stim_coh_set)-1 else '')
    ax.set_ylabel('Error [deg]' if i==1 else '')

# interactive
subject_ids = sorted(df['subject_id'].unique())
prior_sds   = sorted(df['prior_sd'].unique())

dropdown = widgets.Dropdown(
  options=subject_ids,
  value=subject_ids[0],
  description='Subject: '
)
checklist = widgets.SelectMultiple(
  options=prior_sds,
  value=tuple(prior_sds),
  description='Prior SDs:'
)
tickbox = widgets.Checkbox(
    value=False,
    description='Error Positive if towards Prior Mean',
    indent=False
)

if True:
  # components
  plotoutput = widgets.interactive_output(plot_err, {'subject_id': dropdown, 'prior_sds': checklist, 'towards_prior': tickbox})
  controls   = widgets.HBox([dropdown, checklist, tickbox])
  container  = widgets.VBox([controls, plotoutput])
  # layout
  dropdown.layout  = widgets.Layout(margin='0 20px 0 0', width='150px')
  checklist.layout = widgets.Layout(margin='0 20px 0 0', width='200px', height='77px')
  tickbox.layout   = widgets.Layout(margin='0 20px 0 20px')
  controls.layout  = widgets.Layout(margin='0 0 20px 0')
  # show
  display(container)