# 4. MCI Conversion Split (Improved)

This notebook splits MCI subjects into pMCI and sMCI groups using ADNIMERGE.csv data:
- **pMCI (progressive MCI)**: Subjects diagnosed with EMCI/LMCI/MCI at baseline who later converted to Dementia
- **sMCI (stable MCI)**: Subjects diagnosed with EMCI/LMCI/MCI at baseline who remained MCI in all follow-ups **for at least 18 months**

**Key improvements:**
- Includes generic 'MCI' tag (ADNI1 patients with longest follow-up times)
- Enforces minimum 18-month stability period for sMCI classification
- Excludes subjects who revert to CN (ambiguous pathology)
- Excludes subjects with insufficient follow-up data

For subjects with multiple scans, only the earliest scan (lowest image ID) is selected.

In [None]:
import pandas as pd
import os
import shutil
from pathlib import Path
from tqdm.notebook import tqdm
from collections import defaultdict

### Define Paths

In [None]:
mci_path = Path("PATH_TO_DATA")
output_path = Path("PATH_TO_DATA")
pmci_path = output_path / "pMCI"
smci_path = output_path / "sMCI"
adnimerge_file = Path("PATH_TO_DATA")

pmci_path.mkdir(parents=True, exist_ok=True)
smci_path.mkdir(parents=True, exist_ok=True)

### Helper Functions

In [None]:
def extract_subject_and_image_id(filename):
    """Extract subject ID and image ID from filename."""
    name_without_ext = filename.replace('.nii.gz', '').replace('.nii', '')
    parts = name_without_ext.split('_')
    
    if len(parts) >= 4:
        subject_id = f"{parts[0]}_{parts[1]}_{parts[2]}"
        image_id_str = parts[3][1:]  # Remove 'I' prefix
        try:
            image_id = int(image_id_str)
            return subject_id, image_id
        except ValueError:
            return subject_id, float('inf')
    else:
        return filename, float('inf')

def get_earliest_scan_per_subject(file_list):
    """Group files by subject and return only the earliest scan for each subject."""
    subject_files = defaultdict(list)
    
    for filename in file_list:
        subject_id, image_id = extract_subject_and_image_id(filename)
        subject_files[subject_id].append((filename, image_id))
    
    earliest_scans = {}
    for subject_id, files in subject_files.items():
        files.sort(key=lambda x: x[1])
        earliest_file = files[0][0]
        earliest_scans[subject_id] = earliest_file
        
        if len(files) > 1:
            print(f"Subject {subject_id}: Selected {earliest_file} from {len(files)} scans")
    
    return earliest_scans

def determine_mci_conversion_status(adnimerge_df):
    """
    Determine MCI conversion status from ADNIMERGE data.
    
    INCLUDES: Generic 'MCI' tag (ADNI1).
    ENFORCES: Minimum 18-month stability for sMCI.
    
    Returns:
        tuple: (conversion_status_dict, excluded_subjects_dict)
        - conversion_status_dict: {PTID: 'pMCI' or 'sMCI'}
        - excluded_subjects_dict: {PTID: reason_string}
    """
    # 1. Fix: Include generic 'MCI' in baseline filter
    mci_types = ['EMCI', 'LMCI', 'MCI']
    mci_baseline = adnimerge_df[adnimerge_df['DX_bl'].isin(mci_types)]
    
    conversion_status = {}
    excluded_subjects = {}  # Track excluded subjects with reasons
    
    # Statistics
    stats = {"pMCI": 0, "sMCI": 0, "dropped_short_stable": 0, "dropped_reverter": 0, "dropped_no_data": 0, "dropped_inconsistent": 0}
    
    for ptid, subject_data in mci_baseline.groupby('PTID'):
        # Sort by date to be safe
        subject_data = subject_data.sort_values('EXAMDATE')
        
        # Get follow up data
        follow_ups = subject_data[subject_data['VISCODE'] != 'bl']
        follow_up_dx = follow_ups['DX'].dropna().values
        
        if len(follow_up_dx) == 0:
            stats["dropped_no_data"] += 1
            excluded_subjects[ptid] = "NO_FOLLOWUP_DATA"
            continue
            
        # --- LOGIC TREE ---
        
        # 1. Did they EVER convert to Dementia? -> pMCI
        if 'Dementia' in follow_up_dx:
            conversion_status[ptid] = 'pMCI'
            stats["pMCI"] += 1
            continue

        # 2. Did they Revert to Normal (CN)? -> Exclude (Ambiguous pathology)
        if 'CN' in follow_up_dx:
            stats["dropped_reverter"] += 1
            excluded_subjects[ptid] = "REVERTED_TO_CN"
            continue

        # 3. sMCI Logic: Must contain ONLY MCI diagnoses AND span enough time
        valid_mci_dx = {'EMCI', 'LMCI', 'MCI'}
        is_consistent_mci = all(dx in valid_mci_dx for dx in follow_up_dx)
        
        if is_consistent_mci:
            # Calculate time difference in months
            dates = pd.to_datetime(subject_data['EXAMDATE'])
            duration_days = (dates.max() - dates.min()).days
            
            # THRESHOLD: 540 days approx 18 months. 
            # Increasing this to 730 (2 years) is better if you can afford the data loss.
            if duration_days >= 730: 
                conversion_status[ptid] = 'sMCI'
                stats["sMCI"] += 1
            else:
                stats["dropped_short_stable"] += 1
                excluded_subjects[ptid] = f"STABLE_BUT_SHORT ({duration_days} days < 730 days)"
        else:
            # Fluctuating or missing data
            stats["dropped_inconsistent"] += 1
            excluded_subjects[ptid] = f"INCONSISTENT_DX ({list(follow_up_dx)})"

    print("\n--- Final Dataset Stats ---")
    print(f"pMCI: {stats['pMCI']}")
    print(f"sMCI: {stats['sMCI']}")
    print(f"Dropped (Stable < 18 months): {stats['dropped_short_stable']}")
    print(f"Dropped (Reverted to CN): {stats['dropped_reverter']}")
    print(f"Dropped (No follow-up data): {stats['dropped_no_data']}")
    print(f"Dropped (Inconsistent diagnoses): {stats['dropped_inconsistent']}")
    print(f"Total MCI baseline subjects: {len(mci_baseline.groupby('PTID'))}")
    print(f"Total classified: {len(conversion_status)}")
    print(f"Total excluded: {len(excluded_subjects)}")
    
    return conversion_status, excluded_subjects

