# stripped_03_rsa+rdm..

In [None]:
# CELL 1: Setup
import pandas as pd
from pathlib import Path
import nibabel as nib
import numpy as np
from scipy.ndimage import label, center_of_mass
from scipy.stats import pearsonr, ttest_ind

CSV_FILE = Path('/user_data/csimmon2/git_repos/long_pt/long_pt_sub_info.csv')
RESULTS_CSV = '/user_data/csimmon2/git_repos/long_pt/B_analyses/results.csv'
OUTPUT_CSV = '/user_data/csimmon2/git_repos/long_pt/B_analyses/results_final.csv'

df = pd.read_csv(CSV_FILE)

BASE_DIR = Path("/user_data/csimmon2/long_pt")
SESSION_START = {'sub-010': 2, 'sub-018': 2, 'sub-068': 2}

COPE_MAP_DIFFERENTIAL = {
    'face': (10, 1),
    'word': (13, -1),
    'object': (3, 1),
    'house': (11, 1)
}


COPE_MAP_SCRAMBLE = {
    'face': (10, 1),
    'word': (12, 1),
    'object': (3, 1),
    'house': (11, 1)
}

def create_sphere(center_coord, affine, brain_shape, radius=6):
    grid_coords = np.array(np.meshgrid(
        np.arange(brain_shape[0]),
        np.arange(brain_shape[1]),
        np.arange(brain_shape[2]),
        indexing='ij'
    )).reshape(3, -1).T
    
    grid_world = nib.affines.apply_affine(affine, grid_coords)
    distances = np.linalg.norm(grid_world - center_coord, axis=1)
    
    mask_3d = np.zeros(brain_shape, dtype=bool)
    within = grid_coords[distances <= radius]
    for coord in within:
        mask_3d[coord[0], coord[1], coord[2]] = True
    
    return mask_3d

print("✓ Cell 1 complete")

✓ Cell 1 complete


In [2]:
# CELL 2: Load Subjects
def load_subjects_by_group(group_filter=None, patient_only=True):
    filtered_df = df.copy()
    
    if patient_only is True:
        filtered_df = filtered_df[filtered_df['patient'] == 1]
    elif patient_only is False:
        filtered_df = filtered_df[filtered_df['patient'] == 0]
    
    if group_filter:
        if isinstance(group_filter, str):
            group_filter = [group_filter]
        filtered_df = filtered_df[filtered_df['group'].isin(group_filter)]
    
    subjects = {}
    for _, row in filtered_df.iterrows():
        subject_id = row['sub']
        subj_dir = BASE_DIR / subject_id
        if not subj_dir.exists():
            continue
        
        sessions = sorted([d.name.replace('ses-', '') for d in subj_dir.glob('ses-*') if d.is_dir()], key=int)
        start_session = SESSION_START.get(subject_id, 1)
        sessions = [s for s in sessions if int(s) >= start_session]
        if not sessions:
            continue
        
        hemisphere = 'l' if row.get('intact_hemi', 'left') == 'left' else 'r'
        
        subjects[subject_id] = {
            'code': f"{row['group']}{subject_id.split('-')[1]}",
            'sessions': sessions,
            'hemi': hemisphere,
            'group': row['group'],
            'patient_status': 'patient' if row['patient'] == 1 else 'control',
            'surgery_side': row.get('SurgerySide', None)
        }
    return subjects

ALL_PATIENTS = load_subjects_by_group(patient_only=True)
ALL_CONTROLS = load_subjects_by_group(patient_only=False)
ANALYSIS_SUBJECTS = {**ALL_PATIENTS, **ALL_CONTROLS}

print(f"✓ Loaded {len(ANALYSIS_SUBJECTS)} subjects")
for g in ['OTC', 'nonOTC', 'control']:
    n = sum(1 for v in ANALYSIS_SUBJECTS.values() if v['group'] == g)
    print(f"  {g}: {n}")

✓ Loaded 25 subjects
  OTC: 7
  nonOTC: 9
  control: 9


In [3]:
# CELL 3: EXTRACTION - Top 20% ROI Extraction (Both Contrast Sets)
# ============================================================

