# Longitudinal RSA Analysis: Bilateral vs Unilateral Visual Categories
## Streamlined Version - VOTC Resection Study

In [22]:
"""
LONGITUDINAL RSA ANALYSIS - CLEAN VERSION
Bilateral vs Unilateral Visual Categories in VOTC Resection Patients

Analysis Pipeline:
1. Setup & Configuration
2. ROI Extraction (functional clusters)
3. RSA Analysis (RDMs + Liu distinctiveness)
4. Spatial Analysis (drift + hemisphere effects)
5. Group Comparisons & Statistics
6. Visualization
"""

# ============================================================================
# CELL 1: SETUP & CONFIGURATION
# ============================================================================

import numpy as np
import nibabel as nib
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.ndimage import center_of_mass, label
import pandas as pd
from scipy.stats import pearsonr, mannwhitneyu, ttest_ind
import seaborn as sns
from matplotlib.patches import Circle
import warnings
warnings.filterwarnings('ignore')

# Configuration
BASE_DIR = Path("/user_data/csimmon2/long_pt")
OUTPUT_DIR = BASE_DIR / "analyses" / "rsa_corrected"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

CSV_FILE = Path('/user_data/csimmon2/git_repos/long_pt/long_pt_sub_info.csv')
df = pd.read_csv(CSV_FILE)

# Session overrides
SESSION_START = {'sub-010': 2, 'sub-018': 2, 'sub-068': 2}

# Category mappings
COPE_MAP = {'face': 1, 'word': 12, 'object': 3, 'house': 2}
BILATERAL_CATEGORIES = ['object', 'house']
UNILATERAL_CATEGORIES = ['face', 'word']

def load_subjects_by_group(group_filter=None, patient_only=True):
    """Load subjects dynamically from CSV"""
    filtered_df = df.copy()
    
    if patient_only is True:
        filtered_df = filtered_df[filtered_df['patient'] == 1]
    elif patient_only is False:
        filtered_df = filtered_df[filtered_df['patient'] == 0]
    
    if group_filter:
        if isinstance(group_filter, str):
            group_filter = [group_filter]
        filtered_df = filtered_df[filtered_df['group'].isin(group_filter)]
    
    subjects = {}
    
    for _, row in filtered_df.iterrows():
        subject_id = row['sub']
        
        subj_dir = BASE_DIR / subject_id
        if not subj_dir.exists():
            continue
            
        sessions = []
        for ses_dir in subj_dir.glob('ses-*'):
            if ses_dir.is_dir():
                sessions.append(ses_dir.name.replace('ses-', ''))
        
        if not sessions:
            continue
            
        sessions = sorted(sessions, key=lambda x: int(x))
        start_session = SESSION_START.get(subject_id, 1)
        available_sessions = [s for s in sessions if int(s) >= start_session]
        
        if not available_sessions:
            continue
            
        hemisphere_full = row.get('intact_hemi', 'left') if pd.notna(row.get('intact_hemi', None)) else 'left'
        hemisphere = 'l' if hemisphere_full == 'left' else 'r'
        
        subjects[subject_id] = {
            'code': f"{row['group']}{subject_id.split('-')[1]}",
            'sessions': available_sessions,
            'hemi': hemisphere,
            'group': row['group'],
            'patient_status': 'patient' if row['patient'] == 1 else 'control',
            'age_1': row['age_1'] if pd.notna(row['age_1']) else None
        }
    
    return subjects

# Load all subjects
ALL_PATIENTS = load_subjects_by_group(group_filter=None, patient_only=True)
ALL_CONTROLS = load_subjects_by_group(group_filter=None, patient_only=False)
ANALYSIS_SUBJECTS = {**ALL_PATIENTS, **ALL_CONTROLS}

print(f"✓ Loaded {len(ANALYSIS_SUBJECTS)} subjects")
print(f"  Patients: {len(ALL_PATIENTS)}, Controls: {len(ALL_CONTROLS)}")


✓ Loaded 25 subjects
  Patients: 16, Controls: 9


In [2]:
# CELL 2: ROI Extraction

