# Composite figure

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as sts
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
import sys
sys.path.insert(0, '.')

from functions import (
    load_cardinal_task_data, load_all_task_data, filter_common_monkeys,
    compute_flat_tuning, cross_task_cv,
    build_representations, procrustes_distance_matrix,
    assign_age_groups, cross_age_analysis, cross_monkey_by_group,
    build_epoch_representations, cross_epoch_distances,
    pooled_tuning_by_group, pca_reduce_tuning,
    tuning_to_matrix, generalized_procrustes,
    load_behavioral_data, behavioral_distance_matrices,
    draw_3d_alignment, draw_cross_task_bars,
    draw_cross_age_bars, draw_cross_monkey_scatter,
    draw_neural_vs_behavior, draw_cross_epoch_vs_behavior,
    draw_correlation_matrices,
    STIM_COLORS, STIM_LABELS, TASK_EPOCHS, TASK_COLORS,
)

In [None]:
# ── Paths ─────────────────────────────────────────────────────────────────
DATA_DIR = '../data_raw'
BEH_SAC = '../original_repo/behavior/sac_data.csv'
BEH_SAC_ODRD = '../original_repo/GAM/data/sac_odrd.csv'

# ── Age grouping ──────────────────────────────────────────────────────────
AGE_EDGES = (48, 60)
AGE_GROUP_LABELS = ['young', 'middle', 'old']

# ── PCA / neuron thresholds ──────────────────────────────────────────────
BIN_MS = 25
N_PCS_3D = 4
MIN_NEURONS_3D = N_PCS_3D + 1
N_PCS = 5
MIN_NEURONS = N_PCS + 1
N_PCS_CT = 8
MIN_NEURONS_CT = N_PCS_CT + 1
BIN_MS_CT = 50
N_PCS_CE = 8
MIN_NEURONS_CE = 10
N_AGE_BINS = 3

# ── Panels 4-6: epochs & colors ─────────────────────────────────────────
TASK_LIST = ['ODR 1.5s', 'ODR 3.0s', 'ODRd']
PLOT_EPOCHS = ['cue', 'delay', 'response']
EPOCH_COLORS = {'cue': '#2196F3', 'delay': '#FF9800', 'response': '#9C27B0'}

# ── Panel 7: common epochs for 3-way cross-task ─────────────────────────
COMMON_EPOCHS_CT = {'cue': (0, 500), 'delay': (500, 1700)}
T_RANGES = {'ODR 1.5s': (-1000, 2500), 'ODR 3.0s': (-1000, 3500), 'ODRd': (-1000, 3500)}

# ── Panels 12-14: cross-epoch definitions ───────────────────────────────
CROSS_EPOCH_DEFS = {
    'ODR 1.5s': dict(t_range=(-1000, 2500),
                     epochs={'cue': (0, 500), 'delay': (1250, 2000), 'response': (2000, 2500)}),
    'ODR 3.0s': dict(t_range=(-1000, 3500),
                     epochs={'cue': (0, 500), 'delay': (1750, 3000), 'response': (3000, 3500)}),
    'ODRd':     dict(t_range=(-1000, 4000),
                     epochs={'cue': (0, 500), 'delay': (2200, 3000), 'response': (3000, 3500)}),
}
COMPARISONS = [('cue', 'delay'), ('delay', 'response'), ('cue', 'response')]

In [None]:
cardinal_data = load_cardinal_task_data(DATA_DIR)

task_results = {}
for task_name in TASK_LIST:
    all_epochs = TASK_EPOCHS[task_name]['epochs']
    epochs = {k: all_epochs[k] for k in PLOT_EPOCHS}

    grouped_t, enames = pooled_tuning_by_group(
        {task_name: cardinal_data[task_name]}, epochs, AGE_EDGES, bin_ms=BIN_MS)
    reduced_t = pca_reduce_tuning(grouped_t, n_pcs=N_PCS_3D, min_neurons=MIN_NEURONS_3D)

    all_mats = []
    for mid in sorted(reduced_t.keys()):
        for g in sorted(reduced_t[mid].keys()):
            all_mats.append(tuning_to_matrix(reduced_t[mid][g], n_dims=3))
    aligned_all, grand_mean = generalized_procrustes(all_mats)

    lim = np.max(np.abs(grand_mean)) * 1.6
    task_results[task_name] = dict(aligned_all=aligned_all, grand_mean=grand_mean,
                                   enames=enames, lim=lim)