def extract_top20_rois(subject_id, cope_map, percentile=80, min_cluster_size=20):
    """Extract ROIs using top 20% of voxels within search mask"""
    
    info = ANALYSIS_SUBJECTS[subject_id]
    roi_dir = BASE_DIR / subject_id / f'ses-{info["sessions"][0]}' / 'ROIs'
    if not roi_dir.exists(): 
        return {}
    
    all_results = {}
    first_session = info['sessions'][0]

    for hemi in ['l', 'r']:
        for category, (cope_num, multiplier) in cope_map.items():
            
            mask_file = roi_dir / f'{hemi}_{category}_searchmask.nii.gz'
            if not mask_file.exists(): 
                continue
            
            try:
                search_mask_img = nib.load(mask_file)
                search_mask = search_mask_img.get_fdata() > 0
                affine = search_mask_img.affine
            except: 
                continue
            
            hemi_key = f'{hemi}_{category}'
            all_results[hemi_key] = {}
            
            for session in info['sessions']:
                feat_dir = BASE_DIR / subject_id / f'ses-{session}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                z_name = 'zstat1.nii.gz' if session == first_session else f'zstat1_ses{first_session}.nii.gz'
                cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                
                if not cope_file.exists(): 
                    continue
                
                try:
                    z_full = nib.load(cope_file).get_fdata() * multiplier
                    pos_voxels = z_full[search_mask & (z_full > 0)]
                    
                    if len(pos_voxels) < min_cluster_size: 
                        continue
                    
                    dynamic_thresh = max(np.percentile(pos_voxels, percentile), 1.64)
                    
                    suprathresh = (z_full > dynamic_thresh) & search_mask
                    labeled, n_clusters = label(suprathresh)
                    
                    if n_clusters == 0: 
                        continue
                    
                    # Select largest cluster
                    best_idx, max_size = -1, 0
                    for i in range(1, n_clusters + 1):
                        size = np.sum(labeled == i)
                        if size > max_size:
                            max_size = size
                            best_idx = i
                    
                    if best_idx == -1 or max_size < min_cluster_size: 
                        continue
                    
                    roi_mask = (labeled == best_idx)
                    peak_idx = np.unravel_index(np.argmax(z_full * roi_mask), z_full.shape)
                    
                    all_results[hemi_key][session] = {
                        'n_voxels': int(np.sum(roi_mask)),
                        'peak_z': z_full[peak_idx],
                        'centroid': nib.affines.apply_affine(affine, center_of_mass(roi_mask)),
                        'threshold': dynamic_thresh
                    }
                except Exception as e:
                    print(f"Error {subject_id} {hemi_key} ses-{session}: {e}")
                    
    return all_results

# Extract for BOTH contrast sets
print("Extracting Top 20% ROIs - DIFFERENTIAL...")
top20_differential = {}
for sub in ANALYSIS_SUBJECTS:
    res = extract_top20_rois(sub, COPE_MAP_DIFFERENTIAL)
    if res: 
        top20_differential[sub] = res
print(f"✓ Differential: {len(top20_differential)} subjects")

print("\nExtracting Top 20% ROIs - SCRAMBLE...")
top20_scramble = {}
for sub in ANALYSIS_SUBJECTS:
    res = extract_top20_rois(sub, COPE_MAP_SCRAMBLE)
    if res: 
        top20_scramble[sub] = res
print(f"✓ Scramble: {len(top20_scramble)} subjects")

Extracting Top 20% ROIs - DIFFERENTIAL...
✓ Differential: 24 subjects

Extracting Top 20% ROIs - SCRAMBLE...
✓ Scramble: 24 subjects


# Average Activity

# Sum Selectivity

# Spatial Drift

# Geometry

In [None]:

# CELL 5: GEOMETRY PRESERVATION - RDM Stability (6mm sphere)
# ============================================================