def extract_rois(subject_id, subjects_dict, threshold_z=2.3):
    """Extract functional cluster ROIs across all sessions"""
    
    if subject_id not in subjects_dict:
        return {}
        
    info = subjects_dict[subject_id]
    code = info['code']
    hemi = info['hemi']
    sessions = info['sessions']
    first_session = sessions[0]
    
    print(f"{code} - Extracting ROIs [{info['group']} {info['patient_status']}, hemi={hemi}]")
    
    all_results = {}
    
    for category, cope_num in COPE_MAP.items():
        all_results[category] = {}
        
        # Load category-specific mask
        mask_file = BASE_DIR / subject_id / f'ses-{first_session}' / 'ROIs' / f'{hemi}_{category}_searchmask.nii.gz'
        if not mask_file.exists():
            print(f"  ⚠️  {category}: mask not found")
            continue
        
        mask = nib.load(mask_file).get_fdata() > 0
        affine = nib.load(mask_file).affine
        
        # Process each session
        for session in sessions:
            feat_dir = BASE_DIR / subject_id / f'ses-{session}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
            
            zstat_file = 'zstat1.nii.gz' if session == first_session else f'zstat1_ses{first_session}.nii.gz'
            cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / zstat_file
            
            if not cope_file.exists():
                continue
            
            # Load functional activation
            zstat = nib.load(cope_file).get_fdata()
            suprathresh = (zstat > threshold_z) & mask
            
            if suprathresh.sum() < 50:
                continue
            
            # Find largest cluster
            labeled, n_clusters = label(suprathresh)
            if n_clusters == 0:
                continue
                
            cluster_sizes = [(labeled == i).sum() for i in range(1, n_clusters + 1)]
            largest_idx = np.argmax(cluster_sizes) + 1
            roi_mask = (labeled == largest_idx)
            
            # Extract metrics
            peak_idx = np.unravel_index(np.argmax(zstat * roi_mask), zstat.shape)
            peak_z = zstat[peak_idx]
            centroid = nib.affines.apply_affine(affine, center_of_mass(roi_mask))
            
            all_results[category][session] = {
                'n_voxels': cluster_sizes[largest_idx - 1],
                'peak_z': peak_z,
                'centroid': centroid,
                'roi_mask': roi_mask
            }
    
    return all_results

print("\nEXTRACTING FUNCTIONAL ROIs")
print("="*70)

# Extract patient ROIs
functional_rois = {}
for subject_id in ALL_PATIENTS.keys():
    try:
        functional_rois[subject_id] = extract_rois(subject_id, ANALYSIS_SUBJECTS, threshold_z=2.3)
    except Exception as e:
        print(f"❌ {subject_id} failed: {e}")
        functional_rois[subject_id] = {}

# Extract control ROIs - RIGHT hemisphere
for subject_id in ALL_CONTROLS.keys():
    try:
        functional_rois[subject_id] = extract_rois(subject_id, ANALYSIS_SUBJECTS, threshold_z=2.3)
    except Exception as e:
        print(f"❌ {subject_id} failed: {e}")
        functional_rois[subject_id] = {}

print(f"\n✓ Extracted {len(functional_rois)} subjects")

# Extract control ROIs - LEFT hemisphere
print("\nEXTRACTING CONTROLS LEFT HEMISPHERE")
print("="*70)

controls_left_functional = {}
for subject_id in ALL_CONTROLS.keys():
    temp_subjects = {subject_id: {**ANALYSIS_SUBJECTS[subject_id], 'hemi': 'l'}}
    try:
        controls_left_functional[subject_id] = extract_rois(subject_id, temp_subjects, threshold_z=2.3)
    except Exception as e:
        print(f"❌ {subject_id} failed: {e}")
        controls_left_functional[subject_id] = {}

print(f"\n✓ Extracted left hemisphere for {len(controls_left_functional)} controls")



EXTRACTING FUNCTIONAL ROIs
OTC004 - Extracting ROIs [OTC patient, hemi=l]
nonOTC007 - Extracting ROIs [nonOTC patient, hemi=r]
OTC008 - Extracting ROIs [OTC patient, hemi=l]
OTC010 - Extracting ROIs [OTC patient, hemi=r]
OTC017 - Extracting ROIs [OTC patient, hemi=r]
OTC021 - Extracting ROIs [OTC patient, hemi=r]
nonOTC045 - Extracting ROIs [nonOTC patient, hemi=r]
nonOTC047 - Extracting ROIs [nonOTC patient, hemi=l]
nonOTC049 - Extracting ROIs [nonOTC patient, hemi=l]
nonOTC070 - Extracting ROIs [nonOTC patient, hemi=r]
nonOTC072 - Extracting ROIs [nonOTC patient, hemi=l]
nonOTC073 - Extracting ROIs [nonOTC patient, hemi=l]
OTC079 - Extracting ROIs [OTC patient, hemi=r]
nonOTC081 - Extracting ROIs [nonOTC patient, hemi=r]
nonOTC086 - Extracting ROIs [nonOTC patient, hemi=l]
OTC108 - Extracting ROIs [OTC patient, hemi=r]
  ⚠️  face: mask not found
  ⚠️  word: mask not found
  ⚠️  object: mask not found
  ⚠️  house: mask not found
