# Longitudinal RSA Analysis: VOTC Resection Study

## Hypothesis
Bilateral visual categories (Object, House) show greater representational reorganization than unilateral categories (Face, Word) in OTC patients, because bilateral representations are **collaborative** (not redundant) across hemispheres—losing one hemisphere forces compensation.

## Contrast Scheme Justification

### Liu Distinctiveness (Selectivity Change)
- Measures how **selective** a region is for its preferred category
- Uses contrasts that **define** category selectivity (following Liu et al.):
  - FFA: Face > Object (cope 1)
  - VWFA: Word > Scramble (cope 12) — cannot use Word > Face because face signal dominates VWFA's neighborhood
  - PPA: House > Object (cope 2)
  - LOC: Object > Scramble (cope 3)
- Question: "How correlated is preferred with non-preferred?" — about ROI's functional **identity**

### RSA Measures (Geometry Preservation, MDS Shift)
- Measures representational **structure** — how categories relate to each other
- Requires comparing patterns across all four categories simultaneously
- **Must use same baseline** for fair RDM comparison
- All Category > Scramble (copes 10, 12, 3, 11):
  - Consistent reference point
  - Each pattern reflects category response above low-level visual baseline
  - RDM comparisons are apples-to-apples

**Key distinction:** Selectivity is about a region's *identity*; RSA is about representational *structure*.

## Cell 1: Setup & Configuration

In [1]:
import numpy as np
import nibabel as nib
from pathlib import Path
import pandas as pd
from scipy.ndimage import label, center_of_mass
from scipy.stats import pearsonr, ttest_ind, ttest_rel
from scipy.linalg import orthogonal_procrustes
import warnings
warnings.filterwarnings('ignore')

# === PATHS ===
BASE_DIR = Path("/user_data/csimmon2/long_pt")
CSV_FILE = Path('/user_data/csimmon2/git_repos/long_pt/long_pt_sub_info.csv')
OUTPUT_DIR = Path('/user_data/csimmon2/git_repos/long_pt/B_analyses')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# === SUBJECT INFO ===
df = pd.read_csv(CSV_FILE)
SESSION_START = {'sub-010': 2, 'sub-018': 2, 'sub-068': 2}
EXCLUDE_SUBJECTS = ['sub-025', 'sub-027', 'sub-045', 'sub-072']

# === CATEGORIES ===
CATEGORIES = ['face', 'word', 'object', 'house']
BILATERAL = ['object', 'house']
UNILATERAL = ['face', 'word']

# === CONTRAST SCHEMES ===

# For Liu Distinctiveness (Selectivity Change)
# Uses ROI-defining contrasts per Liu et al.
COPE_MAP_LIU = {
    'face': (1, 1),    # Face > Object
    'word': (12, 1),   # Word > Scramble
    'object': (3, 1),  # Object > Scramble
    'house': (2, 1)    # House > Object
}

# For RSA measures (Geometry, MDS, Drift)
# Consistent baseline across all categories
COPE_MAP_SCRAMBLE = {
    'face': (10, 1),   # Face > Scramble
    'word': (12, 1),   # Word > Scramble
    'object': (3, 1),  # Object > Scramble
    'house': (11, 1)   # House > Scramble
}

print("✓ Configuration loaded")
print(f"  Excluding: {EXCLUDE_SUBJECTS}")
print(f"  Liu Distinctiveness: COPE_MAP_LIU")
print(f"  RSA Measures: COPE_MAP_SCRAMBLE")

✓ Configuration loaded
  Excluding: ['sub-025', 'sub-027', 'sub-045', 'sub-072']
  Liu Distinctiveness: COPE_MAP_LIU
  RSA Measures: COPE_MAP_SCRAMBLE


## Cell 2: Load Subjects

In [2]:
def load_subjects():
    """Load all subjects from CSV, excluding problematic ones"""
    subjects = {}
    
    for _, row in df.iterrows():
        subject_id = row['sub']
        
        if subject_id in EXCLUDE_SUBJECTS:
            continue
            
        subj_dir = BASE_DIR / subject_id
        if not subj_dir.exists():
            continue
        
        sessions = sorted(
            [d.name.replace('ses-', '') for d in subj_dir.glob('ses-*') if d.is_dir()], 
            key=int
        )
        start_session = SESSION_START.get(subject_id, 1)
        sessions = [s for s in sessions if int(s) >= start_session]
        
        if len(sessions) < 2:
            continue
        
        hemi = 'l' if row.get('intact_hemi', 'left') == 'left' else 'r'
        
        subjects[subject_id] = {
            'code': f"{row['group']}{subject_id.split('-')[1]}",
            'sessions': sessions,
            'hemi': hemi,
            'group': row['group'],
            'patient': row['patient'] == 1,
            'surgery_side': row.get('SurgerySide', 'na'),
            'sex': row.get('sex', 'na'),
            'age_1': row.get('age_1', np.nan),
            'age_2': row.get('age_2', np.nan)
        }
    
    return subjects