def compute_geometry_preservation(functional_results, cope_map, subjects_dict, radius=6):
    """
    Compute Geometry Preservation (RDM Stability)
    - Dynamic sphere at each session's centroid
    - Correlation of 4-category RDM between T1 and T2
    - Lower values = more representational change
    """
    
    results = []
    
    for sid, rois in functional_results.items():
        info = subjects_dict.get(sid, {})
        if not info:
            continue
        
        first_session = info['sessions'][0]
        
        # Get reference image
        roi_dir = BASE_DIR / sid / f'ses-{first_session}' / 'ROIs'
        ref_file = None
        for cat in ['face', 'object', 'house', 'word']:
            for h in ['l', 'r']:
                test_file = roi_dir / f"{h}_{cat}_searchmask.nii.gz"
                if test_file.exists():
                    ref_file = test_file
                    break
            if ref_file:
                break
        
        if not ref_file:
            continue
            
        ref_img = nib.load(ref_file)
        affine = ref_img.affine
        brain_shape = ref_img.shape
        
        for roi_key, sessions_data in rois.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            first_ses, last_ses = sessions[0], sessions[-1]
            
            # Dynamic spheres at each session's centroid
            sphere_t1 = create_sphere(sessions_data[first_ses]['centroid'], affine, brain_shape, radius)
            sphere_t2 = create_sphere(sessions_data[last_ses]['centroid'], affine, brain_shape, radius)
            
            rdms = {}
            for ses, sphere in [(first_ses, sphere_t1), (last_ses, sphere_t2)]:
                feat_dir = BASE_DIR / sid / f'ses-{ses}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                
                patterns = []
                valid = True
                
                for cat in ['face', 'word', 'object', 'house']:
                    cope_num, mult = cope_map[cat]
                    z_name = 'zstat1.nii.gz' if ses == first_ses else f'zstat1_ses{first_session}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        valid = False
                        break
                    
                    data = nib.load(cope_file).get_fdata() * mult
                    pattern = data[sphere]
                    
                    if len(pattern) == 0 or not np.all(np.isfinite(pattern)):
                        valid = False
                        break
                    
                    patterns.append(pattern)
                
                if not valid or len(patterns) != 4:
                    continue
                
                try:
                    rdm = 1 - np.corrcoef(patterns)
                    rdms[ses] = rdm
                except:
                    continue
            
            if len(rdms) == 2:
                triu_idx = np.triu_indices(4, k=1)
                r, _ = pearsonr(rdms[first_ses][triu_idx], rdms[last_ses][triu_idx])
                
                hemi = roi_key.split('_')[0]
                category = roi_key.split('_')[1]
                
                results.append({
                    'subject': sid,
                    'code': info.get('code', sid),
                    'group': info.get('group', 'unknown'),
                    'hemi': hemi,
                    'category': category,
                    'category_type': 'Bilateral' if category in ['object', 'house'] else 'Unilateral',
                    'geometry_preservation': r
                })
    
    return pd.DataFrame(results)

# Compute for 6mm

print("Computing Geometry Preservation...")
geometry_results = {}
for radius in [6]:
    geometry_results[radius] = compute_geometry_preservation(
        top20_differential, COPE_MAP_DIFFERENTIAL, ANALYSIS_SUBJECTS, radius
    )
    print(f"  {radius}mm: {len(geometry_results[radius])} ROIs")
    
print("✓ Done")

Computing Geometry Preservation...
  6mm: 132 ROIs
✓ Done


In [None]:
# Print Results

if len(geometry_results[6]) > 0:
    df_geom = geometry_results[6]
    print("\nSummary by Group and Category Type:")
    # Calculate mean, standard deviation, and count for the correlation values
    summary = df_geom.groupby(['group', 'category_type'])['geometry_preservation'].agg(['mean', 'std', 'count'])
    print(summary)
else:
    print("\nNo results generated.")


Summary by Group and Category Type:
                           mean       std  count
group   category_type                           
OTC     Bilateral      0.423894  0.349007     12
        Unilateral     0.712803  0.197306     12
control Bilateral      0.659483  0.358840     36
        Unilateral     0.755123  0.212617     36
nonOTC  Bilateral      0.726002  0.257030     18
        Unilateral     0.761896  0.196306     18


# MDS

In [8]:
# CELL 6: MDS EMBEDDING SHIFT (Nordt Approach)
# ============================================================
from scipy.spatial.distance import squareform
from scipy.linalg import orthogonal_procrustes

def mds_2d(rdm):
    """Classical MDS to 2D"""
    n = rdm.shape[0]
    H = np.eye(n) - np.ones((n, n)) / n
    B = -0.5 * H @ (rdm ** 2) @ H
    eigvals, eigvecs = np.linalg.eigh(B)
    idx = np.argsort(eigvals)[::-1]
    eigvals = eigvals[idx]
    eigvecs = eigvecs[:, idx]
    coords = eigvecs[:, :2] * np.sqrt(np.maximum(eigvals[:2], 0))
    return coords