control018 - Extracting ROIs [control control, hemi=r]


In [3]:
# CELL 3: RSA Analysis
def create_sphere(peak_coord, affine, brain_shape, radius=6):
    """Create 6mm sphere around peak"""
    grid_coords = np.array(np.meshgrid(
        np.arange(brain_shape[0]), 
        np.arange(brain_shape[1]), 
        np.arange(brain_shape[2]),
        indexing='ij'
    )).reshape(3, -1).T
    
    grid_world = nib.affines.apply_affine(affine, grid_coords)
    distances = np.linalg.norm(grid_world - peak_coord, axis=1)
    
    mask_3d = np.zeros(brain_shape, dtype=bool)
    within = grid_coords[distances <= radius]
    for coord in within:
        mask_3d[coord[0], coord[1], coord[2]] = True
    
    return mask_3d

def extract_betas(subject_id, session, sphere_mask, category_copes):
    """Extract beta patterns from sphere"""
    info = ANALYSIS_SUBJECTS[subject_id]
    first_session = info['sessions'][0]
    
    feat_dir = BASE_DIR / subject_id / f'ses-{session}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
    
    beta_patterns = []
    valid_categories = []
    
    for category, cope_num in category_copes.items():
        if session == first_session:
            cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / 'cope1.nii.gz'
        else:
            cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / f'cope1_ses{first_session}.nii.gz'
        
        if not cope_file.exists():
            continue
        
        cope_data = nib.load(cope_file).get_fdata()
        roi_betas = cope_data[sphere_mask]
        roi_betas = roi_betas[np.isfinite(roi_betas)]
        
        if len(roi_betas) > 0:
            beta_patterns.append(roi_betas)
            valid_categories.append(category)
    
    if len(beta_patterns) == 0:
        return None, None
    
    min_voxels = min(len(b) for b in beta_patterns)
    beta_patterns = [b[:min_voxels] for b in beta_patterns]
    beta_matrix = np.column_stack(beta_patterns)
    
    return beta_matrix, valid_categories

def compute_rdm(beta_matrix, fisher_transform=True):
    """Compute RDM from beta patterns"""
    correlation_matrix = np.corrcoef(beta_matrix.T)
    rdm = 1 - correlation_matrix
    
    if fisher_transform:
        correlation_matrix_fisher = np.arctanh(np.clip(correlation_matrix, -0.999, 0.999))
        return rdm, correlation_matrix_fisher
    else:
        return rdm, correlation_matrix

def extract_rdms(functional_results, analysis_subjects):
    """Extract all RDMs from 6mm spheres"""
    all_rdms = {}
    
    for subject_id in analysis_subjects.keys():
        if subject_id not in functional_results:
            continue
            
        info = analysis_subjects[subject_id]
        code = info['code']
        sessions = info['sessions']
        first_session = sessions[0]
        
        ref_file = BASE_DIR / subject_id / f'ses-{first_session}' / 'ROIs' / f"{info['hemi']}_face_searchmask.nii.gz"
        if not ref_file.exists():
            continue
            
        ref_img = nib.load(ref_file)
        affine = ref_img.affine
        brain_shape = ref_img.shape
        
        print(f"{code}: RSA Analysis")
        
        all_rdms[subject_id] = {}
        
        for roi_name in COPE_MAP.keys():
            if roi_name not in functional_results[subject_id]:
                continue
            
            all_rdms[subject_id][roi_name] = {
                'rdms': {},
                'correlation_matrices': {},
                'beta_patterns': {},
                'valid_categories': None,
                'session_peaks': {},
                'session_n_voxels': {}
            }
            
            for session in sessions:
                if session not in functional_results[subject_id][roi_name]:
                    continue
                
                peak = functional_results[subject_id][roi_name][session]['centroid']
                sphere_mask = create_sphere(peak, affine, brain_shape, radius=6)
                n_voxels = sphere_mask.sum()
                
                all_rdms[subject_id][roi_name]['session_peaks'][session] = peak
                all_rdms[subject_id][roi_name]['session_n_voxels'][session] = n_voxels
                
                beta_matrix, valid_cats = extract_betas(subject_id, session, sphere_mask, COPE_MAP)
                
                if beta_matrix is None:
                    continue
                
                rdm, corr_matrix_fisher = compute_rdm(beta_matrix, fisher_transform=True)
                
                all_rdms[subject_id][roi_name]['rdms'][session] = rdm
                all_rdms[subject_id][roi_name]['correlation_matrices'][session] = corr_matrix_fisher
                all_rdms[subject_id][roi_name]['beta_patterns'][session] = beta_matrix
                all_rdms[subject_id][roi_name]['valid_categories'] = valid_cats
    
    return all_rdms