SUBJECTS = load_subjects()

print(f"✓ Loaded {len(SUBJECTS)} subjects (after exclusions)")
for group in ['OTC', 'nonOTC', 'control']:
    n = sum(1 for s in SUBJECTS.values() if s['group'] == group)
    print(f"  {group}: {n}")

✓ Loaded 20 subjects (after exclusions)
  OTC: 6
  nonOTC: 7
  control: 7


## Cell 3: Helper Functions

In [3]:
def create_sphere(center_coord, affine, brain_shape, radius=6):
    """Create spherical mask around coordinate"""
    grid = np.array(np.meshgrid(
        np.arange(brain_shape[0]),
        np.arange(brain_shape[1]),
        np.arange(brain_shape[2]),
        indexing='ij'
    )).reshape(3, -1).T
    
    world = nib.affines.apply_affine(affine, grid)
    distances = np.linalg.norm(world - center_coord, axis=1)
    
    mask = np.zeros(brain_shape, dtype=bool)
    within = grid[distances <= radius]
    for c in within:
        mask[c[0], c[1], c[2]] = True
    
    return mask


def filter_to_intact_hemisphere(df_results):
    """Filter results to intact hemisphere for patients, keep both for controls"""
    filtered = []
    for _, row in df_results.iterrows():
        sid = row['subject']
        info = SUBJECTS[sid]
        if info['group'] == 'control':
            filtered.append(row)
        elif row['hemi'] == info['hemi']:
            filtered.append(row)
    return pd.DataFrame(filtered)


print("✓ Helper functions defined")

✓ Helper functions defined


## Cell 4: ROI Extraction Function

In [4]:
def extract_rois(cope_map, threshold_z=2.3, min_voxels=20):
    """Extract ROIs for all subjects using specified contrast scheme"""
    
    all_rois = {}
    
    for sid, info in SUBJECTS.items():
        first_ses = info['sessions'][0]
        roi_dir = BASE_DIR / sid / f'ses-{first_ses}' / 'ROIs'
        
        if not roi_dir.exists():
            continue
        
        all_rois[sid] = {}
        
        # For controls, extract both hemispheres
        hemis = ['l', 'r'] if info['group'] == 'control' else [info['hemi']]
        
        for hemi in hemis:
            for category in CATEGORIES:
                cope_num, mult = cope_map[category]
                
                # Load search mask
                mask_file = roi_dir / f'{hemi}_{category}_searchmask.nii.gz'
                if not mask_file.exists():
                    continue
                
                try:
                    mask_img = nib.load(mask_file)
                    search_mask = mask_img.get_fdata() > 0
                    affine = mask_img.affine
                except:
                    continue
                
                roi_key = f'{hemi}_{category}'
                all_rois[sid][roi_key] = {}
                
                for session in info['sessions']:
                    feat_dir = BASE_DIR / sid / f'ses-{session}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                    z_name = 'zstat1.nii.gz' if session == first_ses else f'zstat1_ses{first_ses}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        continue
                    
                    try:
                        z_data = nib.load(cope_file).get_fdata() * mult
                        suprathresh = (z_data > threshold_z) & search_mask
                        
                        if suprathresh.sum() < min_voxels:
                            continue
                        
                        labeled, n_clusters = label(suprathresh)
                        if n_clusters == 0:
                            continue
                        
                        # Largest cluster
                        sizes = [(labeled == i).sum() for i in range(1, n_clusters + 1)]
                        best_idx = np.argmax(sizes) + 1
                        roi_mask = (labeled == best_idx)
                        
                        if roi_mask.sum() < min_voxels:
                            continue
                        
                        peak_idx = np.unravel_index(np.argmax(z_data * roi_mask), z_data.shape)
                        
                        all_rois[sid][roi_key][session] = {
                            'n_voxels': int(roi_mask.sum()),
                            'peak_z': z_data[peak_idx],
                            'centroid': nib.affines.apply_affine(affine, center_of_mass(roi_mask)),
                            'peak_coord': nib.affines.apply_affine(affine, peak_idx),
                            'roi_mask': roi_mask,
                            'affine': affine,
                            'shape': z_data.shape
                        }
                    except Exception as e:
                        continue
    
    return all_rois

print("✓ ROI extraction function defined")

✓ ROI extraction function defined


## Cell 5: Liu Distinctiveness (Selectivity Change)