def compute_mds_shift(functional_results, cope_map, subjects_dict, radius=6):
    """
    Compute MDS Embedding Shift per category
    - RDM at T1 peak → MDS
    - RDM at T2 peak → MDS  
    - Procrustes align
    - Euclidean distance each category moved
    """
    
    results = []
    categories = ['face', 'word', 'object', 'house']
    
    for sid, rois in functional_results.items():
        info = subjects_dict.get(sid, {})
        if not info:
            continue
        
        first_session = info['sessions'][0]
        
        # Get reference image
        roi_dir = BASE_DIR / sid / f'ses-{first_session}' / 'ROIs'
        ref_file = None
        for cat in categories:
            for h in ['l', 'r']:
                test_file = roi_dir / f"{h}_{cat}_searchmask.nii.gz"
                if test_file.exists():
                    ref_file = test_file
                    break
            if ref_file:
                break
        
        if not ref_file:
            continue
            
        ref_img = nib.load(ref_file)
        affine = ref_img.affine
        brain_shape = ref_img.shape
        
        for roi_key, sessions_data in rois.items():
            sessions = sorted(sessions_data.keys())
            if len(sessions) < 2:
                continue
            
            first_ses, last_ses = sessions[0], sessions[-1]
            hemi = roi_key.split('_')[0]
            roi_category = roi_key.split('_')[1]
            
            # Spheres at each session's centroid
            sphere_t1 = create_sphere(sessions_data[first_ses]['centroid'], affine, brain_shape, radius)
            sphere_t2 = create_sphere(sessions_data[last_ses]['centroid'], affine, brain_shape, radius)
            
            # Build RDMs
            rdms = {}
            for ses, sphere in [(first_ses, sphere_t1), (last_ses, sphere_t2)]:
                feat_dir = BASE_DIR / sid / f'ses-{ses}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
                
                patterns = []
                valid = True
                
                for cat in categories:
                    cope_num, mult = cope_map[cat]
                    z_name = 'zstat1.nii.gz' if ses == first_ses else f'zstat1_ses{first_session}.nii.gz'
                    cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / z_name
                    
                    if not cope_file.exists():
                        valid = False
                        break
                    
                    data = nib.load(cope_file).get_fdata() * mult
                    pattern = data[sphere]
                    
                    if len(pattern) == 0 or not np.all(np.isfinite(pattern)):
                        valid = False
                        break
                    
                    patterns.append(pattern)
                
                if not valid or len(patterns) != 4:
                    continue
                
                try:
                    corr_matrix = np.corrcoef(patterns)
                    rdm = 1 - corr_matrix
                    rdms[ses] = rdm
                except:
                    continue
            
            if len(rdms) != 2:
                continue
            
            try:
                # MDS embedding
                coords_t1 = mds_2d(rdms[first_ses])
                coords_t2 = mds_2d(rdms[last_ses])
                
                # Procrustes align T1 to T2
                R, scale = orthogonal_procrustes(coords_t1, coords_t2)
                coords_t1_aligned = coords_t1 @ R
                
                # Euclidean distance each category moved
                for i, cat in enumerate(categories):
                    dist = np.linalg.norm(coords_t1_aligned[i] - coords_t2[i])
                    
                    results.append({
                        'subject': sid,
                        'code': info.get('code', sid),
                        'group': info.get('group', 'unknown'),
                        'hemi': hemi,
                        'roi_category': roi_category,
                        'measured_category': cat,
                        'category_type': 'Bilateral' if cat in ['object', 'house'] else 'Unilateral',
                        'mds_shift': dist
                    })
                    
            except Exception as e:
                print(f"Error {sid} {roi_key}: {e}")
                continue
    
    return pd.DataFrame(results)

# Compute
print("Computing MDS Embedding Shift...")
mds_results = {}
for radius in [6]:
    mds_results[radius] = compute_mds_shift(
        top20_differential, COPE_MAP_DIFFERENTIAL, ANALYSIS_SUBJECTS, radius
    )
    print(f"  {radius}mm: {len(mds_results[radius])} measurements")

if len(mds_results[6]) > 0:
    df = mds_results[6]
    print("\n6mm Summary by Category Type:")
    print(df.groupby('category_type')['mds_shift'].agg(['mean', 'std', 'count']))

print("✓ Done")

# To view the summary for the 6mm radius
df_geom = geometry_results[6]
print(df_geom.groupby(['group', 'category_type'])['geometry_preservation'].agg(['mean', 'std', 'count']))

Computing MDS Embedding Shift...
  6mm: 528 measurements

6mm Summary by Category Type:
                   mean       std  count
category_type                           
Bilateral      0.280994  0.178664    264
Unilateral     0.243945  0.157929    264
✓ Done
                           mean       std  count
group   category_type                           
OTC     Bilateral      0.423894  0.349007     12
        Unilateral     0.712803  0.197306     12
control Bilateral      0.659483  0.358840     36
        Unilateral     0.755123  0.212617     36
nonOTC  Bilateral      0.726002  0.257030     18
        Unilateral     0.761896  0.196306     18


