In [2]:
# %% [markdown]
# # Longitudinal RSA Analysis: Contrast Scheme Comparison
# 
# **Study:** VOTC Resection - Bilateral vs Unilateral Visual Category Reorganization
# 
# **Hypothesis:** Bilateral categories (Object, House) show greater representational 
# reorganization than unilateral categories (Face, Word) in OTC patients because 
# the hemispheres work collaboratively (not redundantly), so losing one forces compensation.
# 
# **This notebook:**
# 1. Compares 4 contrast schemes for ROI localization and pattern extraction
# 2. Tests hybrid approach (Liu for localization, Scramble for patterns)
# 3. Computes all RSA measures with each approach
# 4. Determines optimal contrast scheme based on results

# %% [markdown]
# ## Cell 1: Setup & Configuration

# %%
import numpy as np
import nibabel as nib
from pathlib import Path
import pandas as pd
from scipy.ndimage import label, center_of_mass
from scipy.stats import pearsonr, ttest_ind, ttest_rel
from scipy.spatial.distance import squareform
from scipy.linalg import orthogonal_procrustes
import warnings
warnings.filterwarnings('ignore')

# Paths
BASE_DIR = Path("/user_data/csimmon2/long_pt")
CSV_FILE = Path('/user_data/csimmon2/git_repos/long_pt/long_pt_sub_info.csv')
OUTPUT_DIR = Path('/user_data/csimmon2/git_repos/long_pt/B_analyses')

# Load subject info
df = pd.read_csv(CSV_FILE)

# Session overrides
SESSION_START = {'sub-010': 2, 'sub-018': 2, 'sub-068': 2}

# Subjects to exclude
EXCLUDE_SUBJECTS = ['sub-025', 'sub-027', 'sub-045', 'sub-072']

# Categories
CATEGORIES = ['face', 'word', 'object', 'house']
BILATERAL = ['object', 'house']
UNILATERAL = ['face', 'word']

# ============================================================
# CONTRAST SCHEME DEFINITIONS
# ============================================================

# Scheme 1: Liu's original mixed contrasts (for ROI localization)
COPE_LIU_MIXED = {
    'face': (1, 1),    # Face > Object
    'word': (4, 1),    # Word > Object  
    'object': (3, 1),  # Object > Scramble
    'house': (2, 1)    # House > Object
}

# Scheme 2: All vs Scramble (consistent baseline)
COPE_ALL_SCRAMBLE = {
    'face': (10, 1),   # Face > Scramble
    'word': (12, 1),   # Word > Scramble
    'object': (3, 1),  # Object > Scramble
    'house': (11, 1)   # House > Scramble
}

# Scheme 3: All vs Others (category vs mean of others)
COPE_ALL_VS_OTHERS = {
    'face': (6, 1),    # Face > mean(House+Object+Word+Scramble)
    'word': (9, 1),    # Word > mean(Face+House+Object+Scramble)
    'object': (8, 1),  # Object > mean(Face+House+Word+Scramble)
    'house': (7, 1)    # House > mean(Face+Object+Word+Scramble)
}

# Scheme 4: Current differential (what Script 1 used - problematic)
COPE_DIFFERENTIAL = {
    'face': (10, 1),   # Face > Scramble
    'word': (13, -1),  # Face > Word, flipped to Word > Face
    'object': (3, 1),  # Object > Scramble
    'house': (11, 1)   # House > Scramble
}

CONTRAST_SCHEMES = {
    'liu_mixed': COPE_LIU_MIXED,
    'all_scramble': COPE_ALL_SCRAMBLE,
    'all_vs_others': COPE_ALL_VS_OTHERS,
    'current_differential': COPE_DIFFERENTIAL
}

print("✓ Configuration loaded")
print(f"  Excluding: {EXCLUDE_SUBJECTS}")

✓ Configuration loaded
  Excluding: ['sub-025', 'sub-027', 'sub-045', 'sub-072']


In [3]:
# %% [markdown]
# ## Cell 2: Load Subjects

# %%
def load_subjects():
    """Load all subjects from CSV, excluding problematic ones"""
    subjects = {}
    
    for _, row in df.iterrows():
        subject_id = row['sub']
        
        if subject_id in EXCLUDE_SUBJECTS:
            continue
            
        subj_dir = BASE_DIR / subject_id
        if not subj_dir.exists():
            continue
        
        sessions = sorted([d.name.replace('ses-', '') for d in subj_dir.glob('ses-*') if d.is_dir()], key=int)
        start_session = SESSION_START.get(subject_id, 1)
        sessions = [s for s in sessions if int(s) >= start_session]
        
        if len(sessions) < 2:
            continue
        
        hemi = 'l' if row.get('intact_hemi', 'left') == 'left' else 'r'
        
        subjects[subject_id] = {
            'code': f"{row['group']}{subject_id.split('-')[1]}",
            'sessions': sessions,
            'hemi': hemi,
            'group': row['group'],
            'patient': row['patient'] == 1,
            'surgery_side': row.get('SurgerySide', None),
            'sex': row.get('sex', None),
            'age_1': row.get('age_1', None),
            'age_2': row.get('age_2', None)
        }
    
    return subjects

SUBJECTS = load_subjects()

# Summary
print(f"✓ Loaded {len(SUBJECTS)} subjects (after exclusions)")
for group in ['OTC', 'nonOTC', 'control']:
    n = sum(1 for s in SUBJECTS.values() if s['group'] == group)
    print(f"  {group}: {n}")

# %% [markdown]

✓ Loaded 20 subjects (after exclusions)
  OTC: 6
  nonOTC: 7
  control: 7


In [4]:
# ## Cell 3: ROI Extraction Functions

# %%
def create_sphere(center_coord, affine, brain_shape, radius=6):
    """Create spherical mask around coordinate"""
    grid = np.array(np.meshgrid(
        np.arange(brain_shape[0]),
        np.arange(brain_shape[1]),
        np.arange(brain_shape[2]),
        indexing='ij'
    )).reshape(3, -1).T
    
    world = nib.affines.apply_affine(affine, grid)
    distances = np.linalg.norm(world - center_coord, axis=1)
    
    mask = np.zeros(brain_shape, dtype=bool)
    within = grid[distances <= radius]
    for c in within:
        mask[c[0], c[1], c[2]] = True
    
    return mask


def extract_rois_single_scheme(cope_map, threshold_z=2.3, min_voxels=20):
    """Extract ROIs for all subjects using a single contrast scheme"""
    
    all_rois = {}
    
    for sid, info in SUBJECTS.items():
        first_ses = info['sessions'][0]
        roi_dir = BASE_DIR / sid / f'ses-{first_ses}' / 'ROIs'
        
        if not roi_dir.exists():
            continue
        
        all_rois[sid] = {}
        
        # For controls, extract both hemispheres
        hemis = ['l', 'r'] if info['group'] == 'control' else [info['hemi']]
        
        for hemi in hemis:
            for category in CATEGORIES:
                cope_num, mult = cope_map[category]
                
                # Load search mask
                mask_file = roi_dir / f'{hemi}_{category}_searchmask.nii.gz'
                if not mask_file.exists():
                    continue
                
                try:
                    mask_img = nib.load(mask_file)
                    search_mask = mask_img.get_fdata() > 0
                    affine = mask_img.affine
                except:
                    continue
                
                roi_key = f'{hemi}_{category}'
                all_rois[sid][roi_key] = {}
                
                for session in info['sessions']:
                    feat_dir = BASE_DIR / sid / f'ses-{session}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                    z_name = 'zstat1.nii.gz' if session == first_ses else f'zstat1_ses{first_ses}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        continue
                    
                    try:
                        z_data = nib.load(cope_file).get_fdata() * mult
                        suprathresh = (z_data > threshold_z) & search_mask
                        
                        if suprathresh.sum() < min_voxels:
                            continue
                        
                        labeled, n_clusters = label(suprathresh)
                        if n_clusters == 0:
                            continue
                        
                        # Largest cluster
                        sizes = [(labeled == i).sum() for i in range(1, n_clusters + 1)]
                        best_idx = np.argmax(sizes) + 1
                        roi_mask = (labeled == best_idx)
                        
                        if roi_mask.sum() < min_voxels:
                            continue
                        
                        peak_idx = np.unravel_index(np.argmax(z_data * roi_mask), z_data.shape)
                        
                        all_rois[sid][roi_key][session] = {
                            'n_voxels': int(roi_mask.sum()),
                            'peak_z': z_data[peak_idx],
                            'centroid': nib.affines.apply_affine(affine, center_of_mass(roi_mask)),
                            'peak_coord': nib.affines.apply_affine(affine, peak_idx),
                            'roi_mask': roi_mask,
                            'affine': affine,
                            'shape': z_data.shape
                        }
                    except Exception as e:
                        continue
    
    return all_rois

print("✓ ROI extraction functions defined")

# %% [markdown]

✓ ROI extraction functions defined


In [5]:
# ## Cell 4: RSA Metric Functions

# %%
def compute_rdm(patterns):
    """Compute RDM from pattern matrix (categories x voxels)"""
    corr = np.corrcoef(patterns)
    rdm = 1 - corr
    return rdm, corr


def compute_geometry_preservation(rois, pattern_cope_map, radius=6):
    """
    Geometry Preservation: RDM stability across sessions
    - Extract patterns from sphere at each session's centroid
    - Correlate T1 and T2 RDMs
    """
    results = []
    
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        first_ses = info['sessions'][0]
        
        for roi_key, sessions_data in roi_data.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            hemi = roi_key.split('_')[0]
            category = roi_key.split('_')[1]
            
            # Get reference for sphere creation
            ref_data = sessions_data[sessions[0]]
            affine = ref_data['affine']
            shape = ref_data['shape']
            
            rdms = {}
            for ses in [sessions[0], sessions[-1]]:
                if ses not in sessions_data:
                    continue
                
                centroid = sessions_data[ses]['centroid']
                sphere = create_sphere(centroid, affine, shape, radius)
                
                feat_dir = BASE_DIR / sid / f'ses-{ses}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                
                patterns = []
                valid = True
                for cat in CATEGORIES:
                    cope_num, mult = pattern_cope_map[cat]
                    z_name = 'zstat1.nii.gz' if ses == first_ses else f'zstat1_ses{first_ses}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        valid = False
                        break
                    
                    data = nib.load(cope_file).get_fdata() * mult
                    pattern = data[sphere]
                    
                    if len(pattern) == 0 or not np.all(np.isfinite(pattern)):
                        valid = False
                        break
                    
                    patterns.append(pattern)
                
                if valid and len(patterns) == 4:
                    rdm, _ = compute_rdm(np.array(patterns))
                    rdms[ses] = rdm
            
            if len(rdms) == 2:
                triu = np.triu_indices(4, k=1)
                r, _ = pearsonr(rdms[sessions[0]][triu], rdms[sessions[-1]][triu])
                
                results.append({
                    'subject': sid,
                    'code': info['code'],
                    'group': info['group'],
                    'hemi': hemi,
                    'category': category,
                    'geometry_preservation': r
                })
    
    return pd.DataFrame(results)


