# Tuning curves: 4 cardinal directions, pooled across tasks

Common conditions: 0°, 90°, 180°, 270° (cardinal only).  
Common epochs: cue (0–500 ms), delay (500–1700 ms).  
Neurons pooled across ODR 1.5s, ODR 3.0s, and ODRd.

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import os, sys
sys.path.insert(0, '.')

from functions import (
    load_odr_data, load_odrd_data, extract_metadata,
    assign_age_groups, pooled_tuning_by_group, pca_reduce_tuning,
)
from functions.load_data import _abs_age_months

DATA_DIR = '../data_raw'

# Age group edges (absolute age in months): young / middle / old
AGE_EDGES = (48, 60)
AGE_GROUP_LABELS = ['young', 'middle', 'old']

CARDINAL_COLS = [0, 2, 4, 6]  # indices into ODR 8-direction data
COMMON_EPOCHS = {'cue': (0, 500), 'delay': (500, 1700)}
BIN_MS = 25
N_PCS = 5
MIN_NEURONS = N_PCS + 1

## 1. Load data (cardinal directions only)

In [None]:
# ODR data (8 directions) -> select 4 cardinal
odr_all, ws_odr = load_odr_data(os.path.join(DATA_DIR, 'odr_data_both_sig_is_best_20240109.mat'))
ids_all, age_all, mature_all, delay_all = extract_metadata(ws_odr, odr_all.shape[0])

# ODRd raw data (already 4 cardinal directions, not split by distractor)
odrd_raw, ws_odrd = load_odrd_data(os.path.join(DATA_DIR, 'odrd_data_sig_on_best_20231018.mat'))
odrd_ids, odrd_age, odrd_mat, _ = extract_metadata(ws_odrd, odrd_raw.shape[0])

# Build task dict with cardinal-only conditions
cardinal_data = {}
for delay, name in [(1.5, 'ODR 1.5s'), (3.0, 'ODR 3.0s')]:
    mask = delay_all == delay
    cardinal_data[name] = dict(
        data=odr_all[mask][:, CARDINAL_COLS],
        ids=ids_all[mask],
        abs_age=_abs_age_months(age_all[mask], mature_all[mask]),
    )

cardinal_data['ODRd'] = dict(
    data=odrd_raw,
    ids=odrd_ids,
    abs_age=_abs_age_months(odrd_age, odrd_mat),
)

for name, td in cardinal_data.items():
    print(f'{name}: {td["data"].shape[0]} neurons, {td["data"].shape[1]} conditions, '
          f'monkeys: {sorted(set(td["ids"]))}')

## 2. Compute pooled tuning curves

In [None]:
grouped, epoch_names = pooled_tuning_by_group(
    cardinal_data, COMMON_EPOCHS, AGE_EDGES, bin_ms=BIN_MS)

print(f'Epochs: {epoch_names}')
print(f'Monkeys: {list(grouped.keys())}\n')
for mid, groups in grouped.items():
    for g, tc in groups.items():
        print(f'  {mid} / {AGE_GROUP_LABELS[g]}: {tc.shape[0]} neurons, '
              f'shape {tc.shape}')

## 3. PCA-reduce tuning curves

In [None]:
reduced = pca_reduce_tuning(grouped, n_pcs=N_PCS, min_neurons=MIN_NEURONS)

for mid, groups in reduced.items():
    for g, info in groups.items():
        print(f'{mid} / {AGE_GROUP_LABELS[g]}: {info["n_neurons"]} neurons -> '
              f'{info["tc"].shape} (PCs x conds x epochs), '
              f'var explained {info["var_explained"]:.1%}')