In [None]:
# CELL 7: MDS SHIFT STATISTICS
# ============================================================
from scipy.stats import ttest_ind, mannwhitneyu, permutation_test

df_mds = mds_results[6]

# Separate by group
otc = df_mds[df_mds['group'] == 'OTC']
non_otc = df_mds[df_mds['group'] == 'nonOTC']
control = df_mds[df_mds['group'] == 'control']

print("=== BILATERAL vs UNILATERAL WITHIN EACH GROUP ===\n")

for name, grp in [('OTC', otc), ('nonOTC', non_otc), ('Control', control)]:
    bil = grp[grp['category_type'] == 'Bilateral']['mds_shift']
    uni = grp[grp['category_type'] == 'Unilateral']['mds_shift']
    
    t_stat, t_p = ttest_ind(bil, uni)
    u_stat, u_p = mannwhitneyu(bil, uni, alternative='greater')  # bilateral > unilateral
    
    print(f"{name}:")
    print(f"  Bilateral: {bil.mean():.3f} ± {bil.std():.3f} (n={len(bil)})")
    print(f"  Unilateral: {uni.mean():.3f} ± {uni.std():.3f} (n={len(uni)})")
    print(f"  Difference: {bil.mean() - uni.mean():.3f}")
    print(f"  t-test: t={t_stat:.2f}, p={t_p:.3f}")
    print(f"  Mann-Whitney (bil>uni): U={u_stat:.0f}, p={u_p:.3f}\n")

print("=== OTC vs CONTROLS (BILATERAL ONLY) ===\n")
otc_bil = otc[otc['category_type'] == 'Bilateral']['mds_shift']
ctrl_bil = control[control['category_type'] == 'Bilateral']['mds_shift']
t_stat, t_p = ttest_ind(otc_bil, ctrl_bil)
print(f"OTC Bilateral: {otc_bil.mean():.3f} ± {otc_bil.std():.3f}")
print(f"Control Bilateral: {ctrl_bil.mean():.3f} ± {ctrl_bil.std():.3f}")
print(f"t-test: t={t_stat:.2f}, p={t_p:.3f}")

print("\n=== BOOTSTRAP TEST: OTC Bilateral vs Unilateral ===")
np.random.seed(42)
n_boot = 10000
observed_diff = otc[otc['category_type']=='Bilateral']['mds_shift'].mean() - \
                otc[otc['category_type']=='Unilateral']['mds_shift'].mean()

# Permutation test
combined = otc['mds_shift'].values
labels = otc['category_type'].values
boot_diffs = []

for _ in range(n_boot):
    shuffled = np.random.permutation(labels)
    bil_mean = combined[shuffled == 'Bilateral'].mean()
    uni_mean = combined[shuffled == 'Unilateral'].mean()
    boot_diffs.append(bil_mean - uni_mean)

boot_diffs = np.array(boot_diffs)
p_perm = np.mean(boot_diffs >= observed_diff)

print(f"Observed difference: {observed_diff:.3f}")
print(f"Permutation p-value (bil > uni): {p_perm:.3f}")
print(f"95% CI of null: [{np.percentile(boot_diffs, 2.5):.3f}, {np.percentile(boot_diffs, 97.5):.3f}]")


# CELL 6d: PRINT MDS RESULTS
# ============================================================
print("="*80)
print("MDS SHIFT RESULTS (6mm radius)")
print("="*80)

df_mds_display = mds_results[6].copy()

print(f"\nTotal measurements: {len(df_mds_display)}")
print(f"Subjects: {df_mds_display['subject'].nunique()}")
print(f"\nBy Group:")
print(df_mds_display.groupby('group').size())

print(f"\n{'-'*80}")
print("FULL MDS RESULTS:")
print(f"{'-'*80}\n")
print(df_mds_display.to_string())

print(f"\n{'-'*80}")
print("SUMMARY BY GROUP AND CATEGORY TYPE:")
print(f"{'-'*80}\n")
summary = df_mds_display.groupby(['group', 'category_type'])['mds_shift'].agg(['mean', 'std', 'count'])
print(summary)

=== BILATERAL vs UNILATERAL WITHIN EACH GROUP ===

OTC:
  Bilateral: 0.375 ± 0.217 (n=48)
  Unilateral: 0.299 ± 0.166 (n=48)
  Difference: 0.077
  t-test: t=1.94, p=0.055
  Mann-Whitney (bil>uni): U=1376, p=0.051

