<a href="https://colab.research.google.com/github/grabuffo/BrainStim_ANN_fMRI_HCP/blob/main/notebooks/Analyze_TargetRegions_Grouping.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Analyze Target Regions & Group Participants by Stimulation Target

This notebook:
1. Loads the TMS-fMRI dataset
2. Extracts target regions for each subject's task-stim sessions
3. Groups participants by target region
4. Provides summary statistics on group sizes and target region distribution

In [None]:
# --- Setup cell (Google Colab compatibility) ---

import os
import sys
import pickle
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

try:
    from google.colab import drive
    IN_COLAB = True
    drive.mount('/content/drive', force_remount=True)
    BASE = "/content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data"
except:
    IN_COLAB = False
    # Local path
    BASE = "/path/to/your/data"  # Update this if running locally

# Clone repo if needed
if IN_COLAB:
    repo_dir = "/content/BrainStim_ANN_fMRI_HCP"
    if not os.path.exists(repo_dir):
        !git clone https://github.com/grabuffo/BrainStim_ANN_fMRI_HCP.git
    else:
        print("Repo already exists âœ…")
    sys.path.append(repo_dir)

print(f"Running in Colab: {IN_COLAB}")
print(f"Data directory: {BASE}")

In [None]:
# --- Load dataset ---

DATASET_PKL = os.path.join(BASE, "TMS_fMRI", "dataset_tian50_schaefer400_allruns.pkl")

print(f"Loading dataset from: {DATASET_PKL}")
with open(DATASET_PKL, "rb") as f:
    dataset = pickle.load(f)

print(f"âœ… Dataset loaded successfully")
print(f"Number of subjects: {len(dataset)}")
print(f"Sample subject keys: {list(dataset.keys())[:3]}")

In [None]:
# --- Extract subject list and explore structure ---

subjects_list = sorted([s.replace('sub-', '') for s in dataset.keys()])
print(f"Subjects: {subjects_list}")
print(f"Total subjects: {len(subjects_list)}")

# Check structure for first subject
first_subject = 'sub-' + subjects_list[0]
print(f"\nStructure for {first_subject}:")
print(f"  Conditions: {list(dataset[first_subject].keys())}")
print(f"  Task-rest trials: {len(dataset[first_subject].get('task-rest', []))}")
print(f"  Task-stim trials: {len(dataset[first_subject].get('task-stim', []))}")

In [None]:
# --- Extract target regions for each subject ---

# Dictionary: subject_id -> list of target region IDs (from task-stim sessions)
subject_targets = {}

# Dictionary: target_region_id -> list of subject_ids
target_to_subjects = defaultdict(list)

# Track all unique target regions
all_target_regions = set()

for subject_id, subject_name in enumerate(subjects_list):
    full_subject_name = 'sub-' + subject_name
    
    # Get task-stim sessions for this subject
    task_stim_sessions = dataset[full_subject_name].get('task-stim', [])
    
    subject_targets[subject_id] = []
    
    # Extract target region from each task-stim session
    for trial in range(len(task_stim_sessions)):
        try:
            target_array = dataset[full_subject_name]['task-stim'][trial]['target']
            target_id = np.where(target_array == 1)[0][0]
            subject_targets[subject_id].append(target_id)
            all_target_regions.add(target_id)
            target_to_subjects[target_id].append(subject_id)
        except Exception as e:
            print(f"Error extracting target for {full_subject_name} trial {trial}: {e}")

print(f"âœ… Target regions extracted")
print(f"\nTotal unique target regions: {len(all_target_regions)}")
print(f"Target region IDs: {sorted(all_target_regions)}")

In [None]:
# --- Summary: targets per subject ---

print("\n" + "="*60)
print("TARGET REGIONS PER SUBJECT")
print("="*60)
for subject_id, subject_name in enumerate(subjects_list):
    targets = subject_targets[subject_id]
    unique_targets = set(targets)
    print(f"Subject {subject_id:2d} ({subject_name:3s}): {list(unique_targets)} (n_sessions={len(targets)})")

In [None]:
# --- Summary: participants per target region ---

print("\n" + "="*60)
print("PARTICIPANTS PER TARGET REGION")
print("="*60)

