# Cross-epoch geometry distance and behavioral correlations

For each task, compute tuning curves in cue, delay and response epochs.
Group neurons by monkey × age. Build PCA representations per epoch.
Compute Procrustes distance between delay–cue and delay–response per monkey–age pair.
Correlate with DI and RT.

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as sts
import os, sys
sys.path.insert(0, '.')

from functions import (
    load_all_task_data,
    build_epoch_representations, cross_epoch_distances,
    load_behavioral_data, get_behavioral_values,
    plot_cross_epoch_correlations,
    TASK_COLORS,
)

DATA_DIR     = '../data_raw'
BEH_SAC      = '../original_repo/behavior/sac_data.csv'
BEH_SAC_ODRD = '../original_repo/GAM/data/sac_odrd.csv'

N_PCS       = 8
MIN_NEURONS = 10
BIN_MS      = 25
N_AGE_BINS  = 3

# Cross-epoch uses late delay windows (different from default TASK_EPOCHS)
TASK_EPOCHS = {
    'ODR 1.5s': dict(t_range=(-1000, 2500),
                     epochs={'cue': (0, 500), 'delay': (1250, 2000), 'response': (2000, 2500)}),
    'ODR 3.0s': dict(t_range=(-1000, 3500),
                     epochs={'cue': (0, 500), 'delay': (1750, 3000), 'response': (3000, 3500)}),
    'ODRd':     dict(t_range=(-1000, 4000),
                     epochs={'cue': (0, 500), 'delay': (2200, 3000), 'response': (3000, 3500)}),
}

COMPARISONS = [('cue', 'delay'), ('delay', 'response'), ('cue', 'response')]

## 1. Load data

In [135]:
task_data = load_all_task_data(DATA_DIR)
beh_df = load_behavioral_data(BEH_SAC, sac_odrd_path=BEH_SAC_ODRD)

for name, T in task_data.items():
    print(f'{name}: {T["data"].shape[0]} neurons, {T["data"].shape[1]} conditions')

ODR 1.5s: 1180 neurons, 8 conditions
ODR 3.0s: 922 neurons, 8 conditions
ODRd: 1319 neurons, 20 conditions


## 2. Per-epoch PCA representations

In [None]:
epoch_reps, age_groups, monkey_edges = build_epoch_representations(
    task_data, TASK_EPOCHS, N_PCS, MIN_NEURONS, bin_ms=BIN_MS, n_age_bins=N_AGE_BINS)

for task_name in TASK_EPOCHS:
    kept = epoch_reps[task_name][list(TASK_EPOCHS[task_name]['epochs'].keys())[0]]
    print(f'{task_name}: {len(kept)} entries per epoch')

## 3. Cross-epoch Procrustes distances

In [None]:
cross_epoch = cross_epoch_distances(epoch_reps, COMPARISONS)

for task_name in TASK_EPOCHS:
    for label, rows in cross_epoch[task_name].items():
        dists = [r['distance'] for r in rows]
        print(f'{task_name} {label}: {len(rows)} pairs, '
              f'mean={np.mean(dists):.4f}, std={np.std(dists):.4f}')

## 4. Behavioral correlations

In [None]:
plot_cross_epoch_correlations(cross_epoch, beh_df, monkey_edges, TASK_EPOCHS,
                               COMPARISONS)

## 5. Summary: correlation matrices

In [None]:
from functions.behavior import get_behavioral_values

epoch_names = ['cue', 'delay', 'response']
pairs = [('cue', 'delay'), ('cue', 'response'), ('delay', 'response')]

# Compute pooled Spearman rho for each pair x (DI, RT)
rho_vals = {}   # {beh_name: list of 3 rho}
p_vals = {}     # {beh_name: list of 3 p}

for beh_name in ['DI', 'RT']:
    rhos, ps = [], []
    for ea, eb in pairs:
        label = f'{ea}\u2192{eb}'
        all_d, all_beh = [], []
        for task_name in TASK_EPOCHS:
            rows = cross_epoch[task_name].get(label, [])
            if not rows:
                continue
            entries = [{'monkey': r['monkey'], 'group': r['group']} for r in rows]
            di_vals, rt_vals = get_behavioral_values(beh_df, entries, task_name, monkey_edges)
            beh_vals = di_vals if beh_name == 'DI' else rt_vals
            dists = np.array([r['distance'] for r in rows])
            valid = np.isfinite(beh_vals)
            all_d.extend(dists[valid] - np.mean(dists[valid]))
            all_beh.extend(beh_vals[valid] - np.mean(beh_vals[valid]))
        if len(all_d) >= 3:
            rho, p = sts.spearmanr(all_d, all_beh)
            rhos.append(rho)
            ps.append(p)
        else:
            rhos.append(np.nan)
            ps.append(np.nan)
    rho_vals[beh_name] = rhos
    p_vals[beh_name] = ps

# Layout: 2x2 grid, position (1,0) is blank
#         delay     response
# cue   [cue-del]  [cue-resp]
# delay    --      [del-resp]
pos_map = {0: (0, 0), 1: (0, 1), 2: (1, 1)}  # pair index -> (row, col)
row_labels = ['cue', 'delay']
col_labels = ['delay', 'response']

from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable

cmap = plt.cm.RdBu_r
norm = Normalize(vmin=-1, vmax=1)

fig, axes = plt.subplots(1, 2, figsize=(6, 2.8))

for ax, beh_name in zip(axes, ['DI', 'RT']):
    # Draw the 3 squares
    for k, (r, c) in pos_map.items():
        rho = rho_vals[beh_name][k]
        p = p_vals[beh_name][k]
        color = cmap(norm(rho)) if np.isfinite(rho) else 'lightgrey'
        ax.add_patch(plt.Rectangle((c, r), 1, 1, facecolor=color,
                                    edgecolor='k', linewidth=1.5))
        if np.isfinite(rho):
            stars = '***' if p < 0.001 else '**' if p < 0.01 else '*' if p < 0.05 else ''
            ax.text(c + 0.5, r + 0.5, f'{rho:.2f}{stars}',
                    ha='center', va='center', fontsize=11, fontweight='bold')

    ax.set_xlim(0, 2)
    ax.set_ylim(2, 0)  # row 0 on top
    ax.set_xticks([0.5, 1.5])
    ax.set_xticklabels(col_labels, fontsize=9)
    ax.set_yticks([0.5, 1.5])
    ax.set_yticklabels(row_labels, fontsize=9)
    ax.tick_params(length=0)
    ax.set_aspect('equal')
    ax.set_title(beh_name, fontsize=11)
    ax.spines[:].set_visible(False)

sm = ScalarMappable(cmap=cmap, norm=norm)
fig.colorbar(sm, ax=axes, shrink=0.8, label='Spearman \u03c1')
fig.suptitle('Cross-epoch distance vs behavior (pooled \u03c1)', fontsize=12)
plt.tight_layout()