### Load ADNIMERGE Data and Determine Conversions

In [None]:
# Load ADNIMERGE data
try:
    print("Loading ADNIMERGE.csv...")
    adnimerge_df = pd.read_csv(adnimerge_file, low_memory=False)
    print(f"Loaded {len(adnimerge_df)} records from ADNIMERGE")
    
    # Determine conversion status
    print("Determining MCI conversion status...")
    conversion_status, excluded_subjects = determine_mci_conversion_status(adnimerge_df)
    
    pmci_subjects = [ptid for ptid, status in conversion_status.items() if status == 'pMCI']
    smci_subjects = [ptid for ptid, status in conversion_status.items() if status == 'sMCI']
    
    print(f"\nConversion analysis results:")
    print(f"pMCI (progressive): {len(pmci_subjects)} subjects")
    print(f"sMCI (stable): {len(smci_subjects)} subjects")
    print(f"Total classified: {len(conversion_status)} subjects")
    if len(conversion_status) > 0:
        print(f"Conversion rate: {len(pmci_subjects) / len(conversion_status) * 100:.1f}%")
    
except FileNotFoundError:
    print(f"Error: ADNIMERGE file not found at {adnimerge_file}")
    conversion_status = {}
    excluded_subjects = {}
    pmci_subjects = []
    smci_subjects = []

### Process MCI Files and Split


In [None]:
# Diagnostic: Check which subjects from files are not in ADNIMERGE
print("=== DIAGNOSTIC: Checking subjects not found in ADNIMERGE ===\n")

# Get all subject IDs from files
mci_files = [f for f in os.listdir(mci_path) if f.endswith('.nii.gz')]
file_subject_ids = set()
for filename in mci_files:
    subject_id, _ = extract_subject_and_image_id(filename)
    file_subject_ids.add(subject_id)

print(f"Total unique subjects in TAU MCI files: {len(file_subject_ids)}")

# Get all PTIDs from ADNIMERGE
adnimerge_ptids = set(adnimerge_df['PTID'].unique())
print(f"Total unique PTIDs in ADNIMERGE: {len(adnimerge_ptids)}")

# Find missing subjects
missing_subjects = file_subject_ids - adnimerge_ptids
found_subjects = file_subject_ids & adnimerge_ptids

print(f"\nSubjects found in ADNIMERGE: {len(found_subjects)}")
print(f"Subjects NOT in ADNIMERGE: {len(missing_subjects)}")

