<a href="https://colab.research.google.com/github/motorlearner/neuromatch/blob/main/laquitaine_wip.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
# set true to show widgets
show_widgets = False

In [2]:
# imports
import requests
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

In [3]:
# @title Read Data
# --------------------------------------------------------------------------- #


# fetch data
url = "https://github.com/steevelaquitaine/projInference/raw/gh-pages/data/csv/data01_direction4priors.csv"
try:
  RequestAPI = requests.get(url)
except requests.ConnectionError:
  print("Failed to download data. Please contact steeve.laquitaine@epfl.ch")
else:
  if RequestAPI.status_code != requests.codes.ok:
    print("Failed to download data. Please contact steeve.laquitaine@epfl.ch")
  else:
    with open("data01_direction4priors.csv", "wb") as fid:
      fid.write(RequestAPI.content)

# read data
data = pd.read_csv("data01_direction4priors.csv")

In [4]:
# @title Helpers: Process Data
# --------------------------------------------------------------------------- #


# COORDINATE TRANSFORMS ----------------------------------------------------- #

def cart2pol(x:np.array, y:np.array):
  """
  Convert cartesian `(x,y)` to polar `(deg,mag)`.

  Args:
      `x`: scalar or array of x-coordinates
      `y`: scalar or array of y-coordinates

  Returns:
      Tuple of `(deg,mag)`, where `deg` is in the interval
      [0,360) counterclockwise from the positive x axis.
  """
  # compute
  deg = (np.degrees(np.arctan2(y,x)) + 360) % 360
  mag = np.hypot(x,y)
  return deg,mag

def pol2cart(deg: np.array, mag: np.array):
  """
  Convert polar `(deg,mag)` to cartesian `(x,y)`.

  Args:
      `deg`: scalar or array of angles in degrees [0,360)
      `mag`: scalar or array of magnitudes

  Returns:
      Tuple of `(x, y)` coordinates.
  """
  # compute
  rad = np.radians(deg)
  x = mag * np.cos(rad)
  y = mag * np.sin(rad)
  return x, y


# CIRCULAR DISTANCE --------------------------------------------------------- #

def circdiff(angle:np.array, reference:np.array):
  """
  Compute the signed minimal circular distance from `angle` to `reference`.

  Args:
      `angle`: scalar or array of angles in degrees [0, 360)
      `reference`: scalar reference angle in degrees [0, 360)

  Returns:
      Signed circular distance in degrees (-180, 180].
  """
  return ((angle - reference + 180) % 360) - 180


# PROCESS DATA WRAPPER ------------------------------------------------------ #

def process_data(data:pd.DataFrame):
  """
  Take the orginal `data`, rename existing columns with clearer
  and shorter names, and add additional columns.

  Args:
    `data`: original dataframe

  Returns:
    Dataframe with the following columns:
      - `subject_id` (int identifying subject)
      - `session_id` (int identifying session within a subject)
      - `run_id` (int identifying the run or block with a session)
      - `trial_id` (int identifying trial within run)
      - `trial_time` (start time of trial, first trial of run starts at 0)
      - `prior_mean` (prior mean in deg, always 225)
      - `prior_sd` (prior sd in deg, one of 10,20,40,80)
      - `stim_deg` (stimulus orientation in deg, one of 5,15,25,...355)
      - `stim_rel` (stimulus orientation relative to prior mean, from -180 to 180)
      - `stim_coh` (stimulus coherence, one of 6,12,24)
      - `init_deg` (initiation angle for response)
      - `rt` (reaction time)
      - `resp_x, resp_y`(cartesian response coords: x and y)
      - `resp_deg, resp_mag` (polar response coords: degrees and magnitude)
      - `resp_rel` (response degrees relative to prior mean, from -180 to 180)
      - `err` (response error in degrees)
      - `err_prev` (previous trial response error within a given run)
      - `err_prior` (same as `err` but sign is positive if in direction of prior mean)
      - `err_priornorm` (same as `err_prior` but normalized to distance of stimulus to prior mean)
  """
  # columns to discard
  cols_remove = ['experiment_id', 'experiment_name', 'raw_response_time']
  # columns to rename
  cols_rename = {
    'subject_id'                  : 'subject_id',
    'session_id'                  : 'session_id',
    'run_id'                      : 'run_id',
    'trial_index'                 : 'trial_id',
    'trial_time'                  : 'trial_time',
    'prior_mean'                  : 'prior_mean',
    'prior_std'                   : 'prior_sd',
    'motion_direction'            : 'stim_deg',
    'motion_coherence'            : 'stim_coh',
    'response_arrow_start_angle'  : 'init_deg',
    'reaction_time'               : 'rt',
    'estimate_x'                  : 'resp_x',
    'estimate_y'                  : 'resp_y'
  }
  # final column order
  cols_final = [
    'subject_id', 'session_id', 'run_id', 'trial_id', 'trial_time',
    'prior_mean', 'prior_sd', 'stim_deg', 'stim_rel', 'stim_coh',
    'init_deg', 'rt', 'resp_x', 'resp_y', 'resp_deg', 'resp_rel',
    'err', 'err_prev', 'err_prior', 'err_priornorm'
  ]
  # create deep copy
  df = data.copy()
  # drop or rename columns
  df.drop(cols_remove, axis=1, inplace=True)
  df.rename(columns=cols_rename, inplace=True)
  # add new columns
  df['stim_rel'] = circdiff(df.stim_deg, df.prior_mean)
  df['resp_deg'], df['resp_mag'] = cart2pol(df.resp_x, df.resp_y)
  df['resp_rel'] = circdiff(df.resp_deg, df.prior_mean)
  df['err'] = circdiff(df.resp_deg, df.stim_deg)
  df['err_prev'] = df.groupby(['subject_id', 'run_id'])['err'].shift(1)
  df['err_prior'] = np.where(df.stim_rel * df.err < 0, np.abs(df.err), -np.abs(df.err))
  df['err_priornorm'] = np.where(df.stim_rel != 0, df.err_prior / np.abs(df.stim_rel), np.nan)
  # reorder columns
  return df[cols_final]