def compute_mds_shift(rois, pattern_cope_map, radius=6):
    """
    MDS Shift: Procrustes-aligned embedding distance
    - MDS embed RDMs to 2D
    - Align with Procrustes
    - Measure movement of each category
    """
    def mds_2d(rdm):
        n = rdm.shape[0]
        H = np.eye(n) - np.ones((n, n)) / n
        B = -0.5 * H @ (rdm ** 2) @ H
        eigvals, eigvecs = np.linalg.eigh(B)
        idx = np.argsort(eigvals)[::-1]
        coords = eigvecs[:, idx[:2]] * np.sqrt(np.maximum(eigvals[idx[:2]], 0))
        return coords
    
    results = []
    
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        first_ses = info['sessions'][0]
        
        for roi_key, sessions_data in roi_data.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            hemi = roi_key.split('_')[0]
            roi_category = roi_key.split('_')[1]
            
            ref_data = sessions_data[sessions[0]]
            affine = ref_data['affine']
            shape = ref_data['shape']
            
            rdms = {}
            for ses in [sessions[0], sessions[-1]]:
                if ses not in sessions_data:
                    continue
                
                centroid = sessions_data[ses]['centroid']
                sphere = create_sphere(centroid, affine, shape, radius)
                
                feat_dir = BASE_DIR / sid / f'ses-{ses}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                
                patterns = []
                valid = True
                for cat in CATEGORIES:
                    cope_num, mult = pattern_cope_map[cat]
                    z_name = 'zstat1.nii.gz' if ses == first_ses else f'zstat1_ses{first_ses}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        valid = False
                        break
                    
                    data = nib.load(cope_file).get_fdata() * mult
                    pattern = data[sphere]
                    
                    if len(pattern) == 0 or not np.all(np.isfinite(pattern)):
                        valid = False
                        break
                    
                    patterns.append(pattern)
                
                if valid and len(patterns) == 4:
                    rdm, _ = compute_rdm(np.array(patterns))
                    rdms[ses] = rdm
            
            if len(rdms) == 2:
                try:
                    coords_t1 = mds_2d(rdms[sessions[0]])
                    coords_t2 = mds_2d(rdms[sessions[-1]])
                    
                    R, _ = orthogonal_procrustes(coords_t1, coords_t2)
                    coords_t1_aligned = coords_t1 @ R
                    
                    for i, cat in enumerate(CATEGORIES):
                        dist = np.linalg.norm(coords_t1_aligned[i] - coords_t2[i])
                        results.append({
                            'subject': sid,
                            'code': info['code'],
                            'group': info['group'],
                            'hemi': hemi,
                            'roi_category': roi_category,
                            'measured_category': cat,
                            'mds_shift': dist
                        })
                except:
                    continue
    
    return pd.DataFrame(results)


def compute_spatial_drift(rois):
    """
    Spatial Drift: Euclidean distance between T1 and T2 peak centroids
    """
    results = []
    
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        
        for roi_key, sessions_data in roi_data.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            hemi = roi_key.split('_')[0]
            category = roi_key.split('_')[1]
            
            c1 = sessions_data[sessions[0]]['centroid']
            c2 = sessions_data[sessions[-1]]['centroid']
            drift = np.linalg.norm(np.array(c2) - np.array(c1))
            
            results.append({
                'subject': sid,
                'code': info['code'],
                'group': info['group'],
                'hemi': hemi,
                'category': category,
                'spatial_drift_mm': drift,
                't1_peak_z': sessions_data[sessions[0]]['peak_z']
            })
    
    return pd.DataFrame(results)


def compute_selectivity_change(rois, pattern_cope_map):
    """
    Selectivity Change (Liu Distinctiveness):
    - Correlation of preferred category with non-preferred categories
    - Change from T1 to T2
    """
    results = []
    
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        first_ses = info['sessions'][0]
        
        for roi_key, sessions_data in roi_data.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            hemi = roi_key.split('_')[0]
            category = roi_key.split('_')[1]
            
            ref_data = sessions_data[sessions[0]]
            affine = ref_data['affine']
            shape = ref_data['shape']
            
            distinctiveness = {}
            for ses in [sessions[0], sessions[-1]]:
                if ses not in sessions_data:
                    continue
                
                centroid = sessions_data[ses]['centroid']
                sphere = create_sphere(centroid, affine, shape, radius=6)
                
                feat_dir = BASE_DIR / sid / f'ses-{ses}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                
                patterns = {}
                valid = True
                for cat in CATEGORIES:
                    cope_num, mult = pattern_cope_map[cat]
                    z_name = 'zstat1.nii.gz' if ses == first_ses else f'zstat1_ses{first_ses}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        valid = False
                        break
                    
                    data = nib.load(cope_file).get_fdata() * mult
                    pattern = data[sphere]
                    
                    if len(pattern) == 0 or not np.all(np.isfinite(pattern)):
                        valid = False
                        break
                    
                    patterns[cat] = pattern
                
                if valid and len(patterns) == 4:
                    # Compute correlation of preferred with non-preferred
                    pref_pattern = patterns[category]
                    nonpref_corrs = []
                    for other_cat in CATEGORIES:
                        if other_cat != category:
                            r, _ = pearsonr(pref_pattern, patterns[other_cat])
                            nonpref_corrs.append(np.arctanh(np.clip(r, -0.999, 0.999)))
                    
                    distinctiveness[ses] = np.mean(nonpref_corrs)
            
            if len(distinctiveness) == 2:
                change = abs(distinctiveness[sessions[-1]] - distinctiveness[sessions[0]])
                results.append({
                    'subject': sid,
                    'code': info['code'],
                    'group': info['group'],
                    'hemi': hemi,
                    'category': category,
                    'selectivity_change': change,
                    't1_distinctiveness': distinctiveness[sessions[0]],
                    't2_distinctiveness': distinctiveness[sessions[-1]]
                })
    
    return pd.DataFrame(results)

print("✓ RSA metric functions defined")


✓ RSA metric functions defined


In [6]:
# %% [markdown]
# ## Cell 5: Extract ROIs with All Contrast Schemes

# %%
print("="*70)
print("EXTRACTING ROIs WITH ALL CONTRAST SCHEMES")
print("="*70)

all_rois = {}
for scheme_name, cope_map in CONTRAST_SCHEMES.items():
    print(f"\n{scheme_name}...")
    all_rois[scheme_name] = extract_rois_single_scheme(cope_map)
    
    # Count ROIs
    n_rois = sum(len(roi_data) for roi_data in all_rois[scheme_name].values())
    n_subjects = len([s for s in all_rois[scheme_name] if all_rois[scheme_name][s]])
    print(f"  ✓ {n_subjects} subjects, {n_rois} ROIs")

print("\n" + "="*70)

EXTRACTING ROIs WITH ALL CONTRAST SCHEMES

liu_mixed...


KeyboardInterrupt: 

In [None]:
# %% [markdown]
# ## Cell 6: Compare ROI Yields Across Schemes

# %%
print("="*70)
print("ROI YIELD COMPARISON BY CATEGORY")
print("="*70)

# Count ROIs per category per scheme
yield_data = []
for scheme_name, rois in all_rois.items():
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        for roi_key, sessions_data in roi_data.items():
            if len(sessions_data) >= 2:  # Has both timepoints
                hemi = roi_key.split('_')[0]
                category = roi_key.split('_')[1]
                n_voxels_t1 = sessions_data[sorted(sessions_data.keys())[0]]['n_voxels']
                
                yield_data.append({
                    'scheme': scheme_name,
                    'subject': sid,
                    'group': info['group'],
                    'hemi': hemi,
                    'category': category,
                    'n_voxels': n_voxels_t1
                })

yield_df = pd.DataFrame(yield_data)

print("\nROI counts by scheme and category:")
print("-"*60)
pivot = yield_df.groupby(['scheme', 'category']).size().unstack(fill_value=0)
print(pivot)

print("\n\nMean voxel count by scheme and category:")
print("-"*60)
pivot_vox = yield_df.groupby(['scheme', 'category'])['n_voxels'].mean().unstack()
print(pivot_vox.round(0))

# Word ROI problem visualization
print("\n\nWORD ROI VOXEL COUNTS (the problematic category):")
print("-"*60)
word_data = yield_df[yield_df['category'] == 'word']
for scheme in CONTRAST_SCHEMES.keys():
    scheme_word = word_data[word_data['scheme'] == scheme]
    print(f"\n{scheme}:")
    print(f"  N ROIs: {len(scheme_word)}")
    print(f"  Mean voxels: {scheme_word['n_voxels'].mean():.0f}")
    print(f"  Min voxels: {scheme_word['n_voxels'].min()}")
    print(f"  Max voxels: {scheme_word['n_voxels'].max()}")

# %% [markdown]
# ## Cell 7: Compute All Metrics - Standard Approach (Same scheme for localization and patterns)

# %%
print("="*70)
print("COMPUTING METRICS: STANDARD APPROACH")
print("(Same contrast scheme for ROI localization AND pattern extraction)")
print("="*70)

standard_results = {}

for scheme_name in CONTRAST_SCHEMES.keys():
    print(f"\n{scheme_name}...")
    
    rois = all_rois[scheme_name]
    cope_map = CONTRAST_SCHEMES[scheme_name]
    
    # Compute all metrics
    geom = compute_geometry_preservation(rois, cope_map)
    mds = compute_mds_shift(rois, cope_map)
    drift = compute_spatial_drift(rois)
    select = compute_selectivity_change(rois, cope_map)
    
    standard_results[scheme_name] = {
        'geometry': geom,
        'mds': mds,
        'drift': drift,
        'selectivity': select
    }
    
    print(f"  Geometry: {len(geom)} ROIs")
    print(f"  MDS: {len(mds)} measurements")
    print(f"  Drift: {len(drift)} ROIs")
    print(f"  Selectivity: {len(select)} ROIs")

# %% [markdown]
# ## Cell 8: Compute Metrics - Hybrid Approach

# %%
print("="*70)
print("COMPUTING METRICS: HYBRID APPROACH")
print("(Liu mixed for localization, All Scramble for patterns)")
print("="*70)

# Use Liu mixed ROIs, but All Scramble patterns
hybrid_rois = all_rois['liu_mixed']
pattern_cope = COPE_ALL_SCRAMBLE

print("\nUsing ROIs from: liu_mixed")
print("Using patterns from: all_scramble")

hybrid_results = {
    'geometry': compute_geometry_preservation(hybrid_rois, pattern_cope),
    'mds': compute_mds_shift(hybrid_rois, pattern_cope),
    'drift': compute_spatial_drift(hybrid_rois),
    'selectivity': compute_selectivity_change(hybrid_rois, pattern_cope)
}

print(f"\n✓ Geometry: {len(hybrid_results['geometry'])} ROIs")
print(f"✓ MDS: {len(hybrid_results['mds'])} measurements")
print(f"✓ Drift: {len(hybrid_results['drift'])} ROIs")
print(f"✓ Selectivity: {len(hybrid_results['selectivity'])} ROIs")

