<a href="https://colab.research.google.com/github/motorlearner/neuromatch/blob/main/plot_rawdata.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!rm -rf getdata.py* sample_data data01_direction4priors.csv
!wget -q https://raw.githubusercontent.com/motorlearner/neuromatch/refs/heads/main/getdata.py
%run getdata.py

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

# make df visible
df = df

# plot function
def plot_rawdata(subject_id:int):
  dat = df[df.subject_id == subject_id]
  # sets
  coh_set      = np.sort(dat['stim_coh'].unique())
  priorsd_set  = np.sort(dat['prior_sd'].unique())
  session_ids  = np.sort(dat.session_id.unique())
  run_ids      = np.sort(dat.run_id.unique())
  # number of runs per session
  session_runs = dat.groupby('session_id')['run_id'].nunique().max()

  # iniit plot
  fig, axs = plt.subplots(
    len(run_ids), len(coh_set),
    figsize=(1.5*len(coh_set), 1.5*len(run_ids))
  )
  # plot
  for i, run_id in enumerate(run_ids):
    for j, coh in enumerate(coh_set):
      ax = axs[i, j]
      # axes
      ax.set_xlim(-180, 180)
      ax.set_ylim(-180, 180)
      ax.set_aspect('equal', adjustable='box')
      ax.set_xticks(np.linspace(-180,180,5))
      ax.set_yticks(np.linspace(-180,180,5))
      ax.tick_params(axis='both', labelsize=8)
      # title
      if i==0:
        ax.set_title(f"Coh {coh_set[j]:.0%}", fontsize=9)
      # xlabels
      if i == len(run_ids)-1:
        ax.tick_params(axis='x', labelrotation=45)
      else:
        ax.set_xticklabels([])
      # ylabels
      if j == 0:
        ax.tick_params('y', labelrotation=0)
      else:
        ax.set_yticklabels([])
      # annotations
      if j==len(coh_set)-1:
        # run id
        ax.annotate(
          f"Run {run_id}", xy=(1.1, 0.5), xycoords='axes fraction', va='center', ha='left',
          rotation=0, fontsize=8, fontweight='normal'
        )
        # subject id
        if i==0:
          ax.annotate(
            f"Subject {subject_id}",
            xy=(1.1, 1.3), xycoords='axes fraction', va='bottom', ha='left',
            rotation=0, fontsize=9, fontweight='bold'
          )
      # data this run and coherence
      thisdat  = dat[(dat['run_id'] == run_id) & (dat['stim_coh'] == coh)]
      prior_sd = thisdat.prior_sd.unique()[0]
      # plot data
      ax.axline((0, 0), slope=1, linestyle='--', color='gray', linewidth=0.2)
      ax.scatter(
        thisdat.stim_rel, thisdat.resp_rel,
        s=2, alpha=0.5, color=colormap[prior_sd]
      )
      ax.text(
        0.98, 0.02, f"n={len(thisdat)}", transform=ax.transAxes,
        ha='right', va='bottom', fontsize=8, color='black'
      )

  #
  run_id_to_row = {run: i for i, run in enumerate(run_ids)}
  session_to_firstrun = dat.groupby('session_id')['run_id'].min().to_dict()

  for session_id in session_ids:
    first_run_id = session_to_firstrun[session_id]
    row_idx = np.flatnonzero(run_ids == first_run_id)[0]
    ax = axs[row_idx, -1]  # first column of that run's row

    # annotate session id
    ax.annotate(
      f"Session {session_id}", xy=(1.1, 1), xycoords='axes fraction', va='top', ha='left',
      rotation=0, fontsize=9, fontweight='bold'
    )

  fig.subplots_adjust(wspace=0.1, hspace=0.15)
  # end

# interactive
subject_ids = sorted(df['subject_id'].unique())

dropdown = widgets.Dropdown(
    options=subject_ids,
    value=subject_ids[0],
    description='Subject: '
)

if True:
  # components
  plotoutput = widgets.interactive_output(plot_rawdata, {'subject_id': dropdown})
  controls   = widgets.HBox([dropdown])
  container  = widgets.VBox([controls, plotoutput])
  # layout
  controls.layout = widgets.Layout(margin="0 0 20px 0")
  # show
  display(container)