print('Panels 4-6: per-task alignment computed.')

In [None]:
tasks_ct, _ = filter_common_monkeys(cardinal_data)

tuning_flat = {}
for name in tasks_ct:
    flat, _, _ = compute_flat_tuning(tasks_ct[name]['data'], T_RANGES[name], COMMON_EPOCHS_CT, BIN_MS_CT)
    tuning_flat[name] = flat

task_ids = {name: tasks_ct[name]['ids'] for name in tasks_ct}
ct_results = cross_task_cv(tuning_flat, task_ids, N_PCS_CT, MIN_NEURONS_CT, 100, 42)
print('Panel 7: cross-task CV computed.')

In [None]:
task_data = load_all_task_data(DATA_DIR)

# PSTHs and tuning
psth_data = {}
for name, cfg in TASK_EPOCHS.items():
    flat, rates, bc = compute_flat_tuning(task_data[name]['data'], cfg['t_range'], cfg['epochs'], BIN_MS)
    psth_data[name] = dict(rates=rates, bc=bc, flat=flat)

# Age groups
age_groups = {}
for task_name in task_data:
    age_groups[task_name] = assign_age_groups(task_data[task_name]['abs_age'], AGE_EDGES)

# Panel 8: cross-age results
indiv_results = {}
for task_name in task_data:
    ids = task_data[task_name]['ids']
    tuning = psth_data[task_name]['flat']
    ag = age_groups[task_name]
    entries = build_representations(tuning, ids, ag, n_pcs=N_PCS, min_neurons=MIN_NEURONS, zscore=True)
    dist = procrustes_distance_matrix(entries)
    ca = cross_age_analysis(entries, dist)
    indiv_results[task_name] = dict(entries=entries, dist=dist, cross_age=ca)

print('Panel 8: cross-age computed.')

# Panel 9: cross-monkey by group
results_by_group, pooled = cross_monkey_by_group(
    task_data, psth_data, age_groups, N_PCS, MIN_NEURONS, AGE_GROUP_LABELS)

print('Panel 9: cross-monkey by group computed.')

In [None]:
beh_df = load_behavioral_data(BEH_SAC, sac_odrd_path=BEH_SAC_ODRD)

beh_dist = {}
for task_name in indiv_results:
    R = indiv_results[task_name]
    di_dist, rt_dist, di_vals, rt_vals = behavioral_distance_matrices(
        beh_df, R['entries'], AGE_EDGES, task_name)
    beh_dist[task_name] = dict(di_dist=di_dist, rt_dist=rt_dist,
                               di_vals=di_vals, rt_vals=rt_vals)

print('Panels 10-11: behavioral distances computed.')

In [None]:
epoch_reps, ce_age_groups, monkey_edges = build_epoch_representations(
    task_data, CROSS_EPOCH_DEFS, N_PCS_CE, MIN_NEURONS_CE, bin_ms=BIN_MS, n_age_bins=N_AGE_BINS)
cross_epoch = cross_epoch_distances(epoch_reps, COMPARISONS)

print('Panels 12-14: cross-epoch distances computed.')

In [None]:
n_conds = 4

plt.rcParams.update({'font.size': 11})

fig = plt.figure(figsize=(16, 16))
gs = GridSpec(4, 15, figure=fig, hspace=0.35, wspace=0.4)

panels = []

# Row 1: panels 0-2 (each spans 5 of 15 cols)
for col in range(3):
    panels.append(fig.add_subplot(gs[0, col*5:(col+1)*5]))

# Row 2: panels 3-5 (3D per-task alignment)
for col in range(3):
    panels.append(fig.add_subplot(gs[1, col*5:(col+1)*5], projection='3d'))

# Row 3: panels 6-8
for col in range(3):
    panels.append(fig.add_subplot(gs[2, col*5:(col+1)*5]))