In [5]:
def compute_selectivity_change(rois, pattern_cope_map):
    """
    Selectivity Change (Liu Distinctiveness):
    - Correlation of preferred category with non-preferred categories
    - Change from T1 to T2 (absolute difference)
    """
    results = []
    
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        first_ses = info['sessions'][0]
        
        for roi_key, sessions_data in roi_data.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            hemi = roi_key.split('_')[0]
            category = roi_key.split('_')[1]
            
            ref_data = sessions_data[sessions[0]]
            affine = ref_data['affine']
            shape = ref_data['shape']
            
            distinctiveness = {}
            for ses in [sessions[0], sessions[-1]]:
                if ses not in sessions_data:
                    continue
                
                centroid = sessions_data[ses]['centroid']
                sphere = create_sphere(centroid, affine, shape, radius=6)
                
                feat_dir = BASE_DIR / sid / f'ses-{ses}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                
                patterns = {}
                valid = True
                for cat in CATEGORIES:
                    cope_num, mult = pattern_cope_map[cat]
                    z_name = 'zstat1.nii.gz' if ses == first_ses else f'zstat1_ses{first_ses}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        valid = False
                        break
                    
                    data = nib.load(cope_file).get_fdata() * mult
                    pattern = data[sphere]
                    
                    if len(pattern) == 0 or not np.all(np.isfinite(pattern)):
                        valid = False
                        break
                    
                    patterns[cat] = pattern
                
                if valid and len(patterns) == 4:
                    pref_pattern = patterns[category]
                    nonpref_corrs = []
                    for other_cat in CATEGORIES:
                        if other_cat != category:
                            r, _ = pearsonr(pref_pattern, patterns[other_cat])
                            nonpref_corrs.append(np.arctanh(np.clip(r, -0.999, 0.999)))
                    
                    distinctiveness[ses] = np.mean(nonpref_corrs)
            
            if len(distinctiveness) == 2:
                change = abs(distinctiveness[sessions[-1]] - distinctiveness[sessions[0]])
                results.append({
                    'subject': sid,
                    'code': info['code'],
                    'group': info['group'],
                    'hemi': hemi,
                    'category': category,
                    'selectivity_change': change
                })
    
    return pd.DataFrame(results)

print("✓ Selectivity change function defined")

✓ Selectivity change function defined


## Cell 6: RSA Measures (Geometry Preservation, MDS Shift)

In [6]:
def compute_geometry_preservation(rois, pattern_cope_map, radius=6):
    """
    Geometry Preservation: RDM stability across sessions
    - Extract patterns from sphere at each session's centroid
    - Correlate T1 and T2 RDMs
    - Higher = more stable; lower in bilateral = MORE reorganization
    """
    results = []
    
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        first_ses = info['sessions'][0]
        
        for roi_key, sessions_data in roi_data.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            hemi = roi_key.split('_')[0]
            category = roi_key.split('_')[1]
            
            ref_data = sessions_data[sessions[0]]
            affine = ref_data['affine']
            shape = ref_data['shape']
            
            rdms = {}
            for ses in [sessions[0], sessions[-1]]:
                if ses not in sessions_data:
                    continue
                
                centroid = sessions_data[ses]['centroid']
                sphere = create_sphere(centroid, affine, shape, radius)
                
                feat_dir = BASE_DIR / sid / f'ses-{ses}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                
                patterns = []
                valid = True
                for cat in CATEGORIES:
                    cope_num, mult = pattern_cope_map[cat]
                    z_name = 'zstat1.nii.gz' if ses == first_ses else f'zstat1_ses{first_ses}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        valid = False
                        break
                    
                    data = nib.load(cope_file).get_fdata() * mult
                    pattern = data[sphere]
                    
                    if len(pattern) == 0 or not np.all(np.isfinite(pattern)):
                        valid = False
                        break
                    
                    patterns.append(pattern)
                
                if valid and len(patterns) == 4:
                    corr_matrix = np.corrcoef(patterns)
                    rdm = 1 - corr_matrix
                    rdms[ses] = rdm
            
            if len(rdms) == 2:
                triu = np.triu_indices(4, k=1)
                r, _ = pearsonr(rdms[sessions[0]][triu], rdms[sessions[-1]][triu])
                
                results.append({
                    'subject': sid,
                    'code': info['code'],
                    'group': info['group'],
                    'hemi': hemi,
                    'category': category,
                    'geometry_preservation': r
                })
    
    return pd.DataFrame(results)