nonOTC:
  Bilateral: 0.247 ± 0.155 (n=72)
  Unilateral: 0.203 ± 0.133 (n=72)
  Difference: 0.044
  t-test: t=1.82, p=0.071
  Mann-Whitney (bil>uni): U=2997, p=0.053

Control:
  Bilateral: 0.267 ± 0.166 (n=144)
  Unilateral: 0.246 ± 0.162 (n=144)
  Difference: 0.020
  t-test: t=1.06, p=0.289
  Mann-Whitney (bil>uni): U=11142, p=0.137

=== OTC vs CONTROLS (BILATERAL ONLY) ===

OTC Bilateral: 0.375 ± 0.217
Control Bilateral: 0.267 ± 0.166
t-test: t=3.63, p=0.000

=== BOOTSTRAP TEST: OTC Bilateral vs Unilateral ===
Observed difference: 0.077
Permutation p-value (bil > uni): 0.031
95% CI of null: [-0.079, 0.081]


In [17]:
# CELL FINAL: PRODUCTION RUN (Strict Scramble - Top 10%)
# ============================================================
# METHODOLOGY: 
# All ROIs defined by Category > Scramble (Liu et al., 2013).
# Threshold restricted to Top 10% (percentile=90) to isolate category-selective
# cores from broad retinotopic activation (due to high Z-scores ~10.9).

print("RUNNING FINAL ANALYSIS: Scramble Map (Top 10%)...")

# 1. Extract ROIs (90th Percentile)
print("1. Extracting ROIs...")
top10_scramble = {}
for sub in ANALYSIS_SUBJECTS:
    res = extract_top20_rois(sub, COPE_MAP_SCRAMBLE, percentile=90)
    if res: 
        top10_scramble[sub] = res

# 2. Compute Metrics (Radius 6mm)
print("2. Computing Stability Metrics...")
# A. Geometry Preservation (Correlation of RDMs)
geom_final = compute_geometry_preservation(
    top10_scramble, COPE_MAP_SCRAMBLE, ANALYSIS_SUBJECTS, radius=6
)

# B. MDS Shift (Spatial Drift)
mds_final = compute_mds_shift(
    top10_scramble, COPE_MAP_SCRAMBLE, ANALYSIS_SUBJECTS, radius=6
)

# 3. Output Results
print("\n" + "="*50)
print("FINAL RESULTS SUMMARY")
print("="*50)

if len(mds_final) > 0:
    # Group Summary
    print("\nMDS Shift by Group & Category Type:")
    summary = mds_final.groupby(['group', 'category_type'])['mds_shift'].agg(['mean', 'std', 'count'])
    print(summary)
    
    # The Key Result Check
    otc = mds_final[mds_final['group'] == 'OTC']
    bil = otc[otc['category_type'] == 'Bilateral']['mds_shift'].mean()
    uni = otc[otc['category_type'] == 'Unilateral']['mds_shift'].mean()
    print(f"\nOTC Drift Pattern:")
    print(f"  Bilateral (House/Obj): {bil:.3f}")
    print(f"  Unilateral (Face/Word): {uni:.3f}")
    print(f"  Difference: {bil - uni:.3f} (Positive = Bilateral drifts more)")
    
    # Save to CSV
    # mds_final.to_csv(BASE_DIR / 'results_drift_scramble_top10.csv', index=False)
    # geom_final.to_csv(BASE_DIR / 'results_geometry_scramble_top10.csv', index=False)
    print("\n✓ CSVs ready to save.")

RUNNING FINAL ANALYSIS: Scramble Map (Top 10%)...
1. Extracting ROIs...
2. Computing Stability Metrics...

FINAL RESULTS SUMMARY

MDS Shift by Group & Category Type:
                           mean       std  count
group   category_type                           
OTC     Bilateral      0.337964  0.184367     48
        Unilateral     0.294501  0.179656     48
control Bilateral      0.228929  0.144479    144
        Unilateral     0.234028  0.151370    144
nonOTC  Bilateral      0.192105  0.151270     72
        Unilateral     0.199345  0.156793     72

OTC Drift Pattern:
  Bilateral (House/Obj): 0.338
  Unilateral (Face/Word): 0.295
  Difference: 0.043 (Positive = Bilateral drifts more)

✓ CSVs ready to save.


In [20]:
# CELL 16 (CORRECTED): STRICT SCRAMBLE RSA (Top 10% Centroids)
# ============================================================
import numpy as np
import pandas as pd
import nibabel as nib

# --- A. REDEFINE RSA FUNCTIONS (With Tuple Support) ---