def compute_liu_metrics(all_rdms, analysis_subjects):
    """Compute Liu's distinctiveness"""
    distinctiveness_results = {}
    roi_preferred = {'face': 'face', 'word': 'word', 'object': 'object', 'house': 'house'}
    
    for subject_id, categories in all_rdms.items():
        if subject_id not in analysis_subjects:
            continue
            
        distinctiveness_results[subject_id] = {}
        
        for roi_name, roi_data in categories.items():
            if not roi_data['correlation_matrices']:
                continue
            
            valid_cats = roi_data['valid_categories']
            if valid_cats is None or len(valid_cats) < 4:
                continue
            
            preferred_cat = roi_preferred[roi_name]
            if preferred_cat not in valid_cats:
                continue
            
            pref_idx = valid_cats.index(preferred_cat)
            nonpref_indices = [i for i, cat in enumerate(valid_cats) if cat != preferred_cat]
            
            distinctiveness_results[subject_id][roi_name] = {}
            
            for session, corr_matrix in roi_data['correlation_matrices'].items():
                pref_vs_nonpref = corr_matrix[pref_idx, nonpref_indices]
                mean_corr = np.mean(pref_vs_nonpref)
                
                distinctiveness_results[subject_id][roi_name][session] = {
                    'liu_distinctiveness': mean_corr,
                    'individual_correlations': pref_vs_nonpref
                }
    
    return distinctiveness_results

print("\nEXTRACTING RSA DATA")
print("="*70)

# Main analysis
all_rdms = extract_rdms(functional_rois, ANALYSIS_SUBJECTS)
liu_distinctiveness = compute_liu_metrics(all_rdms, ANALYSIS_SUBJECTS)

# Controls left hemisphere
print("\nCONTROLS LEFT HEMISPHERE RSA")
print("="*70)
controls_left_rdms = extract_rdms(controls_left_functional, ALL_CONTROLS)
controls_left_distinctiveness = compute_liu_metrics(controls_left_rdms, ALL_CONTROLS)

print("\n✓ RSA analysis complete!")


EXTRACTING RSA DATA
OTC004: RSA Analysis
nonOTC007: RSA Analysis
OTC008: RSA Analysis
OTC010: RSA Analysis
OTC017: RSA Analysis
OTC021: RSA Analysis
nonOTC045: RSA Analysis
nonOTC047: RSA Analysis
nonOTC049: RSA Analysis
nonOTC070: RSA Analysis
nonOTC072: RSA Analysis
nonOTC073: RSA Analysis
OTC079: RSA Analysis
nonOTC081: RSA Analysis
nonOTC086: RSA Analysis
control018: RSA Analysis
control022: RSA Analysis
control025: RSA Analysis
control027: RSA Analysis
control052: RSA Analysis
control058: RSA Analysis
control062: RSA Analysis
control064: RSA Analysis
control068: RSA Analysis

CONTROLS LEFT HEMISPHERE RSA
control018: RSA Analysis
control022: RSA Analysis
control025: RSA Analysis
control027: RSA Analysis
control052: RSA Analysis
control058: RSA Analysis
control062: RSA Analysis
control064: RSA Analysis
control068: RSA Analysis

✓ RSA analysis complete!


In [10]:
# CELL 4: SPATIAL ANALYSIS

BILATERAL_CATEGORIES = ['object', 'house']
UNILATERAL_CATEGORIES = ['face', 'word']

def get_bootstrapped_error_radius(pair_peaks, n_bootstraps=1000):
    """Calculate bootstrapped measurement error radius"""
    if not pair_peaks or len(pair_peaks) < 2:
        return 1.0
    
    data = np.array([p['coord'][:2] for p in pair_peaks])
    
    def stat_func(coords):
        if len(np.unique(coords[:, 0])) < 2 or len(np.unique(coords[:, 1])) < 2:
            return 0.0
        return np.sqrt(np.std(coords[:, 0])**2 + np.std(coords[:, 1])**2)
    
    bootstrapped_stats = [stat_func(data[np.random.choice(len(data), len(data), replace=True)]) 
                          for _ in range(n_bootstraps)]
    
    final_radius = np.mean(bootstrapped_stats)
    return final_radius if not np.isnan(final_radius) and final_radius > 0 else stat_func(data)