# Create a clean summary
target_region_summary = {}
for target_id in sorted(all_target_regions):
    subject_ids = sorted(list(set(target_to_subjects[target_id])))  # unique subjects per target
    n_subjects = len(subject_ids)
    subject_names = [subjects_list[sid] for sid in subject_ids]
    target_region_summary[target_id] = {
        'subject_ids': subject_ids,
        'subject_names': subject_names,
        'n_subjects': n_subjects,
        'n_sessions': len(target_to_subjects[target_id])
    }
    print(f"Target region {target_id:3d}: {n_subjects} subjects {subject_names} ({len(target_to_subjects[target_id])} sessions total)")

print(f"\nTotal groups: {len(target_region_summary)}")

In [None]:
# --- Visualization: Histogram of group sizes ---

group_sizes = [info['n_subjects'] for info in target_region_summary.values()]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Histogram
ax1.hist(group_sizes, bins=range(1, max(group_sizes)+2), edgecolor='black', alpha=0.7, color='steelblue')
ax1.set_xlabel('Number of participants per target region')
ax1.set_ylabel('Frequency')
ax1.set_title('Distribution of Group Sizes')
ax1.grid(axis='y', alpha=0.3)

# Bar plot: participants per target region
target_ids = sorted(target_region_summary.keys())
n_subjects_per_target = [target_region_summary[tid]['n_subjects'] for tid in target_ids]

colors = ['steelblue' if n >= 2 else 'lightcoral' for n in n_subjects_per_target]
ax2.bar(range(len(target_ids)), n_subjects_per_target, color=colors, edgecolor='black', alpha=0.7)
ax2.set_xlabel('Target Region ID')
ax2.set_ylabel('Number of participants')
ax2.set_title('Participants per Target Region')
ax2.set_xticks(range(len(target_ids)))
ax2.set_xticklabels([str(tid) for tid in target_ids], rotation=45)
ax2.grid(axis='y', alpha=0.3)
ax2.axhline(y=2, color='red', linestyle='--', linewidth=1.5, alpha=0.5, label='Min for group')
ax2.legend()

plt.tight_layout()
plt.show()

print(f"\nðŸ’¡ Note: Blue bars = targets with â‰¥2 participants (can form groups)")
print(f"           Red bars = targets with 1 participant (singleton groups)")

In [None]:
# --- Statistical summary ---

group_sizes = np.array([info['n_subjects'] for info in target_region_summary.values()])

print("\n" + "="*60)
print("STATISTICAL SUMMARY")
print("="*60)
print(f"Total target regions: {len(target_region_summary)}")
print(f"Mean participants per group: {group_sizes.mean():.2f}")
print(f"Median participants per group: {np.median(group_sizes):.0f}")
print(f"Min participants per group: {group_sizes.min()}")
print(f"Max participants per group: {group_sizes.max()}")
print(f"\nGroups with â‰¥2 participants: {(group_sizes >= 2).sum()}")
print(f"Groups with 1 participant: {(group_sizes == 1).sum()}")
print(f"\nTotal participants to train: {(group_sizes >= 2).sum()} groups")
print(f"Participants in multi-subject groups: {group_sizes[group_sizes >= 2].sum()}")

In [None]:
# --- Create grouped participant list for downstream analysis ---

# Filter to only groups with â‰¥2 participants (ready for training)
multi_subject_groups = {target_id: info for target_id, info in target_region_summary.items() if info['n_subjects'] >= 2}

print(f"\nðŸŽ¯ Multi-subject groups ready for training:")
print(f"="*60)
for target_id, info in sorted(multi_subject_groups.items()):
    print(f"Target {target_id}: {info['n_subjects']} participants {info['subject_names']}")

print(f"\nâœ… Total trainable groups: {len(multi_subject_groups)}")

In [None]:
# --- Optional: Save summary to pickle for downstream notebooks ---

summary = {
    'subjects_list': subjects_list,
    'subject_targets': subject_targets,
    'all_target_regions': sorted(list(all_target_regions)),
    'target_region_summary': target_region_summary,
    'multi_subject_groups': multi_subject_groups,
}

summary_pkl = os.path.join(BASE, "TMS_fMRI", "target_regions_grouping_summary.pkl")
with open(summary_pkl, "wb") as f:
    pickle.dump(summary, f)

print(f"ðŸ’¾ Saved grouping summary to: {summary_pkl}")