# Tuning curves: 4 cardinal directions, pooled across tasks

Common conditions: 0°, 90°, 180°, 270° (cardinal only).  
Common epochs: cue (0–500 ms), delay (500–1700 ms).  
Neurons pooled across ODR 1.5s, ODR 3.0s, and ODRd.

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import os, sys
sys.path.insert(0, '.')

from functions import (
    load_cardinal_task_data,
    assign_age_groups, pooled_tuning_by_group, pca_reduce_tuning,
    tuning_to_matrix, plot_3d_representation, wall_projections,
    generalized_procrustes,
    plot_3d_grid, plot_within_monkey_alignment, plot_global_alignment,
    STIM_COLORS, STIM_LABELS, AGE_COLORS, AGE_GROUP_LABELS,
)

DATA_DIR = '../data_raw'

AGE_EDGES = (48, 60)
COMMON_EPOCHS = {'cue': (0, 500), 'delay': (500, 1700)}
BIN_MS = 25
N_PCS = 4
MIN_NEURONS = N_PCS + 1

## 1. Load data (cardinal directions only)

In [None]:
cardinal_data = load_cardinal_task_data(DATA_DIR)

for name, td in cardinal_data.items():
    print(f'{name}: {td["data"].shape[0]} neurons, {td["data"].shape[1]} conditions, '
          f'monkeys: {sorted(set(td["ids"]))}')

## 2. Compute pooled tuning curves

In [3]:
grouped, epoch_names = pooled_tuning_by_group(
    cardinal_data, COMMON_EPOCHS, AGE_EDGES, bin_ms=BIN_MS)

print(f'Epochs: {epoch_names}')
print(f'Monkeys: {list(grouped.keys())}\n')
for mid, groups in grouped.items():
    for g, tc in groups.items():
        print(f'  {mid} / {AGE_GROUP_LABELS[g]}: {tc.shape[0]} neurons, '
              f'shape {tc.shape}')

  neuron 0/1180
  neuron 500/1180
  neuron 1000/1180
  neuron 0/922
  neuron 500/922
  neuron 0/1319
  neuron 500/1319
  neuron 1000/1319
Epochs: ['cue', 'delay']
Monkeys: [np.str_('OLI'), np.str_('PIC'), np.str_('QUA'), np.str_('ROS'), np.str_('SON'), np.str_('TRI'), np.str_('UNI'), np.str_('VIK')]

  OLI / young: 20 neurons, shape (20, 4, 2)
  OLI / middle: 224 neurons, shape (224, 4, 2)
  OLI / old: 334 neurons, shape (334, 4, 2)
  PIC / young: 40 neurons, shape (40, 4, 2)
  PIC / middle: 248 neurons, shape (248, 4, 2)
  PIC / old: 240 neurons, shape (240, 4, 2)
  QUA / middle: 56 neurons, shape (56, 4, 2)
  ROS / young: 84 neurons, shape (84, 4, 2)
  ROS / middle: 132 neurons, shape (132, 4, 2)
  ROS / old: 256 neurons, shape (256, 4, 2)
  SON / young: 62 neurons, shape (62, 4, 2)
  SON / middle: 194 neurons, shape (194, 4, 2)
  SON / old: 74 neurons, shape (74, 4, 2)
  TRI / young: 49 neurons, shape (49, 4, 2)
  TRI / middle: 105 neurons, shape (105, 4, 2)
  TRI / old: 103 neurons

## 3. PCA-reduce tuning curves

In [4]:
reduced = pca_reduce_tuning(grouped, n_pcs=N_PCS, min_neurons=MIN_NEURONS)

for mid, groups in reduced.items():
    for g, info in groups.items():
        print(f'{mid} / {AGE_GROUP_LABELS[g]}: {info["n_neurons"]} neurons -> '
              f'{info["tc"].shape} (PCs x conds x epochs), '
              f'var explained {info["var_explained"]:.1%}')

OLI / young: 20 neurons -> (4, 4, 2) (PCs x conds x epochs), var explained 85.6%
OLI / middle: 214 neurons -> (4, 4, 2) (PCs x conds x epochs), var explained 74.4%
OLI / old: 334 neurons -> (4, 4, 2) (PCs x conds x epochs), var explained 74.7%
PIC / young: 35 neurons -> (4, 4, 2) (PCs x conds x epochs), var explained 75.1%
PIC / middle: 248 neurons -> (4, 4, 2) (PCs x conds x epochs), var explained 76.4%
PIC / old: 228 neurons -> (4, 4, 2) (PCs x conds x epochs), var explained 71.8%
QUA / middle: 56 neurons -> (4, 4, 2) (PCs x conds x epochs), var explained 75.5%
ROS / young: 82 neurons -> (4, 4, 2) (PCs x conds x epochs), var explained 76.1%
ROS / middle: 132 neurons -> (4, 4, 2) (PCs x conds x epochs), var explained 70.7%
ROS / old: 256 neurons -> (4, 4, 2) (PCs x conds x epochs), var explained 72.3%
SON / young: 62 neurons -> (4, 4, 2) (PCs x conds x epochs), var explained 79.1%
SON / middle: 194 neurons -> (4, 4, 2) (PCs x conds x epochs), var explained 70.5%
SON / old: 74 neurons 