# %% [markdown]
# ## Cell 9: Compare Results Across Approaches

# %%
def summarize_by_category(df, metric_col, groups=['OTC', 'nonOTC', 'control']):
    """Summarize metric by group and category"""
    summary = []
    for group in groups:
        group_data = df[df['group'] == group]
        for cat in CATEGORIES:
            cat_data = group_data[group_data['category'] == cat][metric_col]
            if len(cat_data) > 0:
                summary.append({
                    'group': group,
                    'category': cat,
                    'mean': cat_data.mean(),
                    'std': cat_data.std(),
                    'n': len(cat_data)
                })
    return pd.DataFrame(summary)


print("="*70)
print("RESULTS COMPARISON: GEOMETRY PRESERVATION")
print("(Higher = more stable, lower in bilateral = MORE reorganization)")
print("="*70)

for approach_name, results in [('liu_mixed (standard)', standard_results['liu_mixed']),
                                ('all_scramble (standard)', standard_results['all_scramble']),
                                ('hybrid (liu ROI + scramble pattern)', {'geometry': hybrid_results['geometry']})]:
    print(f"\n--- {approach_name} ---")
    geom_df = results['geometry'] if 'geometry' in results else results.get('geometry')
    
    if geom_df is None or len(geom_df) == 0:
        print("  No data")
        continue
    
    # Filter to intact hemisphere for patients
    filtered = []
    for _, row in geom_df.iterrows():
        sid = row['subject']
        info = SUBJECTS[sid]
        if info['group'] == 'control':
            filtered.append(row)
        elif row['hemi'] == info['hemi']:
            filtered.append(row)
    
    filtered_df = pd.DataFrame(filtered)
    
    print(f"\n{'Group':<10} {'Face':<12} {'Word':<12} {'Object':<12} {'House':<12}")
    print("-"*60)
    
    for group in ['OTC', 'nonOTC', 'control']:
        gd = filtered_df[filtered_df['group'] == group]
        vals = []
        for cat in CATEGORIES:
            cd = gd[gd['category'] == cat]['geometry_preservation']
            if len(cd) > 0:
                vals.append(f"{cd.mean():.3f}")
            else:
                vals.append("--")
        print(f"{group:<10} {vals[0]:<12} {vals[1]:<12} {vals[2]:<12} {vals[3]:<12}")

# %% [markdown]
# ## Cell 10: Statistical Tests - Compare Bilateral vs Unilateral Effect

# %%
def test_bilateral_effect(df, metric_col, metric_name):
    """Test if bilateral > unilateral change (or < for stability metrics)"""
    
    print(f"\n{'='*70}")
    print(f"BILATERAL vs UNILATERAL: {metric_name}")
    print(f"{'='*70}")
    
    # Filter to intact hemisphere for patients
    filtered = []
    for _, row in df.iterrows():
        sid = row['subject']
        info = SUBJECTS[sid]
        if info['group'] == 'control':
            filtered.append(row)
        elif row['hemi'] == info['hemi']:
            filtered.append(row)
    
    filtered_df = pd.DataFrame(filtered)
    filtered_df['cat_type'] = filtered_df['category'].apply(
        lambda x: 'Bilateral' if x in BILATERAL else 'Unilateral'
    )
    
    print(f"\n{'Group':<12} {'Bilateral':<15} {'Unilateral':<15} {'Diff':<10} {'t':<8} {'p':<8}")
    print("-"*70)
    
    results = []
    for group in ['OTC', 'nonOTC', 'control']:
        gd = filtered_df[filtered_df['group'] == group]
        bil = gd[gd['cat_type'] == 'Bilateral'][metric_col]
        uni = gd[gd['cat_type'] == 'Unilateral'][metric_col]
        
        if len(bil) > 1 and len(uni) > 1:
            t, p = ttest_ind(bil, uni)
            diff = bil.mean() - uni.mean()
            sig = '*' if p < 0.05 else ''
            print(f"{group:<12} {bil.mean():.3f}±{bil.std():.3f}   {uni.mean():.3f}±{uni.std():.3f}   {diff:+.3f}     {t:.2f}    {p:.4f} {sig}")
            
            results.append({
                'group': group,
                'bilateral_mean': bil.mean(),
                'unilateral_mean': uni.mean(),
                'difference': diff,
                't': t,
                'p': p
            })
    
    return pd.DataFrame(results)


print("\n" + "="*70)
print("COMPARING BILATERAL vs UNILATERAL EFFECTS ACROSS APPROACHES")
print("="*70)

# Test each approach
approach_tests = {}

for scheme_name in ['liu_mixed', 'all_scramble', 'current_differential']:
    print(f"\n\n{'#'*70}")
    print(f"APPROACH: {scheme_name}")
    print(f"{'#'*70}")
    
    results = standard_results[scheme_name]
    
    approach_tests[scheme_name] = {}
    
    # Geometry Preservation (lower bilateral = more change)
    if len(results['geometry']) > 0:
        approach_tests[scheme_name]['geometry'] = test_bilateral_effect(
            results['geometry'], 'geometry_preservation', 'Geometry Preservation'
        )
    
    # Selectivity Change (higher bilateral = more change)
    if len(results['selectivity']) > 0:
        approach_tests[scheme_name]['selectivity'] = test_bilateral_effect(
            results['selectivity'], 'selectivity_change', 'Selectivity Change'
        )

# Hybrid approach
print(f"\n\n{'#'*70}")
print("APPROACH: HYBRID (liu ROI + scramble patterns)")
print(f"{'#'*70}")

approach_tests['hybrid'] = {}
approach_tests['hybrid']['geometry'] = test_bilateral_effect(
    hybrid_results['geometry'], 'geometry_preservation', 'Geometry Preservation'
)
approach_tests['hybrid']['selectivity'] = test_bilateral_effect(
    hybrid_results['selectivity'], 'selectivity_change', 'Selectivity Change'
)

# %% [markdown]
# ## Cell 11: Summary Comparison Table

# %%
print("="*70)
print("SUMMARY: OTC BILATERAL vs UNILATERAL EFFECT BY APPROACH")
print("="*70)
print("\nKey question: Does OTC show significantly greater change in bilateral")
print("categories compared to unilateral? (This supports our hypothesis)")
print()

print(f"{'Approach':<25} {'Measure':<20} {'Bil-Uni Diff':<15} {'p-value':<10} {'Sig?':<5}")
print("-"*75)

for approach_name in ['liu_mixed', 'all_scramble', 'current_differential', 'hybrid']:
    if approach_name not in approach_tests:
        continue
    
    for measure in ['geometry', 'selectivity']:
        if measure not in approach_tests[approach_name]:
            continue
        
        test_df = approach_tests[approach_name][measure]
        otc_row = test_df[test_df['group'] == 'OTC']
        
        if len(otc_row) == 0:
            continue
        
        diff = otc_row['difference'].values[0]
        p = otc_row['p'].values[0]
        sig = '✓' if p < 0.05 else ''
        
        # For geometry, negative diff means bilateral has MORE change (lower stability)
        # For selectivity, positive diff means bilateral has MORE change
        measure_label = 'Geometry Pres.' if measure == 'geometry' else 'Selectivity Chg.'
        
        print(f"{approach_name:<25} {measure_label:<20} {diff:+.4f}        {p:.4f}     {sig}")

print("\n" + "-"*75)
print("Note: For Geometry Preservation, NEGATIVE diff = bilateral shows MORE change")
print("      For Selectivity Change, POSITIVE diff = bilateral shows MORE change")

# %% [markdown]
# ## Cell 12: Detailed Category-Level Results for Best Approach

# %%
print("="*70)
print("DETAILED CATEGORY-LEVEL RESULTS")
print("="*70)

# Determine best approach based on OTC significance
# Let's show results for multiple approaches for comparison

for approach_name in ['all_scramble', 'hybrid']:
    print(f"\n\n{'#'*70}")
    print(f"APPROACH: {approach_name}")
    print(f"{'#'*70}")
    
    if approach_name == 'hybrid':
        geom_df = hybrid_results['geometry']
        select_df = hybrid_results['selectivity']
        drift_df = hybrid_results['drift']
    else:
        geom_df = standard_results[approach_name]['geometry']
        select_df = standard_results[approach_name]['selectivity']
        drift_df = standard_results[approach_name]['drift']
    
    # Filter to intact hemisphere
    def filter_intact(df):
        filtered = []
        for _, row in df.iterrows():
            sid = row['subject']
            info = SUBJECTS[sid]
            if info['group'] == 'control':
                filtered.append(row)
            elif row['hemi'] == info['hemi']:
                filtered.append(row)
        return pd.DataFrame(filtered)
    
    geom_filt = filter_intact(geom_df)
    select_filt = filter_intact(select_df)
    drift_filt = filter_intact(drift_df)
    
    print("\n--- GEOMETRY PRESERVATION (higher = more stable) ---")
    print(f"{'Group':<10} {'Face':<10} {'Word':<10} {'Object':<10} {'House':<10}")
    print("-"*50)
    for group in ['OTC', 'nonOTC', 'control']:
        gd = geom_filt[geom_filt['group'] == group]
        vals = []
        for cat in CATEGORIES:
            cd = gd[gd['category'] == cat]['geometry_preservation']
            vals.append(f"{cd.mean():.2f}" if len(cd) > 0 else "--")
        print(f"{group:<10} {vals[0]:<10} {vals[1]:<10} {vals[2]:<10} {vals[3]:<10}")
    
    print("\n--- SELECTIVITY CHANGE (higher = more change) ---")
    print(f"{'Group':<10} {'Face':<10} {'Word':<10} {'Object':<10} {'House':<10}")
    print("-"*50)
    for group in ['OTC', 'nonOTC', 'control']:
        gd = select_filt[select_filt['group'] == group]
        vals = []
        for cat in CATEGORIES:
            cd = gd[gd['category'] == cat]['selectivity_change']
            vals.append(f"{cd.mean():.2f}" if len(cd) > 0 else "--")
        print(f"{group:<10} {vals[0]:<10} {vals[1]:<10} {vals[2]:<10} {vals[3]:<10}")
    
    print("\n--- SPATIAL DRIFT (mm) ---")
    print(f"{'Group':<10} {'Face':<10} {'Word':<10} {'Object':<10} {'House':<10}")
    print("-"*50)
    for group in ['OTC', 'nonOTC', 'control']:
        gd = drift_filt[drift_filt['group'] == group]
        vals = []
        for cat in CATEGORIES:
            cd = gd[gd['category'] == cat]['spatial_drift_mm']
            vals.append(f"{cd.mean():.1f}" if len(cd) > 0 else "--")
        print(f"{group:<10} {vals[0]:<10} {vals[1]:<10} {vals[2]:<10} {vals[3]:<10}")

# %% [markdown]
# ## Cell 13: Bootstrap Analysis for Robust Inference