def compute_mds_shift(rois, pattern_cope_map, radius=6):
    """
    MDS Shift: Procrustes-aligned embedding distance
    - MDS embed RDMs to 2D
    - Align with Procrustes
    - Measure movement of each category
    """
    def mds_2d(rdm):
        n = rdm.shape[0]
        H = np.eye(n) - np.ones((n, n)) / n
        B = -0.5 * H @ (rdm ** 2) @ H
        eigvals, eigvecs = np.linalg.eigh(B)
        idx = np.argsort(eigvals)[::-1]
        coords = eigvecs[:, idx[:2]] * np.sqrt(np.maximum(eigvals[idx[:2]], 0))
        return coords
    
    results = []
    
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        first_ses = info['sessions'][0]
        
        for roi_key, sessions_data in roi_data.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            hemi = roi_key.split('_')[0]
            roi_category = roi_key.split('_')[1]
            
            ref_data = sessions_data[sessions[0]]
            affine = ref_data['affine']
            shape = ref_data['shape']
            
            rdms = {}
            for ses in [sessions[0], sessions[-1]]:
                if ses not in sessions_data:
                    continue
                
                centroid = sessions_data[ses]['centroid']
                sphere = create_sphere(centroid, affine, shape, radius)
                
                feat_dir = BASE_DIR / sid / f'ses-{ses}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                
                patterns = []
                valid = True
                for cat in CATEGORIES:
                    cope_num, mult = pattern_cope_map[cat]
                    z_name = 'zstat1.nii.gz' if ses == first_ses else f'zstat1_ses{first_ses}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        valid = False
                        break
                    
                    data = nib.load(cope_file).get_fdata() * mult
                    pattern = data[sphere]
                    
                    if len(pattern) == 0 or not np.all(np.isfinite(pattern)):
                        valid = False
                        break
                    
                    patterns.append(pattern)
                
                if valid and len(patterns) == 4:
                    corr_matrix = np.corrcoef(patterns)
                    rdm = 1 - corr_matrix
                    rdms[ses] = rdm
            
            if len(rdms) == 2:
                try:
                    coords_t1 = mds_2d(rdms[sessions[0]])
                    coords_t2 = mds_2d(rdms[sessions[-1]])
                    
                    R, _ = orthogonal_procrustes(coords_t1, coords_t2)
                    coords_t1_aligned = coords_t1 @ R
                    
                    for i, cat in enumerate(CATEGORIES):
                        dist = np.linalg.norm(coords_t1_aligned[i] - coords_t2[i])
                        results.append({
                            'subject': sid,
                            'code': info['code'],
                            'group': info['group'],
                            'hemi': hemi,
                            'roi_category': roi_category,
                            'measured_category': cat,
                            'mds_shift': dist
                        })
                except:
                    continue
    
    return pd.DataFrame(results)


def compute_spatial_drift(rois):
    """
    Spatial Drift: Euclidean distance between T1 and T2 peak centroids
    """
    results = []
    
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        
        for roi_key, sessions_data in roi_data.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            hemi = roi_key.split('_')[0]
            category = roi_key.split('_')[1]
            
            c1 = sessions_data[sessions[0]]['centroid']
            c2 = sessions_data[sessions[-1]]['centroid']
            drift = np.linalg.norm(np.array(c2) - np.array(c1))
            
            results.append({
                'subject': sid,
                'code': info['code'],
                'group': info['group'],
                'hemi': hemi,
                'category': category,
                'spatial_drift_mm': drift,
                't1_peak_z': sessions_data[sessions[0]]['peak_z']
            })
    
    return pd.DataFrame(results)

print("✓ RSA metric functions defined")

✓ RSA metric functions defined


## Cell 7: Extract ROIs and Compute All Measures

In [7]:
print("="*70)
print("EXTRACTING ROIs")
print("="*70)

# Liu ROIs for Selectivity Change
print("\nExtracting Liu ROIs (for Selectivity Change)...")
rois_liu = extract_rois(COPE_MAP_LIU)
n_liu = sum(len([k for k, v in roi_data.items() if len(v) >= 2]) for roi_data in rois_liu.values())
print(f"  ✓ {len(rois_liu)} subjects, {n_liu} ROIs with 2+ sessions")

# Scramble ROIs for RSA measures
print("\nExtracting Scramble ROIs (for RSA measures)...")
rois_scramble = extract_rois(COPE_MAP_SCRAMBLE)
n_scr = sum(len([k for k, v in roi_data.items() if len(v) >= 2]) for roi_data in rois_scramble.values())
print(f"  ✓ {len(rois_scramble)} subjects, {n_scr} ROIs with 2+ sessions")

print("\n" + "="*70)
print("COMPUTING MEASURES")
print("="*70)

# Selectivity Change (using Liu ROIs and Liu patterns)
print("\nComputing Selectivity Change (Liu ROIs + Liu patterns)...")
selectivity_df = compute_selectivity_change(rois_liu, COPE_MAP_LIU)
print(f"  ✓ {len(selectivity_df)} measurements")

# Geometry Preservation (using Scramble ROIs and Scramble patterns)
print("\nComputing Geometry Preservation (Scramble ROIs + Scramble patterns)...")
geometry_df = compute_geometry_preservation(rois_scramble, COPE_MAP_SCRAMBLE)
print(f"  ✓ {len(geometry_df)} measurements")