def create_sphere(peak_coord, affine, brain_shape, radius=6):
    """Create 6mm sphere around peak"""
    grid_coords = np.array(np.meshgrid(
        np.arange(brain_shape[0]), 
        np.arange(brain_shape[1]), 
        np.arange(brain_shape[2]),
        indexing='ij'
    )).reshape(3, -1).T
    
    grid_world = nib.affines.apply_affine(affine, grid_coords)
    distances = np.linalg.norm(grid_world - peak_coord, axis=1)
    
    mask_3d = np.zeros(brain_shape, dtype=bool)
    within = grid_coords[distances <= radius]
    for coord in within:
        mask_3d[coord[0], coord[1], coord[2]] = True
    
    return mask_3d

def extract_betas(subject_id, session, sphere_mask, category_copes):
    """Extract beta patterns from sphere"""
    info = ANALYSIS_SUBJECTS[subject_id]
    first_session = info['sessions'][0]
    
    feat_dir = BASE_DIR / subject_id / f'ses-{session}' / 'derivatives' / 'fsl' / 'loc' / 'HighLevel.gfeat'
    
    beta_patterns = []
    valid_categories = []
    
    for category, cope_def in category_copes.items():
        # --- FIX: Handle Tuple (Drift format) vs Int (RSA format) ---
        if isinstance(cope_def, tuple):
            cope_num = cope_def[0] # Extract ID from (10, 1)
        else:
            cope_num = cope_def    # Use ID directly
            
        # Handle registration logic
        if session == first_session:
            cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / 'cope1.nii.gz'
        else:
            cope_file = feat_dir / f'cope{cope_num}.feat' / 'stats' / f'cope1_ses{first_session}.nii.gz'
        
        if not cope_file.exists():
            # Debug print if files are missing
            # print(f"Missing: {cope_file}") 
            continue
        
        cope_data = nib.load(cope_file).get_fdata()
        roi_betas = cope_data[sphere_mask]
        roi_betas = roi_betas[np.isfinite(roi_betas)]
        
        if len(roi_betas) > 0:
            beta_patterns.append(roi_betas)
            valid_categories.append(category)
    
    if len(beta_patterns) == 0:
        return None, None
    
    min_voxels = min(len(b) for b in beta_patterns)
    beta_patterns = [b[:min_voxels] for b in beta_patterns]
    beta_matrix = np.column_stack(beta_patterns)
    
    return beta_matrix, valid_categories

def compute_rdm(beta_matrix, fisher_transform=True):
    correlation_matrix = np.corrcoef(beta_matrix.T)
    rdm = 1 - correlation_matrix
    if fisher_transform:
        correlation_matrix_fisher = np.arctanh(np.clip(correlation_matrix, -0.999, 0.999))
        return rdm, correlation_matrix_fisher
    else:
        return rdm, correlation_matrix

def extract_rdms(functional_results, analysis_subjects, cope_map):
    all_rdms = {}
    print(f"Extracting RSA data for {len(functional_results)} subjects...")
    
    for subject_id in analysis_subjects.keys():
        if subject_id not in functional_results: continue
        
        info = analysis_subjects[subject_id]
        sessions = info['sessions']
        first_session = sessions[0]
        
        ref_file = BASE_DIR / subject_id / f'ses-{first_session}' / 'ROIs' / f"{info['hemi']}_face_searchmask.nii.gz"
        if not ref_file.exists(): continue
            
        ref_img = nib.load(ref_file)
        affine = ref_img.affine
        brain_shape = ref_img.shape
        
        all_rdms[subject_id] = {}
        
        for roi_name, roi_data in functional_results[subject_id].items():
            all_rdms[subject_id][roi_name] = {
                'rdms': {}, 'correlation_matrices': {}, 'valid_categories': None
            }
            
            for session in sessions:
                if session not in roi_data: continue
                
                # Use CENTROID from the input (Top 10% Scramble)
                peak = roi_data[session]['centroid']
                sphere_mask = create_sphere(peak, affine, brain_shape, radius=6)
                
                # Use the passed COPE_MAP (handling tuples automatically now)
                beta_matrix, valid_cats = extract_betas(subject_id, session, sphere_mask, cope_map)
                
                if beta_matrix is None: continue
                
                rdm, corr_matrix_fisher = compute_rdm(beta_matrix, fisher_transform=True)
                
                all_rdms[subject_id][roi_name]['rdms'][session] = rdm
                all_rdms[subject_id][roi_name]['correlation_matrices'][session] = corr_matrix_fisher
                all_rdms[subject_id][roi_name]['valid_categories'] = valid_cats
    return all_rdms