# %%
def bootstrap_group_comparison(df, metric_col, n_boot=10000, seed=42):
    """Bootstrap test for OTC bilateral advantage vs other groups"""
    np.random.seed(seed)
    
    # Filter to intact hemisphere
    filtered = []
    for _, row in df.iterrows():
        sid = row['subject']
        info = SUBJECTS[sid]
        if info['group'] == 'control':
            filtered.append(row)
        elif row['hemi'] == info['hemi']:
            filtered.append(row)
    
    filtered_df = pd.DataFrame(filtered)
    filtered_df['cat_type'] = filtered_df['category'].apply(
        lambda x: 'Bilateral' if x in BILATERAL else 'Unilateral'
    )
    
    # Calculate subject-level bilateral advantage (gap)
    subject_gaps = {}
    for group in ['OTC', 'nonOTC', 'control']:
        gd = filtered_df[filtered_df['group'] == group]
        gaps = []
        for sid in gd['subject'].unique():
            sd = gd[gd['subject'] == sid]
            bil = sd[sd['cat_type'] == 'Bilateral'][metric_col].mean()
            uni = sd[sd['cat_type'] == 'Unilateral'][metric_col].mean()
            if pd.notna(bil) and pd.notna(uni):
                gaps.append(bil - uni)
        subject_gaps[group] = np.array(gaps)
    
    results = []
    
    # Compare OTC vs each other group
    for comp_group in ['nonOTC', 'control']:
        g1 = subject_gaps['OTC']
        g2 = subject_gaps[comp_group]
        
        if len(g1) < 2 or len(g2) < 2:
            continue
        
        observed_diff = np.mean(g1) - np.mean(g2)
        
        boot_diffs = []
        for _ in range(n_boot):
            s1 = np.random.choice(g1, size=len(g1), replace=True)
            s2 = np.random.choice(g2, size=len(g2), replace=True)
            boot_diffs.append(np.mean(s1) - np.mean(s2))
        
        boot_diffs = np.array(boot_diffs)
        ci_low = np.percentile(boot_diffs, 2.5)
        ci_high = np.percentile(boot_diffs, 97.5)
        
        # Two-sided p-value
        if observed_diff > 0:
            p_val = 2 * np.mean(boot_diffs <= 0)
        else:
            p_val = 2 * np.mean(boot_diffs >= 0)
        
        results.append({
            'comparison': f'OTC vs {comp_group}',
            'observed_diff': observed_diff,
            'ci_low': ci_low,
            'ci_high': ci_high,
            'p_value': p_val
        })
    
    return pd.DataFrame(results), subject_gaps


print("="*70)
print("BOOTSTRAP ANALYSIS: OTC BILATERAL ADVANTAGE")
print("="*70)

for approach_name in ['all_scramble', 'hybrid']:
    print(f"\n--- {approach_name.upper()} ---")
    
    if approach_name == 'hybrid':
        select_df = hybrid_results['selectivity']
        geom_df = hybrid_results['geometry']
    else:
        select_df = standard_results[approach_name]['selectivity']
        geom_df = standard_results[approach_name]['geometry']
    
    print("\nSelectivity Change (bilateral advantage = more reorganization):")
    boot_results, gaps = bootstrap_group_comparison(select_df, 'selectivity_change')
    print(f"  Subject gaps - OTC: {gaps['OTC'].mean():.3f}, nonOTC: {gaps['nonOTC'].mean():.3f}, control: {gaps['control'].mean():.3f}")
    for _, row in boot_results.iterrows():
        sig = '***' if row['p_value'] < 0.001 else '**' if row['p_value'] < 0.01 else '*' if row['p_value'] < 0.05 else ''
        print(f"  {row['comparison']}: diff={row['observed_diff']:.3f}, 95%CI=[{row['ci_low']:.3f}, {row['ci_high']:.3f}], p={row['p_value']:.4f} {sig}")
    
    print("\nGeometry Preservation (bilateral disadvantage = more reorganization):")
    boot_results, gaps = bootstrap_group_comparison(geom_df, 'geometry_preservation')
    print(f"  Subject gaps - OTC: {gaps['OTC'].mean():.3f}, nonOTC: {gaps['nonOTC'].mean():.3f}, control: {gaps['control'].mean():.3f}")
    for _, row in boot_results.iterrows():
        sig = '***' if row['p_value'] < 0.001 else '**' if row['p_value'] < 0.01 else '*' if row['p_value'] < 0.05 else ''
        print(f"  {row['comparison']}: diff={row['observed_diff']:.3f}, 95%CI=[{row['ci_low']:.3f}, {row['ci_high']:.3f}], p={row['p_value']:.4f} {sig}")

# %% [markdown]
# ## Cell 14: Decision and Final Export

# %%
print("="*70)
print("DECISION: WHICH CONTRAST SCHEME TO USE?")
print("="*70)

print("""
Based on the comparisons above, evaluate:

1. ROI YIELD: Does the scheme find meaningful ROIs for all categories?
   - Check Word ROI voxel counts (liu_mixed may have very few)
   
2. THEORETICAL CONSISTENCY: Same baseline across categories for RSA?
   - all_scramble and all_vs_others are consistent
   - liu_mixed and current_differential mix baselines
   
3. STATISTICAL POWER: Does the scheme detect the hypothesized effect?
   - OTC bilateral > unilateral change

4. HYBRID APPROACH: Does separating localization from patterns help?
   - May get better ROI localization (liu_mixed) 
   - While maintaining RSA consistency (all_scramble patterns)

RECOMMENDATION: Review the summary tables above and decide.
""")

# Show final recommendation based on results
print("\n" + "-"*70)
print("QUANTITATIVE COMPARISON:")
print("-"*70)

comparison_data = []
for approach in ['liu_mixed', 'all_scramble', 'current_differential', 'hybrid']:
    if approach not in approach_tests:
        continue
    
    row = {'approach': approach}
    
    # Get OTC p-values
    if 'selectivity' in approach_tests[approach]:
        sel_df = approach_tests[approach]['selectivity']
        otc_sel = sel_df[sel_df['group'] == 'OTC']
        if len(otc_sel) > 0:
            row['selectivity_p'] = otc_sel['p'].values[0]
            row['selectivity_diff'] = otc_sel['difference'].values[0]
    
    if 'geometry' in approach_tests[approach]:
        geo_df = approach_tests[approach]['geometry']
        otc_geo = geo_df[geo_df['group'] == 'OTC']
        if len(otc_geo) > 0:
            row['geometry_p'] = otc_geo['p'].values[0]
            row['geometry_diff'] = otc_geo['difference'].values[0]
    
    comparison_data.append(row)

comp_df = pd.DataFrame(comparison_data)
print(comp_df.to_string(index=False))

# %% [markdown]
# ## Cell 15: Export Final Results with Chosen Approach

# %%
# Set the chosen approach here after reviewing results
CHOSEN_APPROACH = 'hybrid'  # Change this based on Cell 14 analysis

print(f"="*70)
print(f"EXPORTING FINAL RESULTS: {CHOSEN_APPROACH}")
print(f"="*70)

if CHOSEN_APPROACH == 'hybrid':
    final_results = hybrid_results
else:
    final_results = standard_results[CHOSEN_APPROACH]

# Build comprehensive results DataFrame
export_data = []

for _, row in final_results['selectivity'].iterrows():
    sid = row['subject']
    info = SUBJECTS[sid]
    
    # Get matching geometry and drift
    geom_match = final_results['geometry'][
        (final_results['geometry']['subject'] == sid) & 
        (final_results['geometry']['hemi'] == row['hemi']) &
        (final_results['geometry']['category'] == row['category'])
    ]
    
    drift_match = final_results['drift'][
        (final_results['drift']['subject'] == sid) & 
        (final_results['drift']['hemi'] == row['hemi']) &
        (final_results['drift']['category'] == row['category'])
    ]
    
    export_row = {
        'Subject': row['code'],
        'Group': info['group'],
        'Surgery_Side': info.get('surgery_side', 'na'),
        'Intact_Hemisphere': 'left' if info['hemi'] == 'l' else 'right',
        'Sex': info.get('sex', 'na'),
        'nonpt_hemi': row['hemi'].upper() if info['group'] == 'control' else 'na',
        'Category': row['category'].title(),
        'Category_Type': 'Bilateral' if row['category'] in BILATERAL else 'Unilateral',
        'age_1': info.get('age_1', np.nan),
        'age_2': info.get('age_2', np.nan),
        'yr_gap': info.get('age_2', 0) - info.get('age_1', 0) if info.get('age_1') and info.get('age_2') else np.nan,
        'Selectivity_Change': row['selectivity_change'],
        'Spatial_Relocation_mm': drift_match['spatial_drift_mm'].values[0] if len(drift_match) > 0 else np.nan,
        'Geometry_Preservation_6mm': geom_match['geometry_preservation'].values[0] if len(geom_match) > 0 else np.nan,
    }
    
    export_data.append(export_row)

export_df = pd.DataFrame(export_data)

# Add MDS shift (averaged across ROI categories for each measured category)
mds_df = final_results['mds']
mds_summary = mds_df.groupby(['subject', 'measured_category'])['mds_shift'].mean().reset_index()

for i, row in export_df.iterrows():
    sid = [s for s in SUBJECTS if SUBJECTS[s]['code'] == row['Subject']][0]
    cat = row['Category'].lower()
    
    mds_match = mds_summary[
        (mds_summary['subject'] == sid) & 
        (mds_summary['measured_category'] == cat)
    ]
    
    if len(mds_match) > 0:
        export_df.loc[i, 'MDS_Shift'] = mds_match['mds_shift'].values[0]

# Save
output_file = OUTPUT_DIR / 'results_final_corrected.csv'
export_df.to_csv(output_file, index=False)
print(f"\n✓ Saved to: {output_file}")
print(f"  Shape: {export_df.shape}")

# Final summary
print("\n" + "-"*70)
print("FINAL SUMMARY BY GROUP AND CATEGORY:")
print("-"*70)

print(f"\n{'Measure':<25} {'Group':<10} {'Face':<8} {'Word':<8} {'Object':<8} {'House':<8}")
print("-"*70)

for measure in ['Selectivity_Change', 'Geometry_Preservation_6mm', 'Spatial_Relocation_mm', 'MDS_Shift']:
    for group in ['OTC', 'nonOTC', 'control']:
        gd = export_df[export_df['Group'] == group]
        vals = []
        for cat in ['Face', 'Word', 'Object', 'House']:
            cd = gd[gd['Category'] == cat][measure]
            vals.append(f"{cd.mean():.2f}" if len(cd) > 0 and cd.notna().any() else "--")
        print(f"{measure:<25} {group:<10} {vals[0]:<8} {vals[1]:<8} {vals[2]:<8} {vals[3]:<8}")
    print()

print("\n" + "="*70)
print("ANALYSIS COMPLETE")
print("="*70)

# new

