# Cross-epoch geometry distance and behavioral correlations

For each task, compute tuning curves in cue, delay and response epochs.
Group neurons by monkey × age. Build PCA representations per epoch.
Compute Procrustes distance between delay–cue and delay–response per monkey–age pair.
Correlate with DI and RT.

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as sts
import os, sys
sys.path.insert(0, '.')

from functions import (
    load_all_task_data,
    compute_single_trial_rates, compute_tuning_curves,
    build_representations, assign_age_groups,
    load_behavioral_data,
    load_odr_data, load_odrd_data, extract_metadata,
)
from functions.load_data import _abs_age_months
from functions.behavior import TASK_MAP_SAC, DI_COLS_8, RT_COLS_8
from scipy.spatial import procrustes

DATA_DIR     = '../data_raw'
BEH_SAC      = '../original_repo/behavior/sac_data.csv'
BEH_SAC_ODRD = '../original_repo/GAM/data/sac_odrd.csv'

N_PCS       = 5
MIN_NEURONS = N_PCS + 1
BIN_MS      = 25
N_AGE_BINS  = 5   # per-monkey quantile bins

# "delay" = second/late delay for all tasks:
#   ODR 1.5s: second half of 500-2000 -> 1250-2000
#   ODR 3.0s: second half of 500-3000 -> 1750-3000
#   ODRd:     post-distractor delay   -> 2200-3000
TASK_EPOCHS = {
    'ODR 1.5s': dict(t_range=(-1000, 2500),
                     epochs={'cue': (0, 500), 'delay': (1250, 2000), 'response': (2000, 2500)}),
    'ODR 3.0s': dict(t_range=(-1000, 3500),
                     epochs={'cue': (0, 500), 'delay': (1750, 3000), 'response': (3000, 3500)}),
    'ODRd':     dict(t_range=(-1000, 4000),
                     epochs={'cue': (0, 500), 'delay': (2200, 3000), 'response': (3000, 3500)}),
}

COMPARISONS = [('cue', 'delay'), ('delay', 'response')]
TASK_COLORS = {'ODR 1.5s': '#1b9e77', 'ODR 3.0s': '#d95f02', 'ODRd': '#7570b3'}

## 1. Load data

In [None]:
task_data = load_all_task_data(DATA_DIR)
beh_df = load_behavioral_data(BEH_SAC, sac_odrd_path=BEH_SAC_ODRD)

for name, T in task_data.items():
    print(f'{name}: {T["data"].shape[0]} neurons, {T["data"].shape[1]} conditions')

## 2. Per-epoch PCA representations

In [None]:
import pandas as pd

epoch_reps = {}     # {task: {epoch_name: {(monkey, group): entry}}}
age_groups = {}     # {task: ndarray}
monkey_edges = {}   # {(task, monkey): tuple of edges}

for task_name, cfg in TASK_EPOCHS.items():
    data = task_data[task_name]
    ids  = data['ids']
    abs_age = data['abs_age']

    # Per-monkey quantile binning
    ag = np.full(len(ids), -1, dtype=int)
    monkeys = sorted(set(ids))
    for mid in monkeys:
        mask = ids == mid
        ages_m = abs_age[mask]
        pcts = np.linspace(0, 100, N_AGE_BINS + 1)[1:-1]
        edges = tuple(np.unique(np.percentile(ages_m, pcts)))
        monkey_edges[(task_name, mid)] = edges
        ag[mask] = np.digitize(ages_m, edges)
    age_groups[task_name] = ag

    # Compute tuning curves
    bins = np.arange(cfg['t_range'][0], cfg['t_range'][1] + BIN_MS, BIN_MS)
    bc   = (bins[:-1] + bins[1:]) / 2.0
    rates = compute_single_trial_rates(data['data'], bins)
    tuning, enames = compute_tuning_curves(rates, bc, cfg['epochs'])

    # Build per-epoch representations
    epoch_reps[task_name] = {}
    for ei, ename in enumerate(enames):
        entries = build_representations(tuning[:, :, ei], ids, ag,
                                        n_pcs=N_PCS, min_neurons=MIN_NEURONS)
        epoch_reps[task_name][ename] = {(e['monkey'], e['group']): e for e in entries}

    # Show which monkey-age pairs survived (table)
    kept = epoch_reps[task_name][enames[0]]
    print(f'\n{task_name}: {len(kept)} entries per epoch')
    table_rows = []
    for mid in monkeys:
        edges = monkey_edges[(task_name, mid)]
        n_grp = len(edges) + 1
        for g in range(n_grp):
            # Age range label
            lo = f'{edges[g-1]:.0f}' if g > 0 else '<' + f'{edges[0]:.0f}'
            hi = f'{edges[g]:.0f}' if g < len(edges) else '>' + f'{edges[-1]:.0f}'
            age_label = f'{lo}-{hi}' if g > 0 and g < len(edges) else (lo if g == 0 else hi)
            n_neur = kept[(mid, g)]['n_neurons'] if (mid, g) in kept else None
            table_rows.append(dict(monkey=mid, bin=g, age_range=age_label,
                                   n_neurons=n_neur if n_neur else '-',
                                   kept='Y' if n_neur else ''))
    df_table = pd.DataFrame(table_rows)
    pivot = df_table.pivot(index='monkey', columns='bin', values='n_neurons')
    # Column headers = age ranges for first monkey (they differ per monkey, so just use bin index)
    pivot.columns = [f'bin {c}' for c in pivot.columns]
    print(pivot.to_string())
    print(f'  ({MIN_NEURONS}+ neurons required, "-" = skipped)')