def calc_error_radii(functional_results, analysis_subjects):
    """Calculate bootstrapped measurement error radii"""
    radii = {}
    
    for subject_id in analysis_subjects.keys():
        if subject_id not in functional_results:
            continue
            
        info = analysis_subjects[subject_id]
        radii[subject_id] = {}
        
        for category, sessions_data in functional_results[subject_id].items():
            if len(sessions_data) < 2:
                radii[subject_id][category] = 1.0
                continue
            
            pair_peaks = [{'coord': data['centroid'], 'session': session} 
                         for session, data in sessions_data.items()]
            
            radius = get_bootstrapped_error_radius(pair_peaks)
            radii[subject_id][category] = radius
    
    return radii

def calc_drift(functional_results, radii, analysis_subjects):
    """Calculate spatial drift between sessions"""
    drift_results = {}
    
    for subject_id, categories in functional_results.items():
        if subject_id not in analysis_subjects:
            continue
            
        info = analysis_subjects[subject_id]
        drift_results[subject_id] = {}
        
        for category, sessions_data in categories.items():
            if len(sessions_data) < 2:
                continue
            
            sessions = sorted(sessions_data.keys())
            baseline_session = sessions[0]
            baseline_centroid = sessions_data[baseline_session]['centroid']
            error_radius = radii[subject_id].get(category, 1.0)
            
            drift_results[subject_id][category] = {
                'baseline_session': baseline_session,
                'baseline_centroid': baseline_centroid,
                'error_radius': error_radius,
                'from_baseline_drift': []
            }
            
            for session in sessions[1:]:
                current_centroid = sessions_data[session]['centroid']
                drift_distance = np.linalg.norm(current_centroid - baseline_centroid)
                
                drift_results[subject_id][category]['from_baseline_drift'].append({
                    'session': session,
                    'distance_mm': drift_distance,
                    'relative_to_error': drift_distance / error_radius
                })
    
    return drift_results

def calc_hemisphere_effects(drift_data, distinctiveness_data, analysis_subjects, 
                           controls_left_drift=None, controls_left_distinct=None):
    """Calculate hemisphere-specific effects including controls both hemispheres"""
    
    # Updated skip list - removed control068
    subjects_to_skip = ['OTC079', 'OTC108']
    
    table_data = []
    
    for subject_id in analysis_subjects.keys():
        info = analysis_subjects[subject_id]
        code = info['code']
        
        if code in subjects_to_skip:
            continue
        
        # Process controls in BOTH hemispheres
        if info['patient_status'] == 'control' and controls_left_drift:
            for hemi_suffix, hemi_label, drift_source, distinct_source in [
                ('_R', 'r', drift_data.get(subject_id, {}), distinctiveness_data.get(subject_id, {})),
                ('_L', 'l', controls_left_drift.get(subject_id, {}), controls_left_distinct.get(subject_id, {}) if controls_left_distinct else {})
            ]:
                spatial_data = {}
                repr_data = {}
                
                # Extract spatial drift
                for category, drift_info in drift_source.items():
                    if drift_info.get('from_baseline_drift'):
                        spatial_data[category] = np.mean([d['distance_mm'] for d in drift_info['from_baseline_drift']])
                
                # Extract representational change
                for category, sessions in distinct_source.items():
                    session_keys = sorted(sessions.keys())
                    if len(session_keys) >= 2:
                        baseline = sessions[session_keys[0]]['liu_distinctiveness']
                        final = sessions[session_keys[-1]]['liu_distinctiveness']
                        repr_data[category] = abs(final - baseline)
                
                # Add rows
                for category in COPE_MAP.keys():
                    if category in spatial_data:
                        table_data.append({
                            'Subject': code + hemi_suffix,
                            'Group': info['group'],
                            'Status': info['patient_status'],
                            'Hemisphere': hemi_label,
                            'Category': category.title(),
                            'Category_Type': 'Bilateral' if category in BILATERAL_CATEGORIES else 'Unilateral',
                            'Spatial_Drift_mm': round(spatial_data[category], 2),
                            'Representational_Change': round(repr_data.get(category, 0), 3),
                            'Sessions': len(analysis_subjects[subject_id]['sessions'])
                        })
        
        # Process patients (single hemisphere)
        else:
            spatial_data = {}
            repr_data = {}
            hemi_label = info['hemi']
            
            if subject_id in drift_data:
                for category, drift_info in drift_data[subject_id].items():
                    if drift_info.get('from_baseline_drift'):
                        spatial_data[category] = np.mean([d['distance_mm'] for d in drift_info['from_baseline_drift']])
            
            if subject_id in distinctiveness_data:
                for category, sessions in distinctiveness_data[subject_id].items():
                    session_keys = sorted(sessions.keys())
                    if len(session_keys) >= 2:
                        baseline = sessions[session_keys[0]]['liu_distinctiveness']
                        final = sessions[session_keys[-1]]['liu_distinctiveness']
                        repr_data[category] = abs(final - baseline)
            
            for category in COPE_MAP.keys():
                if category in spatial_data:
                    table_data.append({
                        'Subject': code,
                        'Group': info['group'],
                        'Status': info['patient_status'],
                        'Hemisphere': hemi_label,
                        'Category': category.title(),
                        'Category_Type': 'Bilateral' if category in BILATERAL_CATEGORIES else 'Unilateral',
                        'Spatial_Drift_mm': round(spatial_data[category], 2),
                        'Representational_Change': round(repr_data.get(category, 0), 3),
                        'Sessions': len(analysis_subjects[subject_id]['sessions'])
                    })
    
    return pd.DataFrame(table_data)