In [7]:
# %% [markdown]
# # Longitudinal RSA Analysis: Contrast Scheme Comparison
# 
# **Study:** VOTC Resection - Bilateral vs Unilateral Visual Category Reorganization
# 
# **Hypothesis:** Bilateral categories (Object, House) show greater representational 
# reorganization than unilateral categories (Face, Word) in OTC patients because 
# the hemispheres work collaboratively (not redundantly), so losing one forces compensation.
# 
# **This notebook:**
# 1. Compares 4 contrast schemes for ROI localization and pattern extraction
# 2. Tests hybrid approach (Liu for localization, Scramble for patterns)
# 3. Computes all RSA measures with each approach
# 4. Determines optimal contrast scheme based on results

# %% [markdown]
# ## Cell 1: Setup & Configuration

# %%
import numpy as np
import nibabel as nib
from pathlib import Path
import pandas as pd
from scipy.ndimage import label, center_of_mass
from scipy.stats import pearsonr, ttest_ind, ttest_rel
from scipy.spatial.distance import squareform
from scipy.linalg import orthogonal_procrustes
import warnings
warnings.filterwarnings('ignore')

# Paths
BASE_DIR = Path("/user_data/csimmon2/long_pt")
CSV_FILE = Path('/user_data/csimmon2/git_repos/long_pt/long_pt_sub_info.csv')
OUTPUT_DIR = Path('/user_data/csimmon2/git_repos/long_pt/B_analyses')

# Load subject info
df = pd.read_csv(CSV_FILE)

# Session overrides
SESSION_START = {'sub-010': 2, 'sub-018': 2, 'sub-068': 2}

# Subjects to exclude
EXCLUDE_SUBJECTS = ['sub-025', 'sub-027', 'sub-045', 'sub-072']

# Categories
CATEGORIES = ['face', 'word', 'object', 'house']
BILATERAL = ['object', 'house']
UNILATERAL = ['face', 'word']

# ============================================================
# CONTRAST SCHEME DEFINITIONS
# ============================================================

# Scheme 1: Liu's original mixed contrasts (for ROI localization)
COPE_LIU_MIXED = {
    'face': (1, 1),    # Face > Object
    'word': (4, 1),    # Word > Object  
    'object': (3, 1),  # Object > Scramble
    'house': (2, 1)    # House > Object
}

# Scheme 2: All vs Scramble (consistent baseline)
COPE_ALL_SCRAMBLE = {
    'face': (10, 1),   # Face > Scramble
    'word': (12, 1),   # Word > Scramble
    'object': (3, 1),  # Object > Scramble
    'house': (11, 1)   # House > Scramble
}

# Scheme 3: All vs Others (category vs mean of others)
COPE_ALL_VS_OTHERS = {
    'face': (6, 1),    # Face > mean(House+Object+Word+Scramble)
    'word': (9, 1),    # Word > mean(Face+House+Object+Scramble)
    'object': (8, 1),  # Object > mean(Face+House+Word+Scramble)
    'house': (7, 1)    # House > mean(Face+Object+Word+Scramble)
}

# Scheme 4: Current differential (what Script 1 used - problematic)
COPE_DIFFERENTIAL = {
    'face': (10, 1),   # Face > Scramble
    'word': (13, -1),  # Face > Word, flipped to Word > Face
    'object': (3, 1),  # Object > Scramble
    'house': (11, 1)   # House > Scramble
}

CONTRAST_SCHEMES = {
    'liu_mixed': COPE_LIU_MIXED,
    'all_scramble': COPE_ALL_SCRAMBLE,
    'all_vs_others': COPE_ALL_VS_OTHERS,
    'current_differential': COPE_DIFFERENTIAL
}

print("✓ Configuration loaded")
print(f"  Excluding: {EXCLUDE_SUBJECTS}")

# %% [markdown]
# ## Cell 2: Load Subjects

# %%
def load_subjects():
    """Load all subjects from CSV, excluding problematic ones"""
    subjects = {}
    
    for _, row in df.iterrows():
        subject_id = row['sub']
        
        if subject_id in EXCLUDE_SUBJECTS:
            continue
            
        subj_dir = BASE_DIR / subject_id
        if not subj_dir.exists():
            continue
        
        sessions = sorted([d.name.replace('ses-', '') for d in subj_dir.glob('ses-*') if d.is_dir()], key=int)
        start_session = SESSION_START.get(subject_id, 1)
        sessions = [s for s in sessions if int(s) >= start_session]
        
        if len(sessions) < 2:
            continue
        
        hemi = 'l' if row.get('intact_hemi', 'left') == 'left' else 'r'
        
        subjects[subject_id] = {
            'code': f"{row['group']}{subject_id.split('-')[1]}",
            'sessions': sessions,
            'hemi': hemi,
            'group': row['group'],
            'patient': row['patient'] == 1,
            'surgery_side': row.get('SurgerySide', None),
            'sex': row.get('sex', None),
            'age_1': row.get('age_1', None),
            'age_2': row.get('age_2', None)
        }
    
    return subjects

SUBJECTS = load_subjects()

# Summary
print(f"✓ Loaded {len(SUBJECTS)} subjects (after exclusions)")
for group in ['OTC', 'nonOTC', 'control']:
    n = sum(1 for s in SUBJECTS.values() if s['group'] == group)
    print(f"  {group}: {n}")

# %% [markdown]
# ## Cell 3: ROI Extraction Functions

# %%
def create_sphere(center_coord, affine, brain_shape, radius=6):
    """Create spherical mask around coordinate"""
    grid = np.array(np.meshgrid(
        np.arange(brain_shape[0]),
        np.arange(brain_shape[1]),
        np.arange(brain_shape[2]),
        indexing='ij'
    )).reshape(3, -1).T
    
    world = nib.affines.apply_affine(affine, grid)
    distances = np.linalg.norm(world - center_coord, axis=1)
    
    mask = np.zeros(brain_shape, dtype=bool)
    within = grid[distances <= radius]
    for c in within:
        mask[c[0], c[1], c[2]] = True
    
    return mask


def extract_rois_single_scheme(cope_map, threshold_z=2.3, min_voxels=20):
    """Extract ROIs for all subjects using a single contrast scheme"""
    
    all_rois = {}
    
    for sid, info in SUBJECTS.items():
        first_ses = info['sessions'][0]
        roi_dir = BASE_DIR / sid / f'ses-{first_ses}' / 'ROIs'
        
        if not roi_dir.exists():
            continue
        
        all_rois[sid] = {}
        
        # For controls, extract both hemispheres
        hemis = ['l', 'r'] if info['group'] == 'control' else [info['hemi']]
        
        for hemi in hemis:
            for category in CATEGORIES:
                cope_num, mult = cope_map[category]
                
                # Load search mask
                mask_file = roi_dir / f'{hemi}_{category}_searchmask.nii.gz'
                if not mask_file.exists():
                    continue
                
                try:
                    mask_img = nib.load(mask_file)
                    search_mask = mask_img.get_fdata() > 0
                    affine = mask_img.affine
                except:
                    continue
                
                roi_key = f'{hemi}_{category}'
                all_rois[sid][roi_key] = {}
                
                for session in info['sessions']:
                    feat_dir = BASE_DIR / sid / f'ses-{session}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                    z_name = 'zstat1.nii.gz' if session == first_ses else f'zstat1_ses{first_ses}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        continue
                    
                    try:
                        z_data = nib.load(cope_file).get_fdata() * mult
                        suprathresh = (z_data > threshold_z) & search_mask
                        
                        if suprathresh.sum() < min_voxels:
                            continue
                        
                        labeled, n_clusters = label(suprathresh)
                        if n_clusters == 0:
                            continue
                        
                        # Largest cluster
                        sizes = [(labeled == i).sum() for i in range(1, n_clusters + 1)]
                        best_idx = np.argmax(sizes) + 1
                        roi_mask = (labeled == best_idx)
                        
                        if roi_mask.sum() < min_voxels:
                            continue
                        
                        peak_idx = np.unravel_index(np.argmax(z_data * roi_mask), z_data.shape)
                        
                        all_rois[sid][roi_key][session] = {
                            'n_voxels': int(roi_mask.sum()),
                            'peak_z': z_data[peak_idx],
                            'centroid': nib.affines.apply_affine(affine, center_of_mass(roi_mask)),
                            'peak_coord': nib.affines.apply_affine(affine, peak_idx),
                            'roi_mask': roi_mask,
                            'affine': affine,
                            'shape': z_data.shape
                        }
                    except Exception as e:
                        continue
    
    return all_rois

print("✓ ROI extraction functions defined")

# %% [markdown]
# ## Cell 4: RSA Metric Functions

# %%
def compute_rdm(patterns):
    """Compute RDM from pattern matrix (categories x voxels)"""
    corr = np.corrcoef(patterns)
    rdm = 1 - corr
    return rdm, corr


def compute_geometry_preservation(rois, pattern_cope_map, radius=6):
    """
    Geometry Preservation: RDM stability across sessions
    - Extract patterns from sphere at each session's centroid
    - Correlate T1 and T2 RDMs
    """
    results = []
    
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        first_ses = info['sessions'][0]
        
        for roi_key, sessions_data in roi_data.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            hemi = roi_key.split('_')[0]
            category = roi_key.split('_')[1]
            
            # Get reference for sphere creation
            ref_data = sessions_data[sessions[0]]
            affine = ref_data['affine']
            shape = ref_data['shape']
            
            rdms = {}
            for ses in [sessions[0], sessions[-1]]:
                if ses not in sessions_data:
                    continue
                
                centroid = sessions_data[ses]['centroid']
                sphere = create_sphere(centroid, affine, shape, radius)
                
                feat_dir = BASE_DIR / sid / f'ses-{ses}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                
                patterns = []
                valid = True
                for cat in CATEGORIES:
                    cope_num, mult = pattern_cope_map[cat]
                    z_name = 'zstat1.nii.gz' if ses == first_ses else f'zstat1_ses{first_ses}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        valid = False
                        break
                    
                    data = nib.load(cope_file).get_fdata() * mult
                    pattern = data[sphere]
                    
                    if len(pattern) == 0 or not np.all(np.isfinite(pattern)):
                        valid = False
                        break
                    
                    patterns.append(pattern)
                
                if valid and len(patterns) == 4:
                    rdm, _ = compute_rdm(np.array(patterns))
                    rdms[ses] = rdm
            
            if len(rdms) == 2:
                triu = np.triu_indices(4, k=1)
                r, _ = pearsonr(rdms[sessions[0]][triu], rdms[sessions[-1]][triu])
                
                results.append({
                    'subject': sid,
                    'code': info['code'],
                    'group': info['group'],
                    'hemi': hemi,
                    'category': category,
                    'geometry_preservation': r
                })
    
    return pd.DataFrame(results)


