<a href="https://colab.research.google.com/github/grabuffo/BrainStim_ANN_fMRI_HCP/blob/main/notebooks/Analyze_TargetRegions_Grouping.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TMS-fMRI: Participant Grouping by Target Stimulation Regions

This notebook analyzes the target regions for TMS stimulation across all participants and groups them accordingly. We'll identify:
- How many unique target regions exist
- How many participants are stimulated in each region
- Create groupings for subsequent model training

## 1Ô∏è‚É£ Setup & Imports

In [None]:
# --- Setup cell ---

# 1Ô∏è‚É£ Mount Google Drive (for data)
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# 2Ô∏è‚É£ Clone GitHub repository (for code)
import os, sys, subprocess

repo_dir = "/content/BrainStim_ANN_fMRI_HCP"
if not os.path.exists(repo_dir):
    !git clone https://github.com/grabuffo/BrainStim_ANN_fMRI_HCP.git
else:
    print("Repo already exists ‚úÖ")

# 3Ô∏è‚É£ Define paths (TMS-fMRI)
data_dir = "/content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data"
preproc_dir = os.path.join(data_dir, "preprocessed_subjects_tms_fmri")

# 4Ô∏è‚É£ Add repo to import path + imports
sys.path.append(repo_dir)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import glob
from collections import defaultdict

# Set style for plots
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (10, 6)

print("‚úÖ Setup complete | Data directory:", preproc_dir)

## 2Ô∏è‚É£ Load Participant Metadata

In [None]:
# --- Load all available subjects and their sessions ---

# First, get all subject directories
subject_dirs = sorted([d for d in os.listdir(preproc_dir) if d.startswith('sub-')])
print(f"Found {len(subject_dirs)} subject directories")

# Check for a metadata file (e.g., participants.tsv or sessions.json)
# For now, we'll infer from available files
signal_files = sorted(glob.glob(os.path.join(preproc_dir, "sub-*_signals.npy")))
subject_ids = sorted(list(set([os.path.basename(f).split('_signals')[0] for f in signal_files])))

print(f"\nSubjects with preprocessed signals: {len(subject_ids)}")
print("Subject IDs:", subject_ids[:5], "..." if len(subject_ids) > 5 else "")

# Check for metadata files that might contain target region info
metadata_files = glob.glob(os.path.join(preproc_dir, "*metadata*"))
print(f"\nMetadata files found: {metadata_files}")

# Also check for task-specific files or logs
task_stim_files = glob.glob(os.path.join(preproc_dir, "sub-*_task-stim*"))
print(f"Task-stim files found: {len(task_stim_files)}")

In [None]:
import pickle

# --- Load the TMS-fMRI dataset ---

BASE = data_dir
DATASET_PKL = os.path.join(BASE, "TMS_fMRI", "dataset_tian50_schaefer400_allruns.pkl")

with open(DATASET_PKL, "rb") as f:
    dataset = pickle.load(f)

print(f"‚úÖ Loaded dataset from: {DATASET_PKL}")
print(f"   Dataset keys (sample): {list(dataset.keys())[:3]}")

# Get list of all subjects in the dataset
subjects_in_dataset = [k.replace('sub-', '') for k in dataset.keys() if k.startswith('sub-')]
print(f"   Total subjects in dataset: {len(subjects_in_dataset)}")

## 3Ô∏è‚É£ Extract Target Regions from Metadata

**Note:** This cell attempts to load target region information. You may need to adjust based on how the metadata is stored in your dataset (e.g., JSON files, TSV metadata, or embedded in filenames).

In [None]:
# --- Extract target regions from dataset ---
# Load the dataset (adjust path if needed)

# Assuming dataset is loaded or available; if not, you may need to load it from a pickle/npz file
# dataset = np.load('path_to_dataset.npy', allow_pickle=True).item()  # if stored as .npy
# or
# import pickle
# with open('path_to_dataset.pkl', 'rb') as f:
#     dataset = pickle.load(f)

# For now, we'll build participant_targets from the dataset structure:
# dataset['sub-<subject_id>']['task-stim'][trial_index] contains 'target' (one-hot encoded)

participant_targets = {}  # {subject_id: [list of unique target regions]}

for sid_idx, sid in enumerate(subject_ids):
    subject_key = f'sub-{sid}' if not sid.startswith('sub-') else sid
    
    # Try to access task-stim trials for this subject
    try:
        task_stim_data = dataset[subject_key]['task-stim']
        
        # Extract target regions from all trials for this subject
        targets = []
        for trial_idx in range(len(task_stim_data)):
            trial_data = task_stim_data[trial_idx]
            
            # target is one-hot encoded, find which region was stimulated
            if 'target' in trial_data:
                target_array = trial_data['target']
                target_id = np.where(target_array == 1)[0]
                
                if len(target_id) > 0:
                    target_region = int(target_id[0])
                    targets.append(target_region)
        
        # Store unique targets for this subject
        if targets:
            participant_targets[sid] = sorted(list(set(targets)))
            
    except KeyError:
        print(f"Warning: No task-stim data found for {subject_key}")
        continue

print(f"Extracted target regions for {len(participant_targets)} subjects")
print("\nExample entries:")
for sid in list(participant_targets.keys())[:3]:
    print(f"  {sid}: {participant_targets[sid]}")

## 4Ô∏è‚É£ Count Unique Target Regions

In [None]:
# --- Count unique target regions and participants per region ---