def compute_liu_metrics(all_rdms, analysis_subjects):
    distinctiveness_results = {}
    roi_preferred = {
        'l_face': 'face', 'r_face': 'face', 
        'l_word': 'word', 'r_word': 'word', 
        'l_object': 'object', 'r_object': 'object', 
        'l_house': 'house', 'r_house': 'house'
    }
    
    for subject_id, categories in all_rdms.items():
        distinctiveness_results[subject_id] = {}
        for roi_name, roi_data in categories.items():
            if not roi_data['correlation_matrices']: continue
            
            valid_cats = roi_data['valid_categories']
            if valid_cats is None or len(valid_cats) < 4: continue
            
            pref_key = [k for k in roi_preferred.keys() if k in roi_name]
            if not pref_key: continue
            preferred_cat = roi_preferred[roi_name]
            
            if preferred_cat not in valid_cats: continue
            
            pref_idx = valid_cats.index(preferred_cat)
            nonpref_indices = [i for i, cat in enumerate(valid_cats) if cat != preferred_cat]
            
            distinctiveness_results[subject_id][roi_name] = {}
            for session, corr_matrix in roi_data['correlation_matrices'].items():
                pref_vs_nonpref = corr_matrix[pref_idx, nonpref_indices]
                mean_corr = np.mean(pref_vs_nonpref)
                distinctiveness_results[subject_id][roi_name][session] = {'liu_distinctiveness': mean_corr}
    return distinctiveness_results

# --- B. EXECUTE ANALYSIS ---

print("TESTING DISTINCTIVENESS ON STRICT SCRAMBLE (Top 10%)...")

# 1. Regenerate Top 10% ROIs if needed
if 'top10_scramble' not in locals():
    print("Regenerating Top 10% ROIs...")
    top10_scramble = {}
    for sub in ANALYSIS_SUBJECTS:
        res = extract_top20_rois(sub, COPE_MAP_SCRAMBLE, percentile=90)
        if res: top10_scramble[sub] = res

# 2. Extract RDMs using the corrected function
# Pass COPE_MAP_SCRAMBLE explicitly so it finds the contrasts
rdms_strict = extract_rdms(top10_scramble, ANALYSIS_SUBJECTS, COPE_MAP_SCRAMBLE)
dist_strict = compute_liu_metrics(rdms_strict, ANALYSIS_SUBJECTS)

# 3. Print Results
print("\n" + "="*60)
print("STRICT SCRAMBLE DISTINCTIVENESS (Mean Correlation with Non-Preferred)")
print("Lower Value = MORE Distinctive (Values < 0.5 are typically good)")
print("="*60)

dist_stats = []
for sub, rois in dist_strict.items():
    if sub not in ANALYSIS_SUBJECTS: continue
    group = ANALYSIS_SUBJECTS[sub]['group']
    for roi, data in rois.items():
        vals = [d['liu_distinctiveness'] for d in data.values()]
        mean_val = np.mean(vals)
        dist_stats.append({'group': group, 'roi': roi, 'val': mean_val})

df_dist = pd.DataFrame(dist_stats)

if not df_dist.empty:
    print("\nMean Correlation by Group & ROI:")
    print(df_dist.groupby(['group', 'roi'])['val'].agg(['mean', 'std', 'count']))
else:
    print("No distinctiveness data found.")

TESTING DISTINCTIVENESS ON STRICT SCRAMBLE (Top 10%)...
Extracting RSA data for 24 subjects...

STRICT SCRAMBLE DISTINCTIVENESS (Mean Correlation with Non-Preferred)
Lower Value = MORE Distinctive (Values < 0.5 are typically good)

Mean Correlation by Group & ROI:
                      mean       std  count
group   roi                                
OTC     l_face    0.348755  0.005619      2
        l_house   0.029002  0.161859      2
        l_object  0.459991  0.595400      2
        l_word    0.309910  0.174248      2
        r_face    0.763184  0.131565      4
        r_house   0.301878  0.512554      4
        r_object  0.590208  0.411475      4
        r_word    0.589901  0.146641      4
control l_face    0.869115  0.313620      9
        l_house   0.166016  0.425518      9
        l_object  0.721747  0.369585      9
        l_word    0.630655  0.212636      9
        r_face    0.639489  0.215762      9
        r_house   0.168906  0.308753      9
        r_object  0.639142  0.3