def compute_mds_shift(rois, pattern_cope_map, radius=6):
    """
    MDS Shift: Procrustes-aligned embedding distance
    - MDS embed RDMs to 2D
    - Align with Procrustes
    - Measure movement of each category
    """
    def mds_2d(rdm):
        n = rdm.shape[0]
        H = np.eye(n) - np.ones((n, n)) / n
        B = -0.5 * H @ (rdm ** 2) @ H
        eigvals, eigvecs = np.linalg.eigh(B)
        idx = np.argsort(eigvals)[::-1]
        coords = eigvecs[:, idx[:2]] * np.sqrt(np.maximum(eigvals[idx[:2]], 0))
        return coords
    
    results = []
    
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        first_ses = info['sessions'][0]
        
        for roi_key, sessions_data in roi_data.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            hemi = roi_key.split('_')[0]
            roi_category = roi_key.split('_')[1]
            
            ref_data = sessions_data[sessions[0]]
            affine = ref_data['affine']
            shape = ref_data['shape']
            
            rdms = {}
            for ses in [sessions[0], sessions[-1]]:
                if ses not in sessions_data:
                    continue
                
                centroid = sessions_data[ses]['centroid']
                sphere = create_sphere(centroid, affine, shape, radius)
                
                feat_dir = BASE_DIR / sid / f'ses-{ses}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                
                patterns = []
                valid = True
                for cat in CATEGORIES:
                    cope_num, mult = pattern_cope_map[cat]
                    z_name = 'zstat1.nii.gz' if ses == first_ses else f'zstat1_ses{first_ses}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        valid = False
                        break
                    
                    data = nib.load(cope_file).get_fdata() * mult
                    pattern = data[sphere]
                    
                    if len(pattern) == 0 or not np.all(np.isfinite(pattern)):
                        valid = False
                        break
                    
                    patterns.append(pattern)
                
                if valid and len(patterns) == 4:
                    rdm, _ = compute_rdm(np.array(patterns))
                    rdms[ses] = rdm
            
            if len(rdms) == 2:
                try:
                    coords_t1 = mds_2d(rdms[sessions[0]])
                    coords_t2 = mds_2d(rdms[sessions[-1]])
                    
                    R, _ = orthogonal_procrustes(coords_t1, coords_t2)
                    coords_t1_aligned = coords_t1 @ R
                    
                    for i, cat in enumerate(CATEGORIES):
                        dist = np.linalg.norm(coords_t1_aligned[i] - coords_t2[i])
                        results.append({
                            'subject': sid,
                            'code': info['code'],
                            'group': info['group'],
                            'hemi': hemi,
                            'roi_category': roi_category,
                            'measured_category': cat,
                            'mds_shift': dist
                        })
                except:
                    continue
    
    return pd.DataFrame(results)


def compute_spatial_drift(rois):
    """
    Spatial Drift: Euclidean distance between T1 and T2 peak centroids
    """
    results = []
    
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        
        for roi_key, sessions_data in roi_data.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            hemi = roi_key.split('_')[0]
            category = roi_key.split('_')[1]
            
            c1 = sessions_data[sessions[0]]['centroid']
            c2 = sessions_data[sessions[-1]]['centroid']
            drift = np.linalg.norm(np.array(c2) - np.array(c1))
            
            results.append({
                'subject': sid,
                'code': info['code'],
                'group': info['group'],
                'hemi': hemi,
                'category': category,
                'spatial_drift_mm': drift,
                't1_peak_z': sessions_data[sessions[0]]['peak_z']
            })
    
    return pd.DataFrame(results)


def compute_selectivity_change(rois, pattern_cope_map):
    """
    Selectivity Change (Liu Distinctiveness):
    - Correlation of preferred category with non-preferred categories
    - Change from T1 to T2
    """
    results = []
    
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        first_ses = info['sessions'][0]
        
        for roi_key, sessions_data in roi_data.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            hemi = roi_key.split('_')[0]
            category = roi_key.split('_')[1]
            
            ref_data = sessions_data[sessions[0]]
            affine = ref_data['affine']
            shape = ref_data['shape']
            
            distinctiveness = {}
            for ses in [sessions[0], sessions[-1]]:
                if ses not in sessions_data:
                    continue
                
                centroid = sessions_data[ses]['centroid']
                sphere = create_sphere(centroid, affine, shape, radius=6)
                
                feat_dir = BASE_DIR / sid / f'ses-{ses}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                
                patterns = {}
                valid = True
                for cat in CATEGORIES:
                    cope_num, mult = pattern_cope_map[cat]
                    z_name = 'zstat1.nii.gz' if ses == first_ses else f'zstat1_ses{first_ses}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        valid = False
                        break
                    
                    data = nib.load(cope_file).get_fdata() * mult
                    pattern = data[sphere]
                    
                    if len(pattern) == 0 or not np.all(np.isfinite(pattern)):
                        valid = False
                        break
                    
                    patterns[cat] = pattern
                
                if valid and len(patterns) == 4:
                    # Compute correlation of preferred with non-preferred
                    pref_pattern = patterns[category]
                    nonpref_corrs = []
                    for other_cat in CATEGORIES:
                        if other_cat != category:
                            r, _ = pearsonr(pref_pattern, patterns[other_cat])
                            nonpref_corrs.append(np.arctanh(np.clip(r, -0.999, 0.999)))
                    
                    distinctiveness[ses] = np.mean(nonpref_corrs)
            
            if len(distinctiveness) == 2:
                change = abs(distinctiveness[sessions[-1]] - distinctiveness[sessions[0]])
                results.append({
                    'subject': sid,
                    'code': info['code'],
                    'group': info['group'],
                    'hemi': hemi,
                    'category': category,
                    'selectivity_change': change,
                    't1_distinctiveness': distinctiveness[sessions[0]],
                    't2_distinctiveness': distinctiveness[sessions[-1]]
                })
    
    return pd.DataFrame(results)

print("✓ RSA metric functions defined")

# %% [markdown]
# ## Cell 5: Extract ROIs with All Contrast Schemes

# %%
print("="*70)
print("EXTRACTING ROIs WITH ALL CONTRAST SCHEMES")
print("="*70)

all_rois = {}
for scheme_name, cope_map in CONTRAST_SCHEMES.items():
    print(f"\n{scheme_name}...")
    all_rois[scheme_name] = extract_rois_single_scheme(cope_map)
    
    # Count ROIs
    n_rois = sum(len(roi_data) for roi_data in all_rois[scheme_name].values())
    n_subjects = len([s for s in all_rois[scheme_name] if all_rois[scheme_name][s]])
    print(f"  ✓ {n_subjects} subjects, {n_rois} ROIs")

print("\n" + "="*70)

# %% [markdown]
# ## Cell 6: Compare ROI Yields Across Schemes

# %%
print("="*70)
print("ROI YIELD COMPARISON BY CATEGORY")
print("="*70)

# Count ROIs per category per scheme
yield_data = []
for scheme_name, rois in all_rois.items():
    for sid, roi_data in rois.items():
        info = SUBJECTS[sid]
        for roi_key, sessions_data in roi_data.items():
            if len(sessions_data) >= 2:  # Has both timepoints
                hemi = roi_key.split('_')[0]
                category = roi_key.split('_')[1]
                n_voxels_t1 = sessions_data[sorted(sessions_data.keys())[0]]['n_voxels']
                
                yield_data.append({
                    'scheme': scheme_name,
                    'subject': sid,
                    'group': info['group'],
                    'hemi': hemi,
                    'category': category,
                    'n_voxels': n_voxels_t1
                })

yield_df = pd.DataFrame(yield_data)

print("\nROI counts by scheme and category:")
print("-"*60)
pivot = yield_df.groupby(['scheme', 'category']).size().unstack(fill_value=0)
print(pivot)

print("\n\nMean voxel count by scheme and category:")
print("-"*60)
pivot_vox = yield_df.groupby(['scheme', 'category'])['n_voxels'].mean().unstack()
print(pivot_vox.round(0))

# Word ROI problem visualization
print("\n\nWORD ROI VOXEL COUNTS (the problematic category):")
print("-"*60)
word_data = yield_df[yield_df['category'] == 'word']
for scheme in CONTRAST_SCHEMES.keys():
    scheme_word = word_data[word_data['scheme'] == scheme]
    print(f"\n{scheme}:")
    print(f"  N ROIs: {len(scheme_word)}")
    print(f"  Mean voxels: {scheme_word['n_voxels'].mean():.0f}")
    print(f"  Min voxels: {scheme_word['n_voxels'].min()}")
    print(f"  Max voxels: {scheme_word['n_voxels'].max()}")

# %% [markdown]
# ## Cell 7: Compute All Metrics - Standard Approach (Same scheme for localization and patterns)

# %%
print("="*70)
print("COMPUTING METRICS: STANDARD APPROACH")
print("(Same contrast scheme for ROI localization AND pattern extraction)")
print("="*70)

standard_results = {}

for scheme_name in CONTRAST_SCHEMES.keys():
    print(f"\n{scheme_name}...")
    
    rois = all_rois[scheme_name]
    cope_map = CONTRAST_SCHEMES[scheme_name]
    
    # Compute all metrics
    geom = compute_geometry_preservation(rois, cope_map)
    mds = compute_mds_shift(rois, cope_map)
    drift = compute_spatial_drift(rois)
    select = compute_selectivity_change(rois, cope_map)
    
    standard_results[scheme_name] = {
        'geometry': geom,
        'mds': mds,
        'drift': drift,
        'selectivity': select
    }
    
    print(f"  Geometry: {len(geom)} ROIs")
    print(f"  MDS: {len(mds)} measurements")
    print(f"  Drift: {len(drift)} ROIs")
    print(f"  Selectivity: {len(select)} ROIs")

# %% [markdown]
# ## Cell 8: Compute Metrics - Hybrid Approach

# %%
print("="*70)
print("COMPUTING METRICS: HYBRID APPROACH")
print("(Liu mixed for localization, All Scramble for patterns)")
print("="*70)

# Use Liu mixed ROIs, but All Scramble patterns
hybrid_rois = all_rois['liu_mixed']
pattern_cope = COPE_ALL_SCRAMBLE

print("\nUsing ROIs from: liu_mixed")
print("Using patterns from: all_scramble")

hybrid_results = {
    'geometry': compute_geometry_preservation(hybrid_rois, pattern_cope),
    'mds': compute_mds_shift(hybrid_rois, pattern_cope),
    'drift': compute_spatial_drift(hybrid_rois),
    'selectivity': compute_selectivity_change(hybrid_rois, pattern_cope)
}

print(f"\n✓ Geometry: {len(hybrid_results['geometry'])} ROIs")
print(f"✓ MDS: {len(hybrid_results['mds'])} measurements")
print(f"✓ Drift: {len(hybrid_results['drift'])} ROIs")
print(f"✓ Selectivity: {len(hybrid_results['selectivity'])} ROIs")

# %% [markdown]
# ## Cell 9: Compare Results Across Approaches

# %%
def summarize_by_category(df, metric_col, groups=['OTC', 'nonOTC', 'control']):
    """Summarize metric by group and category"""
    summary = []
    for group in groups:
        group_data = df[df['group'] == group]
        for cat in CATEGORIES:
            cat_data = group_data[group_data['category'] == cat][metric_col]
            if len(cat_data) > 0:
                summary.append({
                    'group': group,
                    'category': cat,
                    'mean': cat_data.mean(),
                    'std': cat_data.std(),
                    'n': len(cat_data)
                })
    return pd.DataFrame(summary)