# MDS Shift (using Scramble ROIs and Scramble patterns)
print("\nComputing MDS Shift (Scramble ROIs + Scramble patterns)...")
mds_df = compute_mds_shift(rois_scramble, COPE_MAP_SCRAMBLE)
print(f"  ✓ {len(mds_df)} measurements")

# Spatial Drift (using Scramble ROIs)
print("\nComputing Spatial Drift (Scramble ROIs)...")
drift_df = compute_spatial_drift(rois_scramble)
print(f"  ✓ {len(drift_df)} measurements")

EXTRACTING ROIs

Extracting Liu ROIs (for Selectivity Change)...
  ✓ 20 subjects, 107 ROIs with 2+ sessions

Extracting Scramble ROIs (for RSA measures)...
  ✓ 20 subjects, 107 ROIs with 2+ sessions

COMPUTING MEASURES

Computing Selectivity Change (Liu ROIs + Liu patterns)...
  ✓ 107 measurements

Computing Geometry Preservation (Scramble ROIs + Scramble patterns)...
  ✓ 107 measurements

Computing MDS Shift (Scramble ROIs + Scramble patterns)...
  ✓ 428 measurements

Computing Spatial Drift (Scramble ROIs)...
  ✓ 107 measurements


## Cell 8: Statistical Tests - Bilateral vs Unilateral

In [8]:
def test_bilateral_effect(df, metric_col, metric_name, higher_means_more_change=True):
    """Test if bilateral differs from unilateral within each group"""
    
    print(f"\n{'='*70}")
    print(f"{metric_name}")
    if higher_means_more_change:
        print("(Higher = more change; expect bilateral > unilateral in OTC)")
    else:
        print("(Lower = more change; expect bilateral < unilateral in OTC)")
    print("="*70)
    
    # Filter to intact hemisphere
    filtered_df = filter_to_intact_hemisphere(df)
    filtered_df['cat_type'] = filtered_df['category'].apply(
        lambda x: 'Bilateral' if x in BILATERAL else 'Unilateral'
    )
    
    print(f"\n{'Group':<12} {'Bilateral':<18} {'Unilateral':<18} {'Diff':<10} {'t':<8} {'p':<10}")
    print("-"*75)
    
    results = []
    for group in ['OTC', 'nonOTC', 'control']:
        gd = filtered_df[filtered_df['group'] == group]
        bil = gd[gd['cat_type'] == 'Bilateral'][metric_col]
        uni = gd[gd['cat_type'] == 'Unilateral'][metric_col]
        
        if len(bil) > 1 and len(uni) > 1:
            t, p = ttest_ind(bil, uni)
            diff = bil.mean() - uni.mean()
            sig = '***' if p < 0.001 else '**' if p < 0.01 else '*' if p < 0.05 else ''
            
            print(f"{group:<12} {bil.mean():.3f}±{bil.std():.3f}      {uni.mean():.3f}±{uni.std():.3f}      {diff:+.3f}     {t:.2f}    {p:.4f} {sig}")
            
            results.append({
                'group': group,
                'bilateral_mean': bil.mean(),
                'unilateral_mean': uni.mean(),
                'difference': diff,
                't': t,
                'p': p
            })
    
    return pd.DataFrame(results)


# Run tests
print("\n" + "#"*70)
print("BILATERAL vs UNILATERAL STATISTICAL TESTS")
print("#"*70)

selectivity_stats = test_bilateral_effect(
    selectivity_df, 'selectivity_change', 'SELECTIVITY CHANGE', 
    higher_means_more_change=True
)

geometry_stats = test_bilateral_effect(
    geometry_df, 'geometry_preservation', 'GEOMETRY PRESERVATION',
    higher_means_more_change=False
)


######################################################################
BILATERAL vs UNILATERAL STATISTICAL TESTS
######################################################################

SELECTIVITY CHANGE
(Higher = more change; expect bilateral > unilateral in OTC)

Group        Bilateral          Unilateral         Diff       t        p         
---------------------------------------------------------------------------
OTC          0.396±0.273      0.143±0.121      +0.253     2.82    0.0102 *
nonOTC       0.148±0.104      0.125±0.112      +0.023     0.56    0.5781 
control      0.233±0.179      0.150±0.121      +0.083     2.02    0.0482 *

GEOMETRY PRESERVATION
(Lower = more change; expect bilateral < unilateral in OTC)

Group        Bilateral          Unilateral         Diff       t        p         
---------------------------------------------------------------------------
OTC          -0.025±0.465      0.240±0.466      -0.265     -1.36    0.1872 
nonOTC       0.646±0.282      0.5

## Cell 9: Bootstrap Analysis

