In [None]:
# Longitudinal Category Selectivity Analysis
# OTC vs nonOTC vs Control - Bilateral vs Unilateral Categories


import pandas as pd
import numpy as np

# ============================================================
# CONFIGURATION
# ============================================================
FILE_PATH = '/user_data/csimmon2/git_repos/long_pt/results_final_corrected.csv'

MEASURES = {
    'Selectivity_Change': 'Selectivity',
    'Geometry_Preservation_6mm': 'Geometry',
    'Spatial_Relocation_mm': 'Spatial Drift',
    'MDS_Shift': 'MDS'
}

CATEGORIES = ['Face', 'Word', 'Object', 'House']

# ============================================================
# DATA LOADING & PROCESSING
# ============================================================
def load_and_process(filepath):
    """Load data and average controls across hemispheres."""
    df = pd.read_csv(filepath)
    
    controls = df[df['Group'] == 'control'].copy()
    patients = df[df['Group'] != 'control'].copy()
    
    numeric_cols = list(MEASURES.keys()) + ['age_1', 'yr_gap']
    
    # Average controls across L/R hemispheres
    controls_agg = controls.groupby(['Subject', 'Category'])[numeric_cols].mean().reset_index()
    meta = controls.groupby(['Subject', 'Category'])[['Group', 'Category_Type']].first().reset_index()
    controls_final = pd.merge(controls_agg, meta, on=['Subject', 'Category'])
    
    common_cols = ['Subject', 'Group', 'Category', 'Category_Type'] + numeric_cols
    return pd.concat([patients[common_cols], controls_final[common_cols]], ignore_index=True)


def compute_instability_gap(df, measure):
    """Compute Bilateral - Unilateral gap per subject."""
    df_filtered = df[df['Category'].isin(CATEGORIES)].copy()
    subj_means = df_filtered.groupby(['Subject', 'Group', 'Category_Type'])[measure].mean().reset_index()
    pivot = subj_means.pivot(index=['Subject', 'Group'], columns='Category_Type', values=measure).reset_index()
    
    if 'Bilateral' in pivot.columns and 'Unilateral' in pivot.columns:
        pivot['Gap'] = pivot['Bilateral'] - pivot['Unilateral']
        return pivot.dropna(subset=['Gap'])
    return pd.DataFrame()


# ============================================================
# BOOTSTRAP STATISTICS
# ============================================================
def bootstrap_compare(group_a, group_b, n_boot=100000, seed=42):
    """
    Non-parametric bootstrap comparison.
    Returns: dict with diff, ci_low, ci_high, prob_direction, is_sig
    """
    np.random.seed(seed)
    n_a, n_b = len(group_a), len(group_b)
    
    obs_diff = np.mean(group_a) - np.mean(group_b)
    
    sample_a = np.random.choice(group_a, size=(n_boot, n_a), replace=True)
    sample_b = np.random.choice(group_b, size=(n_boot, n_b), replace=True)
    boot_diffs = np.mean(sample_a, axis=1) - np.mean(sample_b, axis=1)
    
    ci_low, ci_high = np.percentile(boot_diffs, [2.5, 97.5])
    prob = np.mean(boot_diffs > 0) if obs_diff > 0 else np.mean(boot_diffs < 0)
    is_sig = (ci_low > 0 and ci_high > 0) or (ci_low < 0 and ci_high < 0)
    
    return {
        'mean_a': np.mean(group_a),
        'mean_b': np.mean(group_b),
        'diff': obs_diff,
        'ci_low': ci_low,
        'ci_high': ci_high,
        'prob_direction': prob,
        'is_significant': is_sig,
        'n_a': n_a,
        'n_b': n_b
    }


# ============================================================
# ANALYSIS FUNCTIONS
# ============================================================
def run_q1_instability_gap(df):
    """Q1: Test if OTC shows larger bilateral-unilateral gap than controls/nonOTC."""
    results = []
    
    for col, name in MEASURES.items():
        gaps = compute_instability_gap(df, col)
        
        otc = gaps[gaps['Group'] == 'OTC']['Gap'].values
        ctrl = gaps[gaps['Group'] == 'control']['Gap'].values
        nonotc = gaps[gaps['Group'] == 'nonOTC']['Gap'].values
        
        for comp_name, comp_data in [('vs Control', ctrl), ('vs nonOTC', nonotc)]:
            if len(otc) == 0 or len(comp_data) == 0:
                continue
            
            stats = bootstrap_compare(otc, comp_data)
            results.append({
                'Measure': name,
                'Comparison': comp_name,
                'OTC_Gap': stats['mean_a'],
                'Ref_Gap': stats['mean_b'],
                'Difference': stats['diff'],
                'CI_Low': stats['ci_low'],
                'CI_High': stats['ci_high'],
                'P_Direction': stats['prob_direction'],
                'Significant': '*' if stats['is_significant'] else ''
            })
    
    return pd.DataFrame(results)