if missing_subjects:
    print(f"\n--- First 20 subjects NOT in ADNIMERGE ---")
    for i, subject_id in enumerate(sorted(missing_subjects)[:20]):
        print(f"  {subject_id}")
    
    # Check if there's a pattern - maybe they're formatted differently?
    print(f"\n--- Checking for format variations ---")
    # Try looking for variations (with/without leading zeros, etc.)
    sample_missing = list(missing_subjects)[:5]
    for subject_id in sample_missing:
        parts = subject_id.split('_')
        if len(parts) == 3:
            site, s, num = parts
            # Try variations
            variations = [
                subject_id,
                f"{int(site)}_{s}_{int(num)}",  # Remove leading zeros
                f"{site}_{s}_{num.zfill(4)}",  # Add more zeros
            ]
            found_variations = []
            for var in variations:
                if var in adnimerge_ptids:
                    found_variations.append(var)
            if found_variations:
                print(f"  {subject_id} -> Found as: {found_variations}")
            else:
                print(f"  {subject_id} -> Not found in any variation")
    
    # Check if subjects exist but with different baseline diagnoses
    print(f"\n--- Checking if missing subjects exist with non-MCI baseline ---")
    sample_missing = list(missing_subjects)[:10]
    for subject_id in sample_missing:
        # Try partial match - maybe the format is slightly different
        matching_rows = adnimerge_df[adnimerge_df['PTID'].str.contains(subject_id.split('_')[-1], na=False)]
        if len(matching_rows) > 0:
            unique_ptids = matching_rows['PTID'].unique()
            print(f"  {subject_id} -> Found similar PTIDs: {list(unique_ptids)[:3]}")
            # Check baseline diagnoses
            for ptid in unique_ptids[:1]:
                baseline_rows = adnimerge_df[(adnimerge_df['PTID'] == ptid) & (adnimerge_df['VISCODE'] == 'bl')]
                if len(baseline_rows) > 0:
                    dx_bl = baseline_rows.iloc[0]['DX_bl']
                    print(f"      Baseline DX: {dx_bl}")


In [None]:
# Get MCI files and select earliest per subject
mci_files = [f for f in os.listdir(mci_path) if f.endswith('.nii.gz')]
print(f"Found {len(mci_files)} MCI files")

earliest_scans = get_earliest_scan_per_subject(mci_files)
print(f"After selecting earliest scans: {len(earliest_scans)} unique subjects")

# Split into pMCI and sMCI based on conversion status
pmci_count = smci_count = unclassified_count = 0
unclassified_reasons = defaultdict(int)
unclassified_examples = []  # Store examples for display

for subject_id, filename in tqdm(earliest_scans.items(), desc="Splitting MCI subjects"):
    source_path = mci_path / filename
    
    # Check conversion status
    if subject_id in conversion_status:
        status = conversion_status[subject_id]
        if status == 'pMCI':
            dest_path = pmci_path / filename
            pmci_count += 1
        else:  # sMCI
            dest_path = smci_path / filename
            smci_count += 1
        
        # Copy file if it doesn't exist
        if not dest_path.exists():
            shutil.copy(source_path, dest_path)
    else:
        unclassified_count += 1
        
        # Check if subject is in excluded_subjects (from conversion analysis)
        if subject_id in excluded_subjects:
            reason = excluded_subjects[subject_id]
            unclassified_reasons[reason] += 1
            if len(unclassified_examples) < 10:  # Store first 10 examples
                unclassified_examples.append((subject_id, reason))
        else:
            # Subject not in ADNIMERGE or doesn't have MCI baseline
            subject_data = adnimerge_df[adnimerge_df['PTID'] == subject_id]
            if len(subject_data) == 0:
                reason = "NOT_IN_ADNIMERGE"
            else:
                baseline = subject_data[subject_data['VISCODE'] == 'bl']
                if len(baseline) == 0:
                    reason = "NO_BASELINE_IN_ADNIMERGE"
                else:
                    dx_bl = baseline.iloc[0]['DX_bl']
                    reason = f"BASELINE_NOT_MCI ({dx_bl})"
            
            unclassified_reasons[reason] += 1
            if len(unclassified_examples) < 10:
                unclassified_examples.append((subject_id, reason))

print(f"\nSplitting completed:")
print(f"pMCI (progressive): {pmci_count} subjects")
print(f"sMCI (stable): {smci_count} subjects")
print(f"Unclassified: {unclassified_count} subjects")
if pmci_count + smci_count > 0:
    print(f"Conversion rate: {pmci_count / (pmci_count + smci_count) * 100:.1f}%")

print(f"\n--- Unclassified Subjects Breakdown ---")
print("(Subjects from files that are not in conversion_status)")
for reason, count in sorted(unclassified_reasons.items(), key=lambda x: -x[1]):
    print(f"{reason}: {count}")

if unclassified_examples:
    print(f"\n--- Example Unclassified Subjects (first 10) ---")
    for subject_id, reason in unclassified_examples:
        print(f"  {subject_id}: {reason}")


### Summary Statistics


In [None]:
# Verify the splits
pmci_files_created = len([f for f in os.listdir(pmci_path) if f.endswith('.nii.gz')])
smci_files_created = len([f for f in os.listdir(smci_path) if f.endswith('.nii.gz')])

print(f"\nFinal verification:")
print(f"pMCI directory: {pmci_files_created} files")
print(f"sMCI directory: {smci_files_created} files")
print(f"Total processed: {pmci_files_created + smci_files_created} files")

# Show some example PTIDs for each group
if pmci_subjects:
    print(f"\nExample pMCI subjects: {pmci_subjects[:5]}")
if smci_subjects:
    print(f"Example sMCI subjects: {smci_subjects[:5]}")