In [9]:
def bootstrap_group_comparison(df, metric_col, n_boot=10000, seed=42):
    """Bootstrap test for OTC bilateral advantage vs other groups"""
    np.random.seed(seed)
    
    # Filter to intact hemisphere
    filtered_df = filter_to_intact_hemisphere(df)
    filtered_df['cat_type'] = filtered_df['category'].apply(
        lambda x: 'Bilateral' if x in BILATERAL else 'Unilateral'
    )
    
    # Calculate subject-level bilateral advantage (gap)
    subject_gaps = {}
    for group in ['OTC', 'nonOTC', 'control']:
        gd = filtered_df[filtered_df['group'] == group]
        gaps = []
        for sid in gd['subject'].unique():
            sd = gd[gd['subject'] == sid]
            bil = sd[sd['cat_type'] == 'Bilateral'][metric_col].mean()
            uni = sd[sd['cat_type'] == 'Unilateral'][metric_col].mean()
            if pd.notna(bil) and pd.notna(uni):
                gaps.append(bil - uni)
        subject_gaps[group] = np.array(gaps)
    
    print(f"\n  Subject-level gaps:")
    for g, gaps in subject_gaps.items():
        print(f"    {g}: n={len(gaps)}, mean={gaps.mean():.3f}")
    
    results = []
    for comp_group in ['nonOTC', 'control']:
        g1 = subject_gaps['OTC']
        g2 = subject_gaps[comp_group]
        
        if len(g1) < 2 or len(g2) < 2:
            continue
        
        observed_diff = np.mean(g1) - np.mean(g2)
        
        boot_diffs = []
        for _ in range(n_boot):
            s1 = np.random.choice(g1, size=len(g1), replace=True)
            s2 = np.random.choice(g2, size=len(g2), replace=True)
            boot_diffs.append(np.mean(s1) - np.mean(s2))
        
        boot_diffs = np.array(boot_diffs)
        ci_low = np.percentile(boot_diffs, 2.5)
        ci_high = np.percentile(boot_diffs, 97.5)
        
        if observed_diff > 0:
            p_val = 2 * np.mean(boot_diffs <= 0)
        else:
            p_val = 2 * np.mean(boot_diffs >= 0)
        
        sig = '***' if p_val < 0.001 else '**' if p_val < 0.01 else '*' if p_val < 0.05 else ''
        print(f"\n  OTC vs {comp_group}: diff={observed_diff:.3f}, 95%CI=[{ci_low:.3f}, {ci_high:.3f}], p={p_val:.4f} {sig}")
        
        results.append({
            'comparison': f'OTC vs {comp_group}',
            'observed_diff': observed_diff,
            'ci_low': ci_low,
            'ci_high': ci_high,
            'p_value': p_val
        })
    
    return pd.DataFrame(results)


print("\n" + "#"*70)
print("BOOTSTRAP ANALYSIS")
print("#"*70)

print("\n--- SELECTIVITY CHANGE ---")
print("(Positive diff = OTC shows MORE bilateral advantage)")
boot_selectivity = bootstrap_group_comparison(selectivity_df, 'selectivity_change')

print("\n--- GEOMETRY PRESERVATION ---")
print("(Negative diff = OTC shows MORE bilateral disadvantage = more reorganization)")
boot_geometry = bootstrap_group_comparison(geometry_df, 'geometry_preservation')


######################################################################
BOOTSTRAP ANALYSIS
######################################################################

--- SELECTIVITY CHANGE ---
(Positive diff = OTC shows MORE bilateral advantage)

  Subject-level gaps:
    OTC: n=6, mean=0.257
    nonOTC: n=7, mean=0.023
    control: n=7, mean=0.083

  OTC vs nonOTC: diff=0.234, 95%CI=[0.029, 0.454], p=0.0132 *

  OTC vs control: diff=0.174, 95%CI=[-0.034, 0.398], p=0.1142 

--- GEOMETRY PRESERVATION ---
(Negative diff = OTC shows MORE bilateral disadvantage = more reorganization)

  Subject-level gaps:
    OTC: n=6, mean=-0.291
    nonOTC: n=7, mean=0.075
    control: n=7, mean=0.051

  OTC vs nonOTC: diff=-0.366, 95%CI=[-0.816, 0.079], p=0.1104 

  OTC vs control: diff=-0.342, 95%CI=[-0.650, -0.016], p=0.0414 *


## Cell 10: Category-Level Results

In [10]:
print("\n" + "="*70)
print("CATEGORY-LEVEL RESULTS")
print("="*70)

def print_category_table(df, metric_col, title):
    filtered = filter_to_intact_hemisphere(df)
    
    print(f"\n--- {title} ---")
    print(f"{'Group':<12} {'Face':<10} {'Word':<10} {'Object':<10} {'House':<10}")
    print("-"*55)
    
    for group in ['OTC', 'nonOTC', 'control']:
        gd = filtered[filtered['group'] == group]
        vals = []
        for cat in ['face', 'word', 'object', 'house']:
            cd = gd[gd['category'] == cat][metric_col]
            if len(cd) > 0:
                vals.append(f"{cd.mean():.2f}")
            else:
                vals.append("--")
        print(f"{group:<12} {vals[0]:<10} {vals[1]:<10} {vals[2]:<10} {vals[3]:<10}")