In [None]:
plot_3d_grid(reduced, epoch_idx=0, title='Cue epoch representations')

## 4. Within-monkey alignment (age groups aligned per monkey)

In [None]:
N_DIMS = 3
n_conds = 4
n_epochs = len(COMMON_EPOCHS)
cue_idx = np.arange(0, n_conds * n_epochs, n_epochs)

plot_within_monkey_alignment(reduced, cue_idx, title='Within-monkey alignment (age groups)')

## 5. Global alignment (all monkey x age groups)

In [None]:
plot_global_alignment(reduced, cue_idx, title='Global alignment (all monkeys x ages)')

## 6. Within-monkey alignment – delay epoch

In [None]:
delay_idx = np.arange(1, n_conds * n_epochs, n_epochs)

plot_within_monkey_alignment(reduced, delay_idx, title='Within-monkey alignment – delay epoch')

## 7. Global alignment – delay epoch

In [None]:
plot_global_alignment(reduced, delay_idx, title='Global alignment – delay epoch')

## 8. Per-task global alignment (cue / delay / response)

4 cardinal directions for all tasks.  
Response epoch timing differs per task, so tuning is computed separately.

In [None]:
from functions import TASK_EPOCHS, plot_surface_patch, wall_surface_projections
from matplotlib.lines import Line2D

PLOT_EPOCHS = ['cue', 'delay', 'response']
EPOCH_COLORS = {'cue': '#2196F3', 'delay': '#FF9800', 'response': '#9C27B0'}
task_list = ['ODR 1.5s', 'ODR 3.0s', 'ODRd']

# ── First pass: compute alignments and per-task axis limits ───────────────
task_results = {}

for task_name in task_list:
    all_epochs = TASK_EPOCHS[task_name]['epochs']
    epochs = {k: all_epochs[k] for k in PLOT_EPOCHS}

    grouped_t, enames = pooled_tuning_by_group(
        {task_name: cardinal_data[task_name]}, epochs, AGE_EDGES, bin_ms=BIN_MS)
    reduced_t = pca_reduce_tuning(grouped_t, n_pcs=N_PCS, min_neurons=MIN_NEURONS)

    all_mats = []
    for mid in sorted(reduced_t.keys()):
        for g in sorted(reduced_t[mid].keys()):
            all_mats.append(tuning_to_matrix(reduced_t[mid][g], n_dims=3))
    aligned_all, grand_mean = generalized_procrustes(all_mats)

    lim = np.max(np.abs(grand_mean)) * 1.6

    task_results[task_name] = dict(aligned_all=aligned_all, grand_mean=grand_mean,
                                   enames=enames, lim=lim)

# ── Plot: one subplot per task, all epochs overlaid ───────────────────────
n_conds = 4
n_ep = len(PLOT_EPOCHS)
fig = plt.figure(figsize=(18, 6))

for col, task_name in enumerate(task_list):
    res = task_results[task_name]
    enames = res['enames']
    lim = res['lim']

    ax = fig.add_subplot(1, 3, col + 1, projection='3d')
    ax.set_xlim(-lim, lim); ax.set_ylim(-lim, lim); ax.set_zlim(-lim, lim)

    for ename in PLOT_EPOCHS:
        ei = enames.index(ename)
        epoch_idx = np.arange(ei, n_conds * n_ep, n_ep)
        ec = EPOCH_COLORS[ename]
        mean_pts = res['grand_mean'][epoch_idx]

        plot_surface_patch(ax, mean_pts, color=ec, alpha=0.2)

        loop = np.vstack([mean_pts, mean_pts[0:1]])
        ax.plot(loop[:, 0], loop[:, 1], loop[:, 2], '-', color=ec,
                lw=2.5, alpha=0.6, zorder=1)
        for i in range(len(mean_pts)):
            ax.scatter(mean_pts[i, 0], mean_pts[i, 1], mean_pts[i, 2],
                       s=120, color='k', alpha=1.0,
                       edgecolors=STIM_COLORS[i], linewidths=1.5,
                       zorder=2, clip_on=False)

        wall_projections(ax, mean_pts, color=ec, alpha=0.15)
        wall_surface_projections(ax, mean_pts, color=ec, alpha=0.10)

    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
    ax.set_xlabel('PC1', fontsize=8)
    ax.set_ylabel('PC2', fontsize=8)
    ax.set_zlabel('PC3', fontsize=8)
    ax.set_title(task_name, fontsize=13)

# Combined legend: directions (dot edges) + epochs (surface/line)
dir_handles = [Line2D([0], [0], marker='o', color='w',
                       markerfacecolor='k', markeredgecolor=c,
                       markeredgewidth=1.5, markersize=10, label=l)
               for c, l in zip(STIM_COLORS, STIM_LABELS)]
epoch_handles = [Line2D([0], [0], color=EPOCH_COLORS[e], lw=3, label=e)
                 for e in PLOT_EPOCHS]
fig.legend(handles=dir_handles + epoch_handles, loc='lower center',
           ncol=len(STIM_COLORS) + len(PLOT_EPOCHS), fontsize=12,
           frameon=False, bbox_to_anchor=(0.5, -0.02))

fig.suptitle('Per-task global alignment (cue / delay / response)', fontsize=15)
plt.tight_layout()

In [49]:
reduced_t[mid][g]["tc"].shape

(4, 4, 3)