print("\nCALCULATING SPATIAL METRICS")
print("="*70)

# Calculate for main analysis
error_radii = calc_error_radii(functional_rois, ANALYSIS_SUBJECTS)
drift_data = calc_drift(functional_rois, error_radii, ANALYSIS_SUBJECTS)

# Calculate for controls left
error_radii_left = calc_error_radii(controls_left_functional, ALL_CONTROLS)
controls_left_drift = calc_drift(controls_left_functional, error_radii_left, ALL_CONTROLS)

# Create comprehensive results table
results_table = calc_hemisphere_effects(drift_data, liu_distinctiveness, ANALYSIS_SUBJECTS,
                                       controls_left_drift, controls_left_distinctiveness)

print(f"\n✓ Analysis complete: {len(results_table)} data points")




CALCULATING SPATIAL METRICS

✓ Analysis complete: 128 data points


In [11]:
# CELL 5: MAIN ANALYSIS

def analyze_groups(results_table):
    """Three-group comparison with controls hemisphere breakdown"""
    
    print("THREE-GROUP COMPARISON: OTC vs nonOTC vs Controls")
    print("="*70)
    
    clean_data = results_table[results_table['Category_Type'] != 'Summary'].copy()
    
    # Average controls across hemispheres for main comparison
    control_data = clean_data[clean_data['Status'] == 'control'].copy()
    control_data['Subject_Base'] = control_data['Subject'].str.replace('_L|_R', '', regex=True)
    control_averaged = control_data.groupby(['Subject_Base', 'Category', 'Category_Type']).agg({
        'Spatial_Drift_mm': 'mean',
        'Representational_Change': 'mean'
    }).reset_index()
    
    patient_data = clean_data[clean_data['Status'] == 'patient']
    
    # Main comparison
    otc = patient_data[patient_data['Group'] == 'OTC']
    nonotc = patient_data[patient_data['Group'] == 'nonOTC']
    controls = control_averaged
    
    print(f"\nREPRESENTATIONAL CHANGE:")
    print(f"{'Group':<15} {'Bilateral':<12} {'Unilateral':<12} {'Difference':<12}")
    print("-" * 52)
    
    group_results = {}
    
    for name, data in [('OTC', otc), ('nonOTC', nonotc), ('Controls', controls)]:
        bil = data[data['Category_Type'] == 'Bilateral']['Representational_Change'].mean()
        uni = data[data['Category_Type'] == 'Unilateral']['Representational_Change'].mean()
        diff = bil - uni
        print(f"{name:<15} {bil:<12.3f} {uni:<12.3f} {diff:<12.3f}")
        group_results[name] = {'bilateral_repr': bil, 'unilateral_repr': uni, 'repr_difference': diff}
    
    # Controls hemisphere breakdown
    print(f"\n  Controls by hemisphere:")
    for hemi, hemi_label in [('l', 'Left'), ('r', 'Right')]:
        hemi_data = control_data[control_data['Hemisphere'] == hemi]
        bil = hemi_data[hemi_data['Category_Type'] == 'Bilateral']['Representational_Change'].mean()
        uni = hemi_data[hemi_data['Category_Type'] == 'Unilateral']['Representational_Change'].mean()
        diff = bil - uni
        print(f"    {hemi_label:<6} {bil:<12.3f} {uni:<12.3f} {diff:<12.3f}")
    
    return group_results

# Run analysis
print("\nMAIN GROUP ANALYSIS")
print("="*50)
final_results = analyze_groups(results_table)

print("\n✓ Analysis complete!")


MAIN GROUP ANALYSIS
THREE-GROUP COMPARISON: OTC vs nonOTC vs Controls