print_category_table(selectivity_df, 'selectivity_change', 'SELECTIVITY CHANGE (higher = more change)')
print_category_table(geometry_df, 'geometry_preservation', 'GEOMETRY PRESERVATION (lower = more change)')
print_category_table(drift_df, 'spatial_drift_mm', 'SPATIAL DRIFT (mm)')

# MDS shift needs special handling (has roi_category and measured_category)
print("\n--- MDS SHIFT (averaged across ROI locations) ---")
mds_filtered = filter_to_intact_hemisphere(
    mds_df.rename(columns={'roi_category': 'category'})
)
mds_avg = mds_filtered.groupby(['group', 'measured_category'])['mds_shift'].mean().reset_index()

print(f"{'Group':<12} {'Face':<10} {'Word':<10} {'Object':<10} {'House':<10}")
print("-"*55)
for group in ['OTC', 'nonOTC', 'control']:
    gd = mds_avg[mds_avg['group'] == group]
    vals = []
    for cat in ['face', 'word', 'object', 'house']:
        cd = gd[gd['measured_category'] == cat]['mds_shift']
        if len(cd) > 0:
            vals.append(f"{cd.values[0]:.2f}")
        else:
            vals.append("--")
    print(f"{group:<12} {vals[0]:<10} {vals[1]:<10} {vals[2]:<10} {vals[3]:<10}")


CATEGORY-LEVEL RESULTS

--- SELECTIVITY CHANGE (higher = more change) ---
Group        Face       Word       Object     House     
-------------------------------------------------------
OTC          0.15       0.14       0.48       0.31      
nonOTC       0.11       0.14       0.17       0.13      
control      0.18       0.12       0.20       0.27      

--- GEOMETRY PRESERVATION (lower = more change) ---
Group        Face       Word       Object     House     
-------------------------------------------------------
OTC          0.32       0.15       0.06       -0.11     
nonOTC       0.55       0.59       0.69       0.60      
control      0.64       0.30       0.55       0.49      

--- SPATIAL DRIFT (mm) ---
Group        Face       Word       Object     House     
-------------------------------------------------------
OTC          6.99       19.01      5.93       13.19     
nonOTC       3.40       8.96       2.52       5.69      
control      5.74       10.00      3.66       6.4

## Cell 11: Export Final Results

In [11]:
print("="*70)
print("EXPORTING FINAL RESULTS")
print("="*70)

# Build comprehensive export DataFrame
# Use selectivity as base (Liu ROIs)
export_data = []

selectivity_filt = filter_to_intact_hemisphere(selectivity_df)
geometry_filt = filter_to_intact_hemisphere(geometry_df)
drift_filt = filter_to_intact_hemisphere(drift_df)

for _, row in selectivity_filt.iterrows():
    sid = row['subject']
    info = SUBJECTS[sid]
    
    # Match geometry and drift (from scramble ROIs)
    geom_match = geometry_filt[
        (geometry_filt['subject'] == sid) & 
        (geometry_filt['hemi'] == row['hemi']) &
        (geometry_filt['category'] == row['category'])
    ]
    
    drift_match = drift_filt[
        (drift_filt['subject'] == sid) & 
        (drift_filt['hemi'] == row['hemi']) &
        (drift_filt['category'] == row['category'])
    ]
    
    # MDS shift (average across ROI locations for this measured category)
    mds_match = mds_df[
        (mds_df['subject'] == sid) & 
        (mds_df['hemi'] == row['hemi']) &
        (mds_df['measured_category'] == row['category'])
    ]['mds_shift'].mean()
    
    export_row = {
        'Subject': row['code'],
        'Group': info['group'],
        'Surgery_Side': info['surgery_side'],
        'Intact_Hemisphere': 'left' if info['hemi'] == 'l' else 'right',
        'Sex': info['sex'],
        'nonpt_hemi': row['hemi'].upper() if info['group'] == 'control' else 'na',
        'Category': row['category'].title(),
        'Category_Type': 'Bilateral' if row['category'] in BILATERAL else 'Unilateral',
        'age_1': info['age_1'],
        'age_2': info['age_2'],
        'yr_gap': info['age_2'] - info['age_1'] if pd.notna(info['age_1']) and pd.notna(info['age_2']) else np.nan,
        'Selectivity_Change': row['selectivity_change'],
        'Spatial_Relocation_mm': drift_match['spatial_drift_mm'].values[0] if len(drift_match) > 0 else np.nan,
        'Geometry_Preservation_6mm': geom_match['geometry_preservation'].values[0] if len(geom_match) > 0 else np.nan,
        'MDS_Shift': mds_match if pd.notna(mds_match) else np.nan
    }
    
    export_data.append(export_row)