## 3. Cross-epoch Procrustes distances

In [None]:
cross_epoch = {}   # {task: {label: list of dict(monkey, group, distance)}}

for task_name in TASK_EPOCHS:
    cross_epoch[task_name] = {}
    for ea, eb in COMPARISONS:
        reps_a = epoch_reps[task_name][ea]
        reps_b = epoch_reps[task_name][eb]
        common = sorted(set(reps_a) & set(reps_b))

        rows = []
        for key in common:
            A = reps_a[key]['matrix'].T   # (n_conds, n_pcs)
            B = reps_b[key]['matrix'].T
            _, _, d = procrustes(A, B)
            rows.append(dict(monkey=key[0], group=key[1], distance=d))

        label = f'{ea}\u2192{eb}'
        cross_epoch[task_name][label] = rows
        dists = [r['distance'] for r in rows]
        print(f'{task_name} {label}: {len(rows)} pairs, '
              f'mean={np.mean(dists):.4f}, std={np.std(dists):.4f}')

## 4. Behavioral correlations

In [None]:
def get_beh_values(beh_df, entries, task_name, monkey_edges_dict):
    """Get per-entry DI/RT using per-monkey age edges."""
    beh_task = TASK_MAP_SAC[task_name]
    sub = beh_df[beh_df['Task'] == beh_task].copy()
    use_mean = (beh_task == 'ODRd')

    n = len(entries)
    di_vals = np.full(n, np.nan)
    rt_vals = np.full(n, np.nan)

    for idx, e in enumerate(entries):
        mid, grp = e['monkey'], e['group']
        edges = monkey_edges_dict[(task_name, mid)]
        rows = sub[sub['Monkey'] == mid].copy()
        if len(rows) == 0:
            continue
        rows['ag'] = np.digitize(rows['age_month'].values, edges)
        rows = rows[rows['ag'] == grp]
        if len(rows) == 0:
            continue
        if use_mean:
            di_vals[idx] = np.nanmean(rows['DI'].values)
            rt_vals[idx] = np.nanmean(rows['RT'].values)
        else:
            di_vals[idx] = np.nanmean(rows[DI_COLS_8].values)
            rt_vals[idx] = np.nanmean(rows[RT_COLS_8].values)
    return di_vals, rt_vals


fig, axes = plt.subplots(2, 2, figsize=(10, 8))

for ci, (ea, eb) in enumerate(COMPARISONS):
    label = f'{ea}\u2192{eb}'

    for ri, beh_name in enumerate(['DI', 'RT']):
        ax = axes[ri, ci]
        all_d, all_beh = [], []

        for task_name in TASK_EPOCHS:
            rows = cross_epoch[task_name][label]
            if not rows:
                continue

            entries = [{'monkey': r['monkey'], 'group': r['group']} for r in rows]
            di_vals, rt_vals = get_beh_values(beh_df, entries, task_name, monkey_edges)
            beh_vals = di_vals if beh_name == 'DI' else rt_vals

            dists = np.array([r['distance'] for r in rows])
            valid = np.isfinite(beh_vals)

            ax.scatter(dists[valid], beh_vals[valid],
                       c=TASK_COLORS[task_name], label=task_name,
                       s=50, alpha=0.7, edgecolors='k', linewidth=0.5)

            # Per-task correlation
            if valid.sum() >= 3:
                rho_t, p_t = sts.spearmanr(dists[valid], beh_vals[valid])
                print(f'  {task_name} {label} vs {beh_name}: '
                      f'\u03c1={rho_t:.3f}, p={p_t:.3f}, n={int(valid.sum())}')

            all_d.extend(dists[valid])
            all_beh.extend(beh_vals[valid])

        # Pooled Spearman + regression line
        all_d   = np.array(all_d)
        all_beh = np.array(all_beh)
        if len(all_d) >= 3:
            rho, p = sts.spearmanr(all_d, all_beh)
            m, b = np.polyfit(all_d, all_beh, 1)
            x_line = np.linspace(all_d.min(), all_d.max(), 50)
            ax.plot(x_line, m * x_line + b, 'k-', lw=1.5, alpha=0.8)
            ax.set_title(f'{label}: {beh_name} (\u03c1={rho:.3f}, p={p:.3f})', fontsize=9)
            print(f'  Pooled {label} vs {beh_name}: '
                  f'\u03c1={rho:.3f}, p={p:.3f}, n={len(all_d)}')
        else:
            ax.set_title(f'{label}: {beh_name}', fontsize=9)

        ax.set_xlabel(f'Procrustes distance ({label})', fontsize=8)
        ax.set_ylabel(beh_name, fontsize=8)
        ax.tick_params(labelsize=7)
        if ri == 0 and ci == 0:
            ax.legend(fontsize=7)

fig.suptitle('Cross-epoch Procrustes distance vs behavioral measures',
             fontsize=11, y=1.02)
plt.tight_layout()