print("="*70)
print("RESULTS COMPARISON: GEOMETRY PRESERVATION")
print("(Higher = more stable, lower in bilateral = MORE reorganization)")
print("="*70)

for approach_name, results in [('liu_mixed (standard)', standard_results['liu_mixed']),
                                ('all_scramble (standard)', standard_results['all_scramble']),
                                ('hybrid (liu ROI + scramble pattern)', {'geometry': hybrid_results['geometry']})]:
    print(f"\n--- {approach_name} ---")
    geom_df = results['geometry'] if 'geometry' in results else results.get('geometry')
    
    if geom_df is None or len(geom_df) == 0:
        print("  No data")
        continue
    
    # Filter to intact hemisphere for patients
    filtered = []
    for _, row in geom_df.iterrows():
        sid = row['subject']
        info = SUBJECTS[sid]
        if info['group'] == 'control':
            filtered.append(row)
        elif row['hemi'] == info['hemi']:
            filtered.append(row)
    
    filtered_df = pd.DataFrame(filtered)
    
    print(f"\n{'Group':<10} {'Face':<12} {'Word':<12} {'Object':<12} {'House':<12}")
    print("-"*60)
    
    for group in ['OTC', 'nonOTC', 'control']:
        gd = filtered_df[filtered_df['group'] == group]
        vals = []
        for cat in CATEGORIES:
            cd = gd[gd['category'] == cat]['geometry_preservation']
            if len(cd) > 0:
                vals.append(f"{cd.mean():.3f}")
            else:
                vals.append("--")
        print(f"{group:<10} {vals[0]:<12} {vals[1]:<12} {vals[2]:<12} {vals[3]:<12}")

# %% [markdown]
# ## Cell 10: Statistical Tests - Compare Bilateral vs Unilateral Effect

# %%
def test_bilateral_effect(df, metric_col, metric_name):
    """Test if bilateral > unilateral change (or < for stability metrics)"""
    
    print(f"\n{'='*70}")
    print(f"BILATERAL vs UNILATERAL: {metric_name}")
    print(f"{'='*70}")
    
    # Filter to intact hemisphere for patients
    filtered = []
    for _, row in df.iterrows():
        sid = row['subject']
        info = SUBJECTS[sid]
        if info['group'] == 'control':
            filtered.append(row)
        elif row['hemi'] == info['hemi']:
            filtered.append(row)
    
    filtered_df = pd.DataFrame(filtered)
    filtered_df['cat_type'] = filtered_df['category'].apply(
        lambda x: 'Bilateral' if x in BILATERAL else 'Unilateral'
    )
    
    print(f"\n{'Group':<12} {'Bilateral':<15} {'Unilateral':<15} {'Diff':<10} {'t':<8} {'p':<8}")
    print("-"*70)
    
    results = []
    for group in ['OTC', 'nonOTC', 'control']:
        gd = filtered_df[filtered_df['group'] == group]
        bil = gd[gd['cat_type'] == 'Bilateral'][metric_col]
        uni = gd[gd['cat_type'] == 'Unilateral'][metric_col]
        
        if len(bil) > 1 and len(uni) > 1:
            t, p = ttest_ind(bil, uni)
            diff = bil.mean() - uni.mean()
            sig = '*' if p < 0.05 else ''
            print(f"{group:<12} {bil.mean():.3f}±{bil.std():.3f}   {uni.mean():.3f}±{uni.std():.3f}   {diff:+.3f}     {t:.2f}    {p:.4f} {sig}")
            
            results.append({
                'group': group,
                'bilateral_mean': bil.mean(),
                'unilateral_mean': uni.mean(),
                'difference': diff,
                't': t,
                'p': p
            })
    
    return pd.DataFrame(results)


print("\n" + "="*70)
print("COMPARING BILATERAL vs UNILATERAL EFFECTS ACROSS APPROACHES")
print("="*70)

# Test each approach
approach_tests = {}

for scheme_name in ['liu_mixed', 'all_scramble', 'current_differential']:
    print(f"\n\n{'#'*70}")
    print(f"APPROACH: {scheme_name}")
    print(f"{'#'*70}")
    
    results = standard_results[scheme_name]
    
    approach_tests[scheme_name] = {}
    
    # Geometry Preservation (lower bilateral = more change)
    if len(results['geometry']) > 0:
        approach_tests[scheme_name]['geometry'] = test_bilateral_effect(
            results['geometry'], 'geometry_preservation', 'Geometry Preservation'
        )
    
    # Selectivity Change (higher bilateral = more change)
    if len(results['selectivity']) > 0:
        approach_tests[scheme_name]['selectivity'] = test_bilateral_effect(
            results['selectivity'], 'selectivity_change', 'Selectivity Change'
        )

# Hybrid approach
print(f"\n\n{'#'*70}")
print("APPROACH: HYBRID (liu ROI + scramble patterns)")
print(f"{'#'*70}")

approach_tests['hybrid'] = {}
approach_tests['hybrid']['geometry'] = test_bilateral_effect(
    hybrid_results['geometry'], 'geometry_preservation', 'Geometry Preservation'
)
approach_tests['hybrid']['selectivity'] = test_bilateral_effect(
    hybrid_results['selectivity'], 'selectivity_change', 'Selectivity Change'
)

# %% [markdown]
# ## Cell 11: Summary Comparison Table

# %%
print("="*70)
print("SUMMARY: OTC BILATERAL vs UNILATERAL EFFECT BY APPROACH")
print("="*70)
print("\nKey question: Does OTC show significantly greater change in bilateral")
print("categories compared to unilateral? (This supports our hypothesis)")
print()

print(f"{'Approach':<25} {'Measure':<20} {'Bil-Uni Diff':<15} {'p-value':<10} {'Sig?':<5}")
print("-"*75)

for approach_name in ['liu_mixed', 'all_scramble', 'current_differential', 'hybrid']:
    if approach_name not in approach_tests:
        continue
    
    for measure in ['geometry', 'selectivity']:
        if measure not in approach_tests[approach_name]:
            continue
        
        test_df = approach_tests[approach_name][measure]
        otc_row = test_df[test_df['group'] == 'OTC']
        
        if len(otc_row) == 0:
            continue
        
        diff = otc_row['difference'].values[0]
        p = otc_row['p'].values[0]
        sig = '✓' if p < 0.05 else ''
        
        # For geometry, negative diff means bilateral has MORE change (lower stability)
        # For selectivity, positive diff means bilateral has MORE change
        measure_label = 'Geometry Pres.' if measure == 'geometry' else 'Selectivity Chg.'
        
        print(f"{approach_name:<25} {measure_label:<20} {diff:+.4f}        {p:.4f}     {sig}")

print("\n" + "-"*75)
print("Note: For Geometry Preservation, NEGATIVE diff = bilateral shows MORE change")
print("      For Selectivity Change, POSITIVE diff = bilateral shows MORE change")

# %% [markdown]
# ## Cell 12: Detailed Category-Level Results for Best Approach

# %%
print("="*70)
print("DETAILED CATEGORY-LEVEL RESULTS")
print("="*70)

# Determine best approach based on OTC significance
# Let's show results for multiple approaches for comparison

for approach_name in ['all_scramble', 'hybrid']:
    print(f"\n\n{'#'*70}")
    print(f"APPROACH: {approach_name}")
    print(f"{'#'*70}")
    
    if approach_name == 'hybrid':
        geom_df = hybrid_results['geometry']
        select_df = hybrid_results['selectivity']
        drift_df = hybrid_results['drift']
    else:
        geom_df = standard_results[approach_name]['geometry']
        select_df = standard_results[approach_name]['selectivity']
        drift_df = standard_results[approach_name]['drift']
    
    # Filter to intact hemisphere
    def filter_intact(df):
        filtered = []
        for _, row in df.iterrows():
            sid = row['subject']
            info = SUBJECTS[sid]
            if info['group'] == 'control':
                filtered.append(row)
            elif row['hemi'] == info['hemi']:
                filtered.append(row)
        return pd.DataFrame(filtered)
    
    geom_filt = filter_intact(geom_df)
    select_filt = filter_intact(select_df)
    drift_filt = filter_intact(drift_df)
    
    print("\n--- GEOMETRY PRESERVATION (higher = more stable) ---")
    print(f"{'Group':<10} {'Face':<10} {'Word':<10} {'Object':<10} {'House':<10}")
    print("-"*50)
    for group in ['OTC', 'nonOTC', 'control']:
        gd = geom_filt[geom_filt['group'] == group]
        vals = []
        for cat in CATEGORIES:
            cd = gd[gd['category'] == cat]['geometry_preservation']
            vals.append(f"{cd.mean():.2f}" if len(cd) > 0 else "--")
        print(f"{group:<10} {vals[0]:<10} {vals[1]:<10} {vals[2]:<10} {vals[3]:<10}")
    
    print("\n--- SELECTIVITY CHANGE (higher = more change) ---")
    print(f"{'Group':<10} {'Face':<10} {'Word':<10} {'Object':<10} {'House':<10}")
    print("-"*50)
    for group in ['OTC', 'nonOTC', 'control']:
        gd = select_filt[select_filt['group'] == group]
        vals = []
        for cat in CATEGORIES:
            cd = gd[gd['category'] == cat]['selectivity_change']
            vals.append(f"{cd.mean():.2f}" if len(cd) > 0 else "--")
        print(f"{group:<10} {vals[0]:<10} {vals[1]:<10} {vals[2]:<10} {vals[3]:<10}")
    
    print("\n--- SPATIAL DRIFT (mm) ---")
    print(f"{'Group':<10} {'Face':<10} {'Word':<10} {'Object':<10} {'House':<10}")
    print("-"*50)
    for group in ['OTC', 'nonOTC', 'control']:
        gd = drift_filt[drift_filt['group'] == group]
        vals = []
        for cat in CATEGORIES:
            cd = gd[gd['category'] == cat]['spatial_drift_mm']
            vals.append(f"{cd.mean():.1f}" if len(cd) > 0 else "--")
        print(f"{group:<10} {vals[0]:<10} {vals[1]:<10} {vals[2]:<10} {vals[3]:<10}")

# %% [markdown]
# ## Cell 13: Bootstrap Analysis for Robust Inference