In [5]:
# @title Process Data
# --------------------------------------------------------------------------- #
df = process_data(data)
# print info
df.head()

Unnamed: 0,subject_id,session_id,run_id,trial_id,trial_time,prior_mean,prior_sd,stim_deg,stim_rel,stim_coh,init_deg,rt,resp_x,resp_y,resp_deg,resp_rel,err,err_prev,err_prior,err_priornorm
0,1,1,1,1,0.0,225,10,225,0,0.12,,,-1.749685,-1.785666,225.583113,0.583113,0.583113,,-0.583113,
1,1,1,1,2,2.73073,225,10,225,0,0.12,,,-1.819693,-1.714269,223.291282,-1.708718,-1.708718,0.583113,-1.708718,
2,1,1,1,3,4.91395,225,10,235,10,0.06,,,-1.562674,-1.951422,231.312691,6.312691,-3.687309,-1.708718,3.687309,0.368731
3,1,1,1,4,6.997296,225,10,225,0,0.06,,,-1.601388,-1.919781,230.166776,5.166776,5.166776,-3.687309,-5.166776,
4,1,1,1,5,9.09713,225,10,215,-10,0.24,,,-1.639461,-1.887371,229.02086,4.02086,14.02086,5.166776,14.02086,1.402086


In [6]:
# @title Helpers: Plots
# --------------------------------------------------------------------------- #


# PARAMS / SETTINGS --------------------------------------------------------- #

# color for each prior_sd
colormap = {
    80: [0.5, 0, 0],
    40: [1, 0.2, 0],
    20: [1, 0.6, 0],
    10: [0.75, 0.75, 0]
}


# FUNCTIONS ----------------------------------------------------------------- #

def show_color(color):
    fig, ax = plt.subplots(figsize=(2, 2))
    ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=color))
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.axis('off')
    plt.show()

In [28]:
# @title Plot: Raw Data (per Subject)
def plot_rawdata(subject_id:int):
  dat = df[df.subject_id == subject_id]

  session_ids  = np.sort(dat.session_id.unique())
  run_ids      = np.sort(dat.run_id.unique())
  session_runs = dat.groupby('session_id')['run_id'].nunique().max()

  coh_set      = np.sort(dat['stim_coh'].unique())
  priorsd_set  = np.sort(dat['prior_sd'].unique())


  fig, axs = plt.subplots(len(run_ids), len(coh_set), figsize=(1.5*len(coh_set), 1.5*len(run_ids)))

  # if axs is 1D (i.e. only 1 run or coherence), make it 2D for uniform indexing
  if len(run_ids) == 1 and len(coh_set) == 1:
    axs = np.array([[axs]])
  elif len(run_ids) == 1:
    axs = axs[np.newaxis, :]
  elif len(coh_set) == 1:
    axs = axs[:, np.newaxis]

  for i, run_id in enumerate(run_ids):
    for j, coh in enumerate(coh_set):
      ax = axs[i, j]
      ax.set_xlim(-180, 180)
      ax.set_ylim(-180, 180)
      ax.set_aspect('equal', adjustable='box')
      ax.set_title(f"Coh {coh_set[j]:.0%}", fontsize=9)
      ax.set_xticks(np.linspace(-180,180,5))
      ax.set_yticks(np.linspace(-180,180,5))
      ax.tick_params(axis='both', labelsize=8)

      # Hide x tick labels except for bottom row
      if i < len(run_ids) - 1:
        ax.set_xticklabels([])
      else:
        ax.tick_params(axis='x', labelrotation=45)

      # Hide y tick labels except for leftmost column
      if j > 0:
          ax.set_yticklabels([])
      else:
        ax.annotate(
          f"Run {run_id}",
          xy=(-0.3, 0.5),
          xycoords='axes fraction',
          rotation=90,
          va='center',
          ha='right',
          fontsize=8,
          fontweight='normal'
        )

      if i==0 && j==1:
        ax.annotate(
          f"Subject {subject_id}",
          xy=(0.5, 1.3),
          xycoords='axes fraction',
          rotation=0,
          va='bottom',
          ha='center',
          fontsize=8,
          fontweight='bold'
        )

      # Filter data for this run and coherence
      subset = dat[(dat['run_id'] == run_id) & (dat['stim_coh'] == coh)]
      prior_sd = subset.prior_sd.unique()[0]

      # Scatter plot stimulus_direction vs response_deg
      ax.axline((0, 0), slope=1, linestyle='--', color='gray', linewidth=0.2)
      ax.scatter(
          subset.stim_rel, subset.resp_rel,
          s=7, alpha=0.5, color=colormap[prior_sd]
      )
      ax.text(
        0.98, 0.02, f"n={len(subset)}", transform=ax.transAxes,
        ha='right', va='bottom', fontsize=8, color='black'
      )


  run_id_to_row = {run: i for i, run in enumerate(run_ids)}
  session_to_firstrun = dat.groupby('session_id')['run_id'].min().to_dict()

  for session_id in session_ids:
    first_run_id = session_to_firstrun[session_id]
    row_idx = np.flatnonzero(run_ids == first_run_id)[0]
    ax = axs[row_idx, 0]  # first column of that run's row

    # Add session label on left of that subplot
    ax.annotate(
        f"Session {session_id}",
        xy=(-0.5, 0.5),
        xycoords='axes fraction',
        rotation=90,
        va='center',
        ha='right',
        fontsize=10,
        fontweight='bold'
    )

  # plt.tight_layout()
  # fig.tight_layout(pad=0.5)
  fig.subplots_adjust(hspace=0.2, wspace=0.5)