def run_q2_category_specificity(df):
    """Q2: Test category-specific effects (OTC vs nonOTC)."""
    results = []
    
    for col, name in MEASURES.items():
        for cat in CATEGORIES:
            df_cat = df[df['Category'] == cat]
            
            otc = df_cat[df_cat['Group'] == 'OTC'][col].values
            nonotc = df_cat[df_cat['Group'] == 'nonOTC'][col].values
            
            if len(otc) == 0 or len(nonotc) == 0:
                continue
            
            stats = bootstrap_compare(otc, nonotc)
            cat_type = 'Bilateral' if cat in ['Object', 'House'] else 'Unilateral'
            
            results.append({
                'Measure': name,
                'Category': cat,
                'Category_Type': cat_type,
                'OTC': stats['mean_a'],
                'nonOTC': stats['mean_b'],
                'Difference': stats['diff'],
                'CI_Low': stats['ci_low'],
                'CI_High': stats['ci_high'],
                'P_Direction': stats['prob_direction'],
                'Significant': '*' if stats['is_significant'] else ''
            })
    
    return pd.DataFrame(results)


def print_results(q1_df, q2_df):
    """Pretty print results."""
    print("=" * 95)
    print("Q1: INSTABILITY GAP  [Mean(House,Object) - Mean(Face,Word)]")
    print("=" * 95)
    print(f"{'Measure':<14} {'Comparison':<12} {'OTC':>7} {'Ref':>7} {'Diff':>7}  {'95% CI':<18}  Sig")
    print("-" * 95)
    
    for _, row in q1_df.iterrows():
        print(f"{row['Measure']:<14} {row['Comparison']:<12} {row['OTC_Gap']:>7.3f} {row['Ref_Gap']:>7.3f} "
              f"{row['Difference']:>7.3f}  [{row['CI_Low']:>6.3f}, {row['CI_High']:>6.3f}]  {row['Significant']}")
    
    print("\n" + "=" * 95)
    print("Q2: CATEGORY SPECIFICITY (OTC vs nonOTC)")
    print("=" * 95)
    
    for measure in MEASURES.values():
        print(f"\n--- {measure} ---")
        measure_df = q2_df[q2_df['Measure'] == measure]
        for _, row in measure_df.iterrows():
            print(f"  {row['Category']:<8} ({row['Category_Type'][:3]})  "
                  f"OTC={row['OTC']:.3f}  nonOTC={row['nonOTC']:.3f}  "
                  f"diff={row['Difference']:>6.3f}  [{row['CI_Low']:>6.3f},{row['CI_High']:>6.3f}]  {row['Significant']}")


# ============================================================
# MAIN EXECUTION
# ============================================================
if __name__ == "__main__":
    # Load and process data
    df = load_and_process(FILE_PATH)
    print(f"Loaded {len(df)} rows\n")
    
    # Run analyses
    q1_results = run_q1_instability_gap(df)
    q2_results = run_q2_category_specificity(df)
    
    # Print results
    print_results(q1_results, q2_results)
    
    # Summary
    print("\n" + "=" * 95)
    print("SUMMARY")
    print("=" * 95)
    print("""
Q1 Instability Gap: Tests if bilateral categories show more instability 
   relative to unilateral categories in OTC patients.
   
Q2 Category Specificity: Tests which categories drive the OTC effect.
   Hypothesis: Object & House (bilateral, OTC-dependent) should show deficits.
   
* = 95% CI excludes zero (significant effect)
""")

Loaded 96 rows

Q1: INSTABILITY GAP  [Mean(House,Object) - Mean(Face,Word)]
Measure        Comparison       OTC     Ref    Diff  95% CI              Sig
-----------------------------------------------------------------------------------------------
Selectivity    vs Control     0.275   0.109   0.166  [ 0.019,  0.337]  *
Selectivity    vs nonOTC      0.275   0.008   0.267  [ 0.132,  0.428]  *
Geometry       vs Control    -0.289  -0.096  -0.193  [-0.340, -0.043]  *
Geometry       vs nonOTC     -0.289  -0.036  -0.253  [-0.404, -0.103]  *
Spatial Drift  vs Control    -5.535  -0.344  -5.191  [-13.376,  2.799]  
Spatial Drift  vs nonOTC     -5.535  -1.826  -3.709  [-11.513,  3.960]  
MDS            vs Control     0.077   0.020   0.056  [ 0.002,  0.116]  *
MDS            vs nonOTC      0.077   0.044   0.033  [-0.025,  0.095]  