# %%
def bootstrap_group_comparison(df, metric_col, n_boot=10000, seed=42):
    """Bootstrap test for OTC bilateral advantage vs other groups"""
    np.random.seed(seed)
    
    # Filter to intact hemisphere
    filtered = []
    for _, row in df.iterrows():
        sid = row['subject']
        info = SUBJECTS[sid]
        if info['group'] == 'control':
            filtered.append(row)
        elif row['hemi'] == info['hemi']:
            filtered.append(row)
    
    filtered_df = pd.DataFrame(filtered)
    filtered_df['cat_type'] = filtered_df['category'].apply(
        lambda x: 'Bilateral' if x in BILATERAL else 'Unilateral'
    )
    
    # Calculate subject-level bilateral advantage (gap)
    subject_gaps = {}
    for group in ['OTC', 'nonOTC', 'control']:
        gd = filtered_df[filtered_df['group'] == group]
        gaps = []
        for sid in gd['subject'].unique():
            sd = gd[gd['subject'] == sid]
            bil = sd[sd['cat_type'] == 'Bilateral'][metric_col].mean()
            uni = sd[sd['cat_type'] == 'Unilateral'][metric_col].mean()
            if pd.notna(bil) and pd.notna(uni):
                gaps.append(bil - uni)
        subject_gaps[group] = np.array(gaps)
    
    results = []
    
    # Compare OTC vs each other group
    for comp_group in ['nonOTC', 'control']:
        g1 = subject_gaps['OTC']
        g2 = subject_gaps[comp_group]
        
        if len(g1) < 2 or len(g2) < 2:
            continue
        
        observed_diff = np.mean(g1) - np.mean(g2)
        
        boot_diffs = []
        for _ in range(n_boot):
            s1 = np.random.choice(g1, size=len(g1), replace=True)
            s2 = np.random.choice(g2, size=len(g2), replace=True)
            boot_diffs.append(np.mean(s1) - np.mean(s2))
        
        boot_diffs = np.array(boot_diffs)
        ci_low = np.percentile(boot_diffs, 2.5)
        ci_high = np.percentile(boot_diffs, 97.5)
        
        # Two-sided p-value
        if observed_diff > 0:
            p_val = 2 * np.mean(boot_diffs <= 0)
        else:
            p_val = 2 * np.mean(boot_diffs >= 0)
        
        results.append({
            'comparison': f'OTC vs {comp_group}',
            'observed_diff': observed_diff,
            'ci_low': ci_low,
            'ci_high': ci_high,
            'p_value': p_val
        })
    
    return pd.DataFrame(results), subject_gaps


print("="*70)
print("BOOTSTRAP ANALYSIS: OTC BILATERAL ADVANTAGE")
print("="*70)

for approach_name in ['all_scramble', 'hybrid']:
    print(f"\n--- {approach_name.upper()} ---")
    
    if approach_name == 'hybrid':
        select_df = hybrid_results['selectivity']
        geom_df = hybrid_results['geometry']
    else:
        select_df = standard_results[approach_name]['selectivity']
        geom_df = standard_results[approach_name]['geometry']
    
    print("\nSelectivity Change (bilateral advantage = more reorganization):")
    boot_results, gaps = bootstrap_group_comparison(select_df, 'selectivity_change')
    print(f"  Subject gaps - OTC: {gaps['OTC'].mean():.3f}, nonOTC: {gaps['nonOTC'].mean():.3f}, control: {gaps['control'].mean():.3f}")
    for _, row in boot_results.iterrows():
        sig = '***' if row['p_value'] < 0.001 else '**' if row['p_value'] < 0.01 else '*' if row['p_value'] < 0.05 else ''
        print(f"  {row['comparison']}: diff={row['observed_diff']:.3f}, 95%CI=[{row['ci_low']:.3f}, {row['ci_high']:.3f}], p={row['p_value']:.4f} {sig}")
    
    print("\nGeometry Preservation (bilateral disadvantage = more reorganization):")
    boot_results, gaps = bootstrap_group_comparison(geom_df, 'geometry_preservation')
    print(f"  Subject gaps - OTC: {gaps['OTC'].mean():.3f}, nonOTC: {gaps['nonOTC'].mean():.3f}, control: {gaps['control'].mean():.3f}")
    for _, row in boot_results.iterrows():
        sig = '***' if row['p_value'] < 0.001 else '**' if row['p_value'] < 0.01 else '*' if row['p_value'] < 0.05 else ''
        print(f"  {row['comparison']}: diff={row['observed_diff']:.3f}, 95%CI=[{row['ci_low']:.3f}, {row['ci_high']:.3f}], p={row['p_value']:.4f} {sig}")

# %% [markdown]
# ## Cell 14: Decision and Final Export

# %%
print("="*70)
print("DECISION: WHICH CONTRAST SCHEME TO USE?")
print("="*70)

print("""
Based on the comparisons above, evaluate:

1. ROI YIELD: Does the scheme find meaningful ROIs for all categories?
   - Check Word ROI voxel counts (liu_mixed may have very few)
   
2. THEORETICAL CONSISTENCY: Same baseline across categories for RSA?
   - all_scramble and all_vs_others are consistent
   - liu_mixed and current_differential mix baselines
   
3. STATISTICAL POWER: Does the scheme detect the hypothesized effect?
   - OTC bilateral > unilateral change

4. HYBRID APPROACH: Does separating localization from patterns help?
   - May get better ROI localization (liu_mixed) 
   - While maintaining RSA consistency (all_scramble patterns)

RECOMMENDATION: Review the summary tables above and decide.
""")

# Show final recommendation based on results
print("\n" + "-"*70)
print("QUANTITATIVE COMPARISON:")
print("-"*70)

comparison_data = []
for approach in ['liu_mixed', 'all_scramble', 'current_differential', 'hybrid']:
    if approach not in approach_tests:
        continue
    
    row = {'approach': approach}
    
    # Get OTC p-values
    if 'selectivity' in approach_tests[approach]:
        sel_df = approach_tests[approach]['selectivity']
        otc_sel = sel_df[sel_df['group'] == 'OTC']
        if len(otc_sel) > 0:
            row['selectivity_p'] = otc_sel['p'].values[0]
            row['selectivity_diff'] = otc_sel['difference'].values[0]
    
    if 'geometry' in approach_tests[approach]:
        geo_df = approach_tests[approach]['geometry']
        otc_geo = geo_df[geo_df['group'] == 'OTC']
        if len(otc_geo) > 0:
            row['geometry_p'] = otc_geo['p'].values[0]
            row['geometry_diff'] = otc_geo['difference'].values[0]
    
    comparison_data.append(row)

comp_df = pd.DataFrame(comparison_data)
print(comp_df.to_string(index=False))

# %% [markdown]
# ## Cell 15: Export Final Results with Chosen Approach

# %%
# Set the chosen approach here after reviewing results
CHOSEN_APPROACH = 'hybrid'  # Change this based on Cell 14 analysis

print(f"="*70)
print(f"EXPORTING FINAL RESULTS: {CHOSEN_APPROACH}")
print(f"="*70)

if CHOSEN_APPROACH == 'hybrid':
    final_results = hybrid_results
else:
    final_results = standard_results[CHOSEN_APPROACH]

# Build comprehensive results DataFrame
export_data = []

for _, row in final_results['selectivity'].iterrows():
    sid = row['subject']
    info = SUBJECTS[sid]
    
    # Get matching geometry and drift
    geom_match = final_results['geometry'][
        (final_results['geometry']['subject'] == sid) & 
        (final_results['geometry']['hemi'] == row['hemi']) &
        (final_results['geometry']['category'] == row['category'])
    ]
    
    drift_match = final_results['drift'][
        (final_results['drift']['subject'] == sid) & 
        (final_results['drift']['hemi'] == row['hemi']) &
        (final_results['drift']['category'] == row['category'])
    ]
    
    export_row = {
        'Subject': row['code'],
        'Group': info['group'],
        'Surgery_Side': info.get('surgery_side', 'na'),
        'Intact_Hemisphere': 'left' if info['hemi'] == 'l' else 'right',
        'Sex': info.get('sex', 'na'),
        'nonpt_hemi': row['hemi'].upper() if info['group'] == 'control' else 'na',
        'Category': row['category'].title(),
        'Category_Type': 'Bilateral' if row['category'] in BILATERAL else 'Unilateral',
        'age_1': info.get('age_1', np.nan),
        'age_2': info.get('age_2', np.nan),
        'yr_gap': info.get('age_2', 0) - info.get('age_1', 0) if info.get('age_1') and info.get('age_2') else np.nan,
        'Selectivity_Change': row['selectivity_change'],
        'Spatial_Relocation_mm': drift_match['spatial_drift_mm'].values[0] if len(drift_match) > 0 else np.nan,
        'Geometry_Preservation_6mm': geom_match['geometry_preservation'].values[0] if len(geom_match) > 0 else np.nan,
    }
    
    export_data.append(export_row)

export_df = pd.DataFrame(export_data)

# Add MDS shift (averaged across ROI categories for each measured category)
mds_df = final_results['mds']
mds_summary = mds_df.groupby(['subject', 'measured_category'])['mds_shift'].mean().reset_index()

for i, row in export_df.iterrows():
    sid = [s for s in SUBJECTS if SUBJECTS[s]['code'] == row['Subject']][0]
    cat = row['Category'].lower()
    
    mds_match = mds_summary[
        (mds_summary['subject'] == sid) & 
        (mds_summary['measured_category'] == cat)
    ]
    
    if len(mds_match) > 0:
        export_df.loc[i, 'MDS_Shift'] = mds_match['mds_shift'].values[0]

# Save
output_file = OUTPUT_DIR / 'results_final_corrected.csv'
export_df.to_csv(output_file, index=False)
print(f"\n✓ Saved to: {output_file}")
print(f"  Shape: {export_df.shape}")

# Final summary
print("\n" + "-"*70)
print("FINAL SUMMARY BY GROUP AND CATEGORY:")
print("-"*70)

print(f"\n{'Measure':<25} {'Group':<10} {'Face':<8} {'Word':<8} {'Object':<8} {'House':<8}")
print("-"*70)

for measure in ['Selectivity_Change', 'Geometry_Preservation_6mm', 'Spatial_Relocation_mm', 'MDS_Shift']:
    for group in ['OTC', 'nonOTC', 'control']:
        gd = export_df[export_df['Group'] == group]
        vals = []
        for cat in ['Face', 'Word', 'Object', 'House']:
            cd = gd[gd['Category'] == cat][measure]
            vals.append(f"{cd.mean():.2f}" if len(cd) > 0 and cd.notna().any() else "--")
        print(f"{measure:<25} {group:<10} {vals[0]:<8} {vals[1]:<8} {vals[2]:<8} {vals[3]:<8}")
    print()

print("\n" + "="*70)
print("ANALYSIS COMPLETE")
print("="*70)

✓ Configuration loaded
  Excluding: ['sub-025', 'sub-027', 'sub-045', 'sub-072']
✓ Loaded 20 subjects (after exclusions)
  OTC: 6
  nonOTC: 7
  control: 7
✓ ROI extraction functions defined
✓ RSA metric functions defined
EXTRACTING ROIs WITH ALL CONTRAST SCHEMES

liu_mixed...
  ✓ 20 subjects, 108 ROIs

all_scramble...
  ✓ 20 subjects, 108 ROIs

all_vs_others...
  ✓ 20 subjects, 108 ROIs

current_differential...
  ✓ 20 subjects, 108 ROIs

ROI YIELD COMPARISON BY CATEGORY

ROI counts by scheme and category:
------------------------------------------------------------
category              face  house  object  word
scheme                                         
all_scramble            27     27      27    26
all_vs_others           27     26      27    24
current_differential    27     27      27    25
liu_mixed               27     27      27    21


Mean voxel count by scheme and category:
------------------------------------------------------------
category                face   house