# Cross-temporal geometry distance

For each task, pool all neurons across monkeys.  
At each sliding window (500 ms, 25 ms steps), build a PCA representation.  
Compare every time window against every other via Procrustes → time × time matrix.

ODR: 8 stimulus directions. ODRd: 20 conditions (4 cue × 5 distractor).

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import os, sys, warnings
sys.path.insert(0, '.')

from functions import (
    load_odr_data, load_odrd_data, split_odrd_by_distractor,
    extract_metadata, compute_single_trial_rates,
)
from functions.temporal import rates_to_psth
from functions.representations import zscore_neurons, pca_reduce, _clean_neurons
from scipy.spatial import procrustes as scipy_procrustes

DATA_DIR = '../data_raw'
BIN_MS = 25

# Full trial time ranges per task (cue at 0, delay follows at 500 ms)
TASK_T_RANGE = {
    'ODR 1.5s': (-500, 2000),   # cue + 1.5 s delay
    'ODR 3.0s': (-500, 3500),   # cue + 3.0 s delay
    'ODRd':     (-500, 3500),   # cue + pre-distractor + distractor + post-distractor
}

WINDOW_MS = 500
STEP_MS = 25
N_PCS = 8

## 1. Load data

In [None]:
odr_all, ws_odr = load_odr_data(os.path.join(DATA_DIR, 'odr_data_both_sig_is_best_20240109.mat'))
_, _, _, delay_all = extract_metadata(ws_odr, odr_all.shape[0])

odrd_raw, _ = load_odrd_data(os.path.join(DATA_DIR, 'odrd_data_sig_on_best_20231018.mat'))
odrd_split = split_odrd_by_distractor(odrd_raw)

tasks = {}
for delay_val, name in [(1.5, 'ODR 1.5s'), (3.0, 'ODR 3.0s')]:
    mask = delay_all == delay_val
    tasks[name] = odr_all[mask]
tasks['ODRd'] = odrd_split

for name, data in tasks.items():
    print(f'{name}: {data.shape[0]} neurons, {data.shape[1]} conditions')

## 2. Compute PSTHs and time × time Procrustes matrices

In [None]:
txn_matrices = {}   # {task: ndarray (n_t, n_t)}
txn_centers = {}    # {task: ndarray of window centers}

for task_name, data in tasks.items():
    t_range = TASK_T_RANGE[task_name]
    bins = np.arange(t_range[0], t_range[1] + BIN_MS, BIN_MS)
    bc = (bins[:-1] + bins[1:]) / 2.0

    n_conds = data.shape[1]
    n_pcs = min(N_PCS, n_conds - 1)

    print(f'{task_name}: {data.shape[0]} neurons, {n_conds} conds, {n_pcs} PCs, '
          f'range {t_range[0]}–{t_range[1]} ms')

    rates = compute_single_trial_rates(data, bins)
    psth = rates_to_psth(rates)

    window_starts = np.arange(bc[0], bc[-1] - WINDOW_MS + STEP_MS + 1, STEP_MS)
    wc = window_starts + WINDOW_MS / 2
    n_t = len(wc)

    # PCA representation at each time window
    reps = []
    for t0 in window_starts:
        bmask = (bc >= t0) & (bc < t0 + WINDOW_MS)
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', RuntimeWarning)
            tuning_t = np.nanmean(psth[:, :, bmask], axis=2)

        X = _clean_neurons(tuning_t)
        if X.shape[0] < n_pcs + 1:
            reps.append(None); continue
        X_z = zscore_neurons(X)
        if X_z.shape[0] < n_pcs + 1:
            reps.append(None); continue
        mat, _ = pca_reduce(X_z, n_pcs)
        reps.append(mat)

    # Time x time Procrustes
    dist = np.full((n_t, n_t), np.nan)
    for i in range(n_t):
        if reps[i] is None: continue
        dist[i, i] = 0.0
        for j in range(i + 1, n_t):
            if reps[j] is None: continue
            _, _, d = scipy_procrustes(reps[i].T, reps[j].T)
            dist[i, j] = d
            dist[j, i] = d

    txn_matrices[task_name] = dist
    txn_centers[task_name] = wc
    print(f'  {n_t} time windows, {np.sum(np.isfinite(dist[0]))} valid')

## 3. Plot

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4.5), squeeze=False)

for ti, task_name in enumerate(tasks):
    ax = axes[0, ti]
    m = txn_matrices[task_name]
    wc = txn_centers[task_name]
    vals = m[np.isfinite(m) & (m > 0)]
    vmin, vmax = np.percentile(vals, 2), np.percentile(vals, 98)

    im = ax.imshow(m, origin='lower', cmap='magma',
                   vmin=vmin, vmax=vmax, aspect='equal',
                   extent=[wc[0], wc[-1], wc[0], wc[-1]])
    ax.axvline(0, color='w', ls='--', lw=0.6, alpha=0.6)
    ax.axhline(0, color='w', ls='--', lw=0.6, alpha=0.6)
    ax.axvline(500, color='w', ls=':', lw=0.6, alpha=0.6)
    ax.axhline(500, color='w', ls=':', lw=0.6, alpha=0.6)
    ax.set_title(task_name, fontsize=10)
    ax.set_xlabel('Time (ms)', fontsize=8)
    ax.set_ylabel('Time (ms)', fontsize=8)
    ax.tick_params(labelsize=7)
    fig.colorbar(im, ax=ax, shrink=0.8, label='Procrustes distance')

fig.suptitle('Cross-temporal geometry distance (all neurons pooled)',
             fontsize=11, y=1.02)
plt.tight_layout()