export_df = pd.DataFrame(export_data)

# Save
output_file = OUTPUT_DIR / 'results_final_corrected.csv'
export_df.to_csv(output_file, index=False)

print(f"\n✓ Saved to: {output_file}")
print(f"  Shape: {export_df.shape}")
print(f"\nColumns: {list(export_df.columns)}")

# Summary
print("\n" + "-"*70)
print("FINAL SUMMARY")
print("-"*70)
print(f"\nSubjects per group:")
print(export_df.groupby('Group')['Subject'].nunique())
print(f"\nMeasurements per category:")
print(export_df.groupby('Category').size())

EXPORTING FINAL RESULTS

✓ Saved to: /user_data/csimmon2/git_repos/long_pt/B_analyses/results_final_corrected.csv
  Shape: (107, 15)

Columns: ['Subject', 'Group', 'Surgery_Side', 'Intact_Hemisphere', 'Sex', 'nonpt_hemi', 'Category', 'Category_Type', 'age_1', 'age_2', 'yr_gap', 'Selectivity_Change', 'Spatial_Relocation_mm', 'Geometry_Preservation_6mm', 'MDS_Shift']

----------------------------------------------------------------------
FINAL SUMMARY
----------------------------------------------------------------------

Subjects per group:
Group
OTC        6
control    7
nonOTC     7
Name: Subject, dtype: int64

Measurements per category:
Category
Face      27
House     27
Object    27
Word      26
dtype: int64


## Cell 12: Summary Statistics

In [12]:
print("="*70)
print("FINAL SUMMARY STATISTICS")
print("="*70)

print("\n--- BY GROUP AND CATEGORY TYPE ---")
summary = export_df.groupby(['Group', 'Category_Type']).agg({
    'Selectivity_Change': ['mean', 'std', 'count'],
    'Geometry_Preservation_6mm': ['mean', 'std'],
    'Spatial_Relocation_mm': ['mean', 'std'],
    'MDS_Shift': ['mean', 'std']
}).round(3)

print(summary)

print("\n" + "="*70)
print("KEY FINDINGS")
print("="*70)

# Extract OTC stats
otc_bil = export_df[(export_df['Group'] == 'OTC') & (export_df['Category_Type'] == 'Bilateral')]
otc_uni = export_df[(export_df['Group'] == 'OTC') & (export_df['Category_Type'] == 'Unilateral')]

print(f"\nOTC Selectivity Change:")
print(f"  Bilateral: {otc_bil['Selectivity_Change'].mean():.3f} ± {otc_bil['Selectivity_Change'].std():.3f}")
print(f"  Unilateral: {otc_uni['Selectivity_Change'].mean():.3f} ± {otc_uni['Selectivity_Change'].std():.3f}")
t, p = ttest_ind(otc_bil['Selectivity_Change'], otc_uni['Selectivity_Change'])
print(f"  Bil - Uni = {otc_bil['Selectivity_Change'].mean() - otc_uni['Selectivity_Change'].mean():.3f}, p = {p:.4f}")

print(f"\nOTC Geometry Preservation:")
print(f"  Bilateral: {otc_bil['Geometry_Preservation_6mm'].mean():.3f} ± {otc_bil['Geometry_Preservation_6mm'].std():.3f}")
print(f"  Unilateral: {otc_uni['Geometry_Preservation_6mm'].mean():.3f} ± {otc_uni['Geometry_Preservation_6mm'].std():.3f}")
t, p = ttest_ind(otc_bil['Geometry_Preservation_6mm'].dropna(), otc_uni['Geometry_Preservation_6mm'].dropna())
print(f"  Bil - Uni = {otc_bil['Geometry_Preservation_6mm'].mean() - otc_uni['Geometry_Preservation_6mm'].mean():.3f}, p = {p:.4f}")

print("\n" + "="*70)
print("ANALYSIS COMPLETE")
print("="*70)

FINAL SUMMARY STATISTICS

--- BY GROUP AND CATEGORY TYPE ---
                      Selectivity_Change               \
                                    mean    std count   
Group   Category_Type                                   
OTC     Bilateral                  0.396  0.273    12   
        Unilateral                 0.143  0.121    11   
control Bilateral                  0.233  0.179    28   
        Unilateral                 0.150  0.121    28   
nonOTC  Bilateral                  0.148  0.104    14   
        Unilateral                 0.125  0.112    14   

                      Geometry_Preservation_6mm        Spatial_Relocation_mm  \
                                           mean    std                  mean   
Group   Category_Type                                                          
OTC     Bilateral                        -0.025  0.465                 9.556   
        Unilateral                        0.240  0.466                12.457   
control Bilateral        