# Flatten all targets to find unique ones
all_targets = []
for targets in participant_targets.values():
    if isinstance(targets, list):
        all_targets.extend(targets)
    else:
        all_targets.append(targets)

unique_targets = sorted(list(set(all_targets)))

print(f"Total unique target regions: {len(unique_targets)}")
print(f"Target regions: {unique_targets}\n")

# Count participants per target region
target_to_participants = defaultdict(list)
for sid, targets in participant_targets.items():
    if isinstance(targets, list):
        for target in targets:
            target_to_participants[target].append(sid)
    else:
        target_to_participants[targets].append(sid)

# Sort by number of participants (descending)
sorted_targets = sorted(target_to_participants.items(), key=lambda x: len(x[1]), reverse=True)

print("Participants per target region:")
print("-" * 50)
for target, participants in sorted_targets:
    print(f"{target:20s}: {len(participants):3d} participants")
    print(f"  {participants}")
    
print("-" * 50)
print(f"Total: {len(participant_targets)} participants")

# Summary statistics
target_counts = [len(p) for p in target_to_participants.values()]
print(f"\nGroup size statistics:")
print(f"  Mean group size: {np.mean(target_counts):.2f}")
print(f"  Median group size: {np.median(target_counts):.1f}")
print(f"  Min group size: {np.min(target_counts)}")
print(f"  Max group size: {np.max(target_counts)}")

## 5Ô∏è‚É£ Visualize Participant Distribution

In [None]:
# --- Bar plot: participants per target region ---

targets_list = [t for t, p in sorted_targets]
counts_list = [len(p) for t, p in sorted_targets]

fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.bar(range(len(targets_list)), counts_list, color='steelblue', edgecolor='black', alpha=0.7)

# Add value labels on bars
for i, (bar, count) in enumerate(zip(bars, counts_list)):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, 
            str(count), ha='center', va='bottom', fontsize=10, fontweight='bold')

ax.set_xlabel("Target Region", fontsize=12, fontweight='bold')
ax.set_ylabel("Number of Participants", fontsize=12, fontweight='bold')
ax.set_title("Participant Distribution Across Target Regions", fontsize=14, fontweight='bold')
ax.set_xticks(range(len(targets_list)))
ax.set_xticklabels(targets_list, rotation=45, ha='right')
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# --- Pie chart: proportion of participants per target region ---

fig, ax = plt.subplots(figsize=(8, 8))
colors = plt.cm.Set3(np.linspace(0, 1, len(targets_list)))

wedges, texts, autotexts = ax.pie(counts_list, labels=targets_list, autopct='%1.1f%%',
                                    colors=colors, startangle=90, textprops={'fontsize': 10})

# Make percentage text bold
for autotext in autotexts:
    autotext.set_color('black')
    autotext.set_fontweight('bold')
    autotext.set_fontsize(9)

ax.set_title("Proportion of Participants per Target Region", fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()

## 6Ô∏è‚É£ Generate Summary Table

In [None]:
# --- Create summary DataFrame ---

summary_data = []
for target, participants in sorted_targets:
    summary_data.append({
        'Target Region': target,
        'N Participants': len(participants),
        'Participants': ', '.join(participants),
        'Percentage': f"{100 * len(participants) / len(participant_targets):.1f}%"
    })

df_summary = pd.DataFrame(summary_data)

print("\n" + "="*80)
print("SUMMARY: TARGET REGIONS & PARTICIPANT GROUPING")
print("="*80 + "\n")
print(df_summary.to_string(index=False))
print("\n" + "="*80)
print(f"Total participants: {len(participant_targets)}")
print(f"Total target regions: {len(unique_targets)}")
print("="*80)

## 7Ô∏è‚É£ Export Grouping Information for Model Training

Save the grouping information for use in subsequent training notebooks.

In [None]:
# --- Save grouping information as JSON ---

# Create a clean grouping dictionary
grouping = {
    'target_regions': unique_targets,
    'participant_groups': {target: participants for target, participants in target_to_participants.items()},
    'participant_to_target': participant_targets,
    'summary': {
        'total_participants': len(participant_targets),
        'total_target_regions': len(unique_targets),
        'mean_group_size': float(np.mean(target_counts)),
        'median_group_size': float(np.median(target_counts)),
        'min_group_size': int(np.min(target_counts)),
        'max_group_size': int(np.max(target_counts))
    }
}

# Save to JSON
grouping_path = os.path.join(preproc_dir, "target_region_grouping.json")
with open(grouping_path, 'w') as f:
    json.dump(grouping, f, indent=2)

print(f"‚úÖ Saved grouping information to: {grouping_path}")

# Also save the summary table as CSV
csv_path = os.path.join(preproc_dir, "target_region_summary.csv")
df_summary.to_csv(csv_path, index=False)
print(f"‚úÖ Saved summary table to: {csv_path}")

In [None]:
# --- Display final summary ---

print("\n" + "üéØ "*20)
print("GROUPING READY FOR MODEL TRAINING")
print("üéØ "*20 + "\n")

print("Next steps:")
print("1. Use the grouping information to train separate models per target region")
print("2. For each target region group:")
print("   - Use REST sessions for training data")
print("   - Use STIM sessions for validation/evaluation")
print("3. Compare models across different target regions\n")

print("Key files generated:")
print(f"  - {grouping_path}")
print(f"  - {csv_path}\n")