REPRESENTATIONAL CHANGE:
Group           Bilateral    Unilateral   Difference  
----------------------------------------------------
OTC             0.368        0.141        0.228       
nonOTC          0.140        0.139        0.001       
Controls        0.259        0.172        0.087       

  Controls by hemisphere:
    Left   0.276        0.205        0.071       
    Right  0.243        0.139        0.104       

✓ Analysis complete!


In [12]:
# ============================================================================
# VERIFICATION CELL: Check Results Match Previous Analysis
# ============================================================================

print("="*70)
print("VERIFICATION: CHECKING RESULTS TABLE")
print("="*70)

# 1. Basic counts
print(f"\n1. BASIC COUNTS:")
print(f"   Total rows: {len(results_table)}")
print(f"   Unique subjects: {results_table['Subject'].nunique()}")

# 2. Check control hemisphere suffixes
print(f"\n2. CONTROL HEMISPHERE SUFFIXES:")
controls = results_table[results_table['Status'] == 'control']
print(f"   Total control rows: {len(controls)}")
print(f"   Unique control entries: {controls['Subject'].nunique()}")

unique_controls = sorted(controls['Subject'].unique())
print(f"\n   Control subject names:")
for subj in unique_controls:
    n_rows = len(controls[controls['Subject'] == subj])
    print(f"     {subj}: {n_rows} categories")

has_L = any('_L' in str(s) for s in unique_controls)
has_R = any('_R' in str(s) for s in unique_controls)
print(f"\n   ✓ Has _L suffixes: {has_L}")
print(f"   ✓ Has _R suffixes: {has_R}")

if has_L and has_R:
    print("   ✓✓ CONTROLS HAVE BOTH HEMISPHERES!")
else:
    print("   ⚠️  WARNING: Missing hemisphere suffixes")

# 3. Expected vs actual counts
print(f"\n3. EXPECTED VS ACTUAL:")
n_otc = len([s for s in ANALYSIS_SUBJECTS.values() if s['group'] == 'OTC' and s['code'] not in ['OTC079', 'OTC108']])
n_nonotc = len([s for s in ANALYSIS_SUBJECTS.values() if s['group'] == 'nonOTC'])
n_controls = len([s for s in ANALYSIS_SUBJECTS.values() if s['patient_status'] == 'control'])

expected_otc = n_otc * 4
expected_nonotc = n_nonotc * 4
expected_controls = n_controls * 2 * 4  # Both hemispheres

actual_otc = len(results_table[results_table['Group'] == 'OTC'])
actual_nonotc = len(results_table[results_table['Group'] == 'nonOTC'])
actual_controls = len(results_table[results_table['Status'] == 'control'])

print(f"   OTC:     Expected {expected_otc}, Got {actual_otc}")
print(f"   nonOTC:  Expected {expected_nonotc}, Got {actual_nonotc}")
print(f"   Controls: Expected {expected_controls}, Got {actual_controls}")
print(f"   TOTAL:   Expected {expected_otc + expected_nonotc + expected_controls}, Got {len(results_table)}")

# 4. Key summary statistics (compare to your previous output)
print(f"\n4. GROUP SUMMARY STATISTICS:")
print(f"\n{'Group':<12} {'Category':<12} {'Mean Drift':<12} {'Mean Change':<12}")
print("-"*50)

for group in ['OTC', 'nonOTC', 'control']:
    group_data = results_table[results_table['Group'] == group] if group != 'control' else results_table[results_table['Status'] == 'control']
    
    for cat_type in ['Bilateral', 'Unilateral']:
        cat_data = group_data[group_data['Category_Type'] == cat_type]
        if len(cat_data) > 0:
            mean_drift = cat_data['Spatial_Drift_mm'].mean()
            mean_change = cat_data['Representational_Change'].mean()
            print(f"{group:<12} {cat_type:<12} {mean_drift:<12.1f} {mean_change:<12.3f}")

# 5. Control hemisphere breakdown
print(f"\n5. CONTROLS BY HEMISPHERE:")
print(f"{'Hemisphere':<12} {'Bilateral':<12} {'Unilateral':<12} {'Difference':<12}")
print("-"*50)

controls_data = results_table[results_table['Status'] == 'control']
for hemi in ['l', 'r']:
    hemi_data = controls_data[controls_data['Hemisphere'] == hemi]
    bil = hemi_data[hemi_data['Category_Type'] == 'Bilateral']['Representational_Change'].mean()
    uni = hemi_data[hemi_data['Category_Type'] == 'Unilateral']['Representational_Change'].mean()
    diff = bil - uni
    hemi_label = 'Left' if hemi == 'l' else 'Right'
    print(f"{hemi_label:<12} {bil:<12.3f} {uni:<12.3f} {diff:<12.3f}")