Q2: CATEGORY SPECIFICITY (OTC vs nonOTC)

--- Selectivity ---
  Face     (Uni)  OTC=0.167  nonOTC=0.084  diff= 0.083  [-0.008, 0.179]  
  Word     (Uni)  OTC=0.120  no

In [9]:
# Q4: PAIRWISE CATEGORY COMPARISONS
# Run this cell AFTER analysis_final.py

# ============================================================
# Q4: PAIRWISE CATEGORY COMPARISONS
# ============================================================
def run_q4_pairwise_categories(df, group='OTC'):
    """Pairwise: each bilateral category vs each unilateral category."""
    results = []
    bilateral_cats = ['Object', 'House']
    unilateral_cats = ['Face', 'Word']
    group_df = df[df['Group'] == group]
    
    for col, name in MEASURES.items():
        for bil_cat in bilateral_cats:
            for uni_cat in unilateral_cats:
                bil_vals = group_df[group_df['Category'] == bil_cat][col].values
                uni_vals = group_df[group_df['Category'] == uni_cat][col].values
                if len(bil_vals) == 0 or len(uni_vals) == 0:
                    continue
                stats = bootstrap_compare(bil_vals, uni_vals)
                results.append({
                    'Measure': name,
                    'Comparison': f"{bil_cat} vs {uni_cat}",
                    'Bilateral_Cat': bil_cat,
                    'Unilateral_Cat': uni_cat,
                    'Bilateral': stats['mean_a'],
                    'Unilateral': stats['mean_b'],
                    'Difference': stats['diff'],
                    'CI_Low': stats['ci_low'],
                    'CI_High': stats['ci_high'],
                    'Significant': '*' if stats['is_significant'] else ''
                })
    return pd.DataFrame(results)

# Run
q4_results = run_q4_pairwise_categories(df, group='OTC')

# Print
print("=" * 95)
print("Q4: PAIRWISE CATEGORY COMPARISONS (within OTC)")
print("    Tests which bilateral-unilateral pairs drive the effect")
print("=" * 95)

for measure in MEASURES.values():
    print(f"\n--- {measure} ---")
    measure_df = q4_results[q4_results['Measure'] == measure]
    for _, row in measure_df.iterrows():
        print(f"  {row['Comparison']:<18}  {row['Bilateral']:.3f} vs {row['Unilateral']:.3f}  "
              f"diff={row['Difference']:>7.3f}  [{row['CI_Low']:>6.3f},{row['CI_High']:>6.3f}]  {row['Significant']}")

# Summary
print("\n" + "=" * 95)
print("SUMMARY: Significant comparisons per category")
print("=" * 95)
for cat in ['Object', 'House']:
    n = (q4_results[q4_results['Bilateral_Cat'] == cat]['Significant'] == '*').sum()
    print(f"  {cat:<8} (Bil): {n}/8 significant")
for cat in ['Face', 'Word']:
    n = (q4_results[q4_results['Unilateral_Cat'] == cat]['Significant'] == '*').sum()
    print(f"  {cat:<8} (Uni): {n}/8 significant")

Q4: PAIRWISE CATEGORY COMPARISONS (within OTC)
    Tests which bilateral-unilateral pairs drive the effect

--- Selectivity ---
  Object vs Face      0.490 vs 0.167  diff=  0.323  [ 0.067, 0.576]  *
  Object vs Word      0.490 vs 0.120  diff=  0.369  [ 0.115, 0.619]  *
  House vs Face       0.348 vs 0.167  diff=  0.181  [ 0.009, 0.333]  *
  House vs Word       0.348 vs 0.120  diff=  0.228  [ 0.060, 0.375]  *

--- Geometry ---
  Object vs Face      0.636 vs 0.751  diff= -0.115  [-0.332, 0.119]  
  Object vs Word      0.636 vs 0.674  diff= -0.039  [-0.246, 0.182]  
  House vs Face       0.212 vs 0.751  diff= -0.539  [-0.806,-0.235]  *
  House vs Word       0.212 vs 0.674  diff= -0.462  [-0.722,-0.166]  *

--- Spatial Drift ---
  Object vs Face      5.869 vs 6.543  diff= -0.673  [-5.198, 3.997]  
  Object vs Word      5.869 vs 18.430  diff=-12.560  [-24.284,-1.745]  *
  House vs Face       8.033 vs 6.543  diff=  1.490  [-5.032, 8.358]  
  House vs Word       8.033 vs 18.430  diff=-10.397 