# Row 4: panels 9-12 only (panel 13 slot reserved for correlation matrices)
for col in range(4):
    panels.append(fig.add_subplot(gs[3, col*3:(col+1)*3]))

# ── Expand 3D panels to fill their space ─────────────────────────────────
for idx in (3, 4, 5):
    pos = panels[idx].get_position()
    expand = 1.4
    w, h = pos.width * expand, pos.height * expand
    cx, cy = pos.x0 + pos.width / 2, pos.y0 + pos.height / 2
    panels[idx].set_position([cx - w/2, cy - h/2, w, h])

# ── Panels 4-6: per-task 3D alignment ────────────────────────────────────
for col, task_name in enumerate(TASK_LIST):
    draw_3d_alignment(panels[3 + col], task_results[task_name],
                      PLOT_EPOCHS, EPOCH_COLORS, STIM_COLORS, n_conds)

# Legend for panels 4-6
dir_handles = [Line2D([0], [0], marker='o', color='w',
                       markerfacecolor='k', markeredgecolor=c,
                       markeredgewidth=1.5, markersize=10, label=l)
               for c, l in zip(STIM_COLORS, STIM_LABELS)]
epoch_handles = [Line2D([0], [0], color=EPOCH_COLORS[e], lw=3, label=e)
                 for e in PLOT_EPOCHS]
fig.legend(handles=dir_handles + epoch_handles,
           loc='lower center', ncol=len(STIM_COLORS) + len(PLOT_EPOCHS),
           frameon=False, bbox_to_anchor=(0.5, 0.67))

# ── Panel 7: cross-task bar plot ─────────────────────────────────────────
draw_cross_task_bars(panels[6], ct_results)

# ── Panel 8: cross-age bar plot ──────────────────────────────────────────
draw_cross_age_bars(panels[7], indiv_results, TASK_COLORS)

# ── Panel 9: cross-monkey by group ──────────────────────────────────────
draw_cross_monkey_scatter(panels[8], results_by_group, pooled, AGE_GROUP_LABELS, TASK_COLORS)

# ── Panels 10-11: neural vs DI / RT ─────────────────────────────────────
draw_neural_vs_behavior(panels[9], indiv_results, beh_dist, 'di_dist', TASK_COLORS,
                        xlabel='DI distance')
draw_neural_vs_behavior(panels[10], indiv_results, beh_dist, 'rt_dist', TASK_COLORS,
                        xlabel='RT distance', show_ylabel=False, show_left_spine=False)

# ── Panels 12-13: cross-epoch delay→response vs DI / RT ─────────────────
ce_label = 'delay\u2192response'
draw_cross_epoch_vs_behavior(panels[11], cross_epoch, CROSS_EPOCH_DEFS, beh_df, monkey_edges,
                             ce_label, 'DI', TASK_COLORS,
                             xlabel='Procrustes dist.\n(delay\u2192resp.)', ylabel='DI')
draw_cross_epoch_vs_behavior(panels[12], cross_epoch, CROSS_EPOCH_DEFS, beh_df, monkey_edges,
                             ce_label, 'RT', TASK_COLORS,
                             xlabel='Procrustes dist.\n(delay\u2192resp.)', ylabel='RT',
                             show_left_spine=False)

# ── Panel 14: correlation matrices (uses gridspec slot directly) ─────────
pairs_14 = [('cue', 'delay'), ('cue', 'response'), ('delay', 'response')]
pos_map = {0: (0, 0), 1: (0, 1), 2: (1, 1)}
draw_correlation_matrices(fig, gs[3, 12:15], cross_epoch, CROSS_EPOCH_DEFS,
                          beh_df, monkey_edges, pairs_14, pos_map)

# ── Label remaining empty panels ─────────────────────────────────────────
for i in (0, 1, 2):
    ax = panels[i]
    ax.text(0.5, 0.5, str(i + 1), transform=ax.transAxes,
            fontsize=28, fontweight='bold', ha='center', va='center', color='0.3')
    ax.set_xticks([]); ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_edgecolor('0.7')

plt.suptitle('Figure layout', fontsize=16, y=0.98)