# 6. Check for missing subjects
print(f"\n6. SUBJECT COVERAGE:")
print(f"   Expected subjects (excluding OTC079, OTC108):")

expected_subjects = []
for subject_id, info in ANALYSIS_SUBJECTS.items():
    if info['code'] not in ['OTC079', 'OTC108']:
        expected_subjects.append(info['code'])

actual_subjects = results_table['Subject'].str.replace('_L|_R', '', regex=True).unique()

print(f"   Expected: {sorted(expected_subjects)}")
print(f"   Got:      {sorted(actual_subjects)}")

missing = set(expected_subjects) - set(actual_subjects)
extra = set(actual_subjects) - set(expected_subjects)

if missing:
    print(f"   ⚠️  MISSING: {missing}")
if extra:
    print(f"   ⚠️  EXTRA: {extra}")
if not missing and not extra:
    print(f"   ✓✓ ALL SUBJECTS ACCOUNTED FOR!")

# 7. Compare to your previous key finding
print(f"\n7. KEY FINDING REPLICATION CHECK:")
print(f"   (Should match your previous analysis)")
print(f"\n{'Group':<15} {'Bilateral':<12} {'Unilateral':<12} {'Bil>Uni?':<10}")
print("-"*50)

# Average controls across hemispheres
control_data = results_table[results_table['Status'] == 'control'].copy()
control_data['Subject_Base'] = control_data['Subject'].str.replace('_L|_R', '', regex=True)
control_averaged = control_data.groupby(['Subject_Base', 'Category_Type']).agg({
    'Representational_Change': 'mean'
}).reset_index()

for group_name, group_filter in [('OTC', 'OTC'), ('nonOTC', 'nonOTC'), ('Controls', None)]:
    if group_filter:
        data = results_table[results_table['Group'] == group_filter]
    else:
        data = control_averaged
    
    bil = data[data['Category_Type'] == 'Bilateral']['Representational_Change'].mean()
    uni = data[data['Category_Type'] == 'Unilateral']['Representational_Change'].mean()
    status = '✓' if bil > uni else '✗'
    
    print(f"{group_name:<15} {bil:<12.3f} {uni:<12.3f} {status:<10}")

print("\n" + "="*70)
print("VERIFICATION COMPLETE")
print("="*70)
print("\nDoes this match your previous analysis?")
print("If YES -> we can save the pickle")
print("If NO  -> we need to debug what changed")

VERIFICATION: CHECKING RESULTS TABLE

1. BASIC COUNTS:
   Total rows: 128
   Unique subjects: 32

2. CONTROL HEMISPHERE SUFFIXES:
   Total control rows: 72
   Unique control entries: 18

   Control subject names:
     control018_L: 4 categories
     control018_R: 4 categories
     control022_L: 4 categories
     control022_R: 4 categories
     control025_L: 4 categories
     control025_R: 4 categories
     control027_L: 4 categories
     control027_R: 4 categories
     control052_L: 4 categories
     control052_R: 4 categories
     control058_L: 4 categories
     control058_R: 4 categories
     control062_L: 4 categories
     control062_R: 4 categories
     control064_L: 4 categories
     control064_R: 4 categories
     control068_L: 4 categories
     control068_R: 4 categories

   ✓ Has _L suffixes: True
   ✓ Has _R suffixes: True
   ✓✓ CONTROLS HAVE BOTH HEMISPHERES!

3. EXPECTED VS ACTUAL:
   OTC:     Expected 20, Got 20
   nonOTC:  Expected 36, Got 36
   Controls: Expected 72, Got 

In [13]:
# CELL 6: SAVE RESULTS

import pickle

# Save main results
results_to_save = {
    'functional_rois': functional_rois,
    'liu_distinctiveness': liu_distinctiveness,
    'drift_data': drift_data,
    'results_table': results_table,
    'controls_left_drift': controls_left_drift,
    'controls_left_distinctiveness': controls_left_distinctiveness
}

pickle_file = OUTPUT_DIR / "rsa_results.pkl"
with open(pickle_file, 'wb') as f:
    pickle.dump(results_to_save, f)

# Save CSV
results_table.to_csv(OUTPUT_DIR / "results_table.csv", index=False)

print(f"\n✓ Results saved to: {OUTPUT_DIR}")
print(f"  - rsa_results.pkl")
print(f"  - results_table.csv")

print("\n" + "="*70)
print("ANALYSIS COMPLETE")
print("="*70)


✓ Results saved to: /user_data/csimmon2/long_pt/analyses/rsa_corrected
  - rsa_results.pkl
  - results_table.csv

ANALYSIS COMPLETE