subject_ids = sorted(df['subject_id'].unique())
dropdown = widgets.Dropdown(
    options=subject_ids,
    value=subject_ids[0],
    description='Subject: '
)

if show_widgets:
  plotoutput = widgets.interactive_output(plot_rawdata, {'subject_id': dropdown})
  container = widgets.VBox([dropdown, plotoutput])
  display(container)

VBox(children=(Dropdown(description='Subject: ', options=(np.int64(1), np.int64(2), np.int64(3), np.int64(4), …

In [8]:
# @title Plot: Error over Trials (per Subject)
def plot_err(subject_id:int, prior_sds:list=[10,20,40,80]):
  dat = df[(df.subject_id == subject_id) & df.prior_sd.isin(prior_sds)]
  dat.sort_values(by=['session_id', 'run_id', 'trial_id']).reset_index(drop=True)

  # max errors
  max_err = df[(df.subject_id == subject_id)].err.abs().max()

  # unique values
  prior_sd_set = sorted(dat.prior_sd.unique())
  stim_coh_set = sorted(dat.stim_coh.unique())

  fig, axes = plt.subplots(
    len(stim_coh_set), 1,
    figsize=(12, 6),
    sharex=True, sharey=True
  )

  for i, stim_coh in enumerate(stim_coh_set):
    ax = axes[i]
    # prep data
    thisdat = dat[(dat.stim_coh == stim_coh)].reset_index(drop=True)
    xall = np.arange(len(thisdat.err)) + 1
    yall = thisdat.err
    # plot data
    for j, prior_sd in enumerate(prior_sd_set):
      mask = thisdat.prior_sd==prior_sd
      x = np.where(mask, xall, np.nan)
      y = np.where(mask, yall, np.nan)
      y_avg = pd.Series(y).rolling(window=12, center=True).mean()
      ax.plot(x, y, ls='-', lw=0.5, alpha=0.5, color=colormap[prior_sd])
      ax.plot(x, y_avg, ls='-', lw=3.0, alpha=0.8, color='white')
      ax.plot(x, y_avg, ls='-', lw=1.0, alpha=1.0, color=colormap[prior_sd])
    # plot text
    ax.set_title(f'Coherence={stim_coh:.0%}', fontsize=10)
    # plot axes
    ax.set_xlim(1, max([len(dat[dat.stim_coh==x]) for x in stim_coh_set]))
    ax.set_ylim(-max_err, max_err)
    ax.set_xlabel('Trial #' if i==len(stim_coh_set)-1 else '')
    ax.set_ylabel('Error [deg]' if i==1 else '')

# interactive
subject_ids = sorted(df['subject_id'].unique())
prior_sds = sorted(df['prior_sd'].unique())

dropdown = widgets.Dropdown(
  options=subject_ids,
  value=subject_ids[0],
  description='Subject: '
)
checklist = widgets.SelectMultiple(
  options=prior_sds,
  value=tuple(prior_sds),  # all selected by default
  description='Prior SDs:',
  layout={'height': '75px'}
)

if show_widgets:
  plotoutput = widgets.interactive_output(plot_err, {'subject_id': dropdown, 'prior_sds': checklist})
  controls   = widgets.HBox([dropdown, checklist])
  container  = widgets.VBox([controls, plotoutput])
  display(container)