# SRPBS Multi-Disorder MRI Dataset Exploration

This notebook explores the SRPBS (Strategic Research Program for Brain Sciences) dataset for normative modeling of dementia.

**Dataset**: SRPBS Multi-disorder MRI (Japanese)  
**Focus**: Healthy controls aged ≥45 years with T1w MRI  
**Data Type**: Cross-sectional  
**Study Design**: Multi-site, multi-disorder study across Japan

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
import os 
os.chdir(Path().absolute().parent)

## 1. Load SRPBS Participants Data

In [None]:
# Load participants.tsv
print("Loading SRPBS participants data...")
participants = pd.read_csv('data/SRPBS/participants.tsv', sep='\t')

print(f"Total records: {len(participants):,}")
print(f"Total columns: {len(participants.columns)}")
print(f"\nColumns: {list(participants.columns)}")
print(f"\nFirst 5 rows:")
participants.head()

## 2. Dataset Structure Analysis

In [None]:
# Diagnosis distribution
print("Diagnosis distribution:")
print("Legend:")
print("  0: Healthy Control")
print("  1: Autistic Spectrum Disorders")
print("  2: Major Depressive Disorder")
print("  3: Obsessive Compulsive Disorder")
print("  4: Schizophrenia")
print("  5: Pain")
print("  6: Stroke")
print("  7: Bipolar Disorder")
print("  8: Dysthymia")
print("  99: Others")
print()
diag_counts = participants['diag'].value_counts().sort_index()
for diag, count in diag_counts.items():
    print(f"  {diag}: {count:4d} ({count/len(participants)*100:.1f}%)")

# Filter healthy controls
healthy = participants[participants['diag'] == 0].copy()
print(f"\n>>> Total Healthy Controls: {len(healthy):,} <<<")

In [None]:
# Site distribution for healthy controls
print("Site distribution (Healthy Controls only):")
print("Site codes:")
print("  SWA: Showa University")
print("  HUH: Hiroshima University Hospital")
print("  HRC: Hiroshima Rehabilitation Center")
print("  HKH: Hiroshima Kajikawa Hospital")
print("  COI: Hiroshima COI")
print("  KUT: Kyoto University TimTrio")
print("  KTT: Kyoto University Trio")
print("  UTO: University of Tokyo Hospital")
print("  ATT: ATR Trio")
print("  ATV: ATR Verio")
print("  CIN: CiNet")
print("  NKN: Nishinomiya Kyouritsu Hospital")
print()
site_counts = healthy['site'].value_counts()
for site, count in site_counts.items():
    print(f"  {site}: {count:3d} ({count/len(healthy)*100:.1f}%)")

print(f"\nTotal sites: {healthy['site'].nunique()}")

## 3. Demographics of Healthy Controls

In [None]:
# Age distribution (all healthy controls)
print("Age distribution (all healthy controls):")
print(f"  Range: {healthy['age'].min():.0f} - {healthy['age'].max():.0f} years")
print(f"  Mean: {healthy['age'].mean():.1f} ± {healthy['age'].std():.1f}")
print(f"  Median: {healthy['age'].median():.0f}")
print(f"  Q1-Q3: {healthy['age'].quantile(0.25):.0f} - {healthy['age'].quantile(0.75):.0f}")
print(f"  Missing: {healthy['age'].isna().sum()}")

# Sex distribution (encoding: 1=Male, 2=Female)
print("\nSex distribution:")
sex_map = {1: 'M', 2: 'F'}
healthy['sex_mapped'] = healthy['sex'].map(sex_map)
sex_counts = healthy['sex_mapped'].value_counts()
for sex, count in sex_counts.items():
    print(f"  {sex}: {count} ({count/len(healthy)*100:.1f}%)")
print(f"  Missing: {healthy['sex'].isna().sum()}")

# Handedness distribution (encoding: 1=Right, 2=Left)
print("\nHandedness distribution:")
hand_map = {1: 'R', 2: 'L'}
healthy['hand_mapped'] = healthy['hand'].map(hand_map)
hand_counts = healthy['hand_mapped'].value_counts()
for hand, count in hand_counts.items():
    print(f"  {hand}: {count} ({count/len(healthy)*100:.1f}%)")
print(f"  Missing: {healthy['hand'].isna().sum()}")

## 4. Filter for Age ≥45 Years

In [None]:
# Filter for age >= 45
age_45plus = healthy[healthy['age'] >= 45].copy()

print("Filtering criteria:")
print(f"  Diagnosis = 0 (Healthy):    {len(healthy):,} subjects")
print(f"  Age ≥45:                    {len(age_45plus):,} subjects")
print()
print(f">>> TOTAL HEALTHY CONTROLS (AGE ≥45): {len(age_45plus):,} subjects <<<")
print()

# Age distribution after filtering
print("Age distribution (healthy controls ≥45):")
print(f"  Range: {age_45plus['age'].min():.0f} - {age_45plus['age'].max():.0f} years")
print(f"  Mean: {age_45plus['age'].mean():.1f} ± {age_45plus['age'].std():.1f}")
print(f"  Median: {age_45plus['age'].median():.0f}")

# Age decade distribution
print("\nAge decade distribution:")
age_bins = [45, 50, 60, 70, 80, 100]
age_labels = ['45-49', '50-59', '60-69', '70-79', '80+']
age_45plus['age_decade'] = pd.cut(age_45plus['age'], bins=age_bins, labels=age_labels, right=False)
age_decade_counts = age_45plus['age_decade'].value_counts().sort_index()
for decade, count in age_decade_counts.items():
    print(f"  {decade}: {count:3d} ({count/len(age_45plus)*100:.1f}%)")

## 5. Demographics of Age ≥45 Healthy Controls

In [None]:
# Sex distribution
print("Sex distribution (age ≥45):")
sex_counts_45 = age_45plus['sex_mapped'].value_counts()
for sex, count in sex_counts_45.items():
    print(f"  {sex}: {count} ({count/len(age_45plus)*100:.1f}%)")

# Site distribution
print("\nSite distribution (age ≥45):")
site_counts_45 = age_45plus['site'].value_counts()
for site, count in site_counts_45.items():
    print(f"  {site}: {count:3d} ({count/len(age_45plus)*100:.1f}%)")

# Handedness
print("\nHandedness distribution (age ≥45):")
hand_counts_45 = age_45plus['hand_mapped'].value_counts()
for hand, count in hand_counts_45.items():
    print(f"  {hand}: {count} ({count/len(age_45plus)*100:.1f}%)")
print(f"  Missing: {age_45plus['hand'].isna().sum()}")

## 6. Visualizations

In [None]:
# Age distribution histogram
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
axes[0].hist(age_45plus['age'], bins=20, edgecolor='black', alpha=0.7, color='steelblue')
axes[0].axvline(age_45plus['age'].mean(), color='red', linestyle='--', linewidth=2, 
                label=f'Mean: {age_45plus["age"].mean():.1f}')
axes[0].axvline(age_45plus['age'].median(), color='orange', linestyle='--', linewidth=2, 
                label=f'Median: {age_45plus["age"].median():.0f}')
axes[0].set_xlabel('Age (years)', fontsize=12)
axes[0].set_ylabel('Frequency', fontsize=12)
axes[0].set_title(f'Age Distribution (N={len(age_45plus)})', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Box plot by sex
sex_palette = {'M': '#3498db', 'F': '#e74c3c'}
sns.boxplot(data=age_45plus, x='sex_mapped', y='age', palette=sex_palette, ax=axes[1])
axes[1].set_xlabel('Sex', fontsize=12)
axes[1].set_ylabel('Age (years)', fontsize=12)
axes[1].set_title('Age Distribution by Sex', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

In [None]:
# Age decade distribution
fig, ax = plt.subplots(figsize=(10, 6))
age_decade_counts.plot(kind='bar', color='steelblue', edgecolor='black', ax=ax)
ax.set_xlabel('Age Decade', fontsize=12)
ax.set_ylabel('Number of subjects', fontsize=12)
ax.set_title('Distribution by Age Decade', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
plt.xticks(rotation=0)
for i, v in enumerate(age_decade_counts.values):
    ax.text(i, v + 2, str(v), ha='center', fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Site and sex distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Site distribution (top 10 sites)
site_counts_top = site_counts_45.head(10)
axes[0].bar(range(len(site_counts_top)), site_counts_top.values, color='steelblue', edgecolor='black')
axes[0].set_xticks(range(len(site_counts_top)))
axes[0].set_xticklabels(site_counts_top.index, rotation=45, ha='right')
axes[0].set_ylabel('Number of subjects', fontsize=12)
axes[0].set_title('Distribution by Site (Top 10)', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(site_counts_top.values):
    axes[0].text(i, v + 2, str(v), ha='center', fontweight='bold', fontsize=10)

# Sex distribution
colors = [sex_palette.get(sex, 'gray') for sex in sex_counts_45.index]
axes[1].bar(range(len(sex_counts_45)), sex_counts_45.values, color=colors, edgecolor='black')
axes[1].set_xticks(range(len(sex_counts_45)))
axes[1].set_xticklabels(sex_counts_45.index, rotation=0)
axes[1].set_ylabel('Number of subjects', fontsize=12)
axes[1].set_title('Distribution by Sex', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(sex_counts_45.values):
    axes[1].text(i, v + 2, str(v), ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
# Age distribution by site (violin plot for top 5 sites)
top5_sites = site_counts_45.head(5).index
age_45plus_top5 = age_45plus[age_45plus['site'].isin(top5_sites)]

fig, ax = plt.subplots(figsize=(12, 6))
sns.violinplot(data=age_45plus_top5, x='site', y='age', ax=ax, color='lightblue')
ax.set_xlabel('Site', fontsize=12)
ax.set_ylabel('Age (years)', fontsize=12)
ax.set_title('Age Distribution by Site (Top 5 Sites)', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()

## 7. Quality Control Metrics

SRPBS provides MRIQC quality metrics for T1w images.

In [None]:
# Load QC metrics
try:
    qc_t1w = pd.read_csv('data/SRPBS/group_T1w.tsv', sep='\t')
    print(f"T1w QC metrics available for {len(qc_t1w):,} subjects")
    print(f"\nQC metrics columns: {list(qc_t1w.columns)}")
    
    # Merge with our filtered dataset
    age_45plus_qc = age_45plus.merge(qc_t1w, left_on='participant_id', right_on='participant_id', how='left')
    print(f"\nHealthy controls (age ≥45) with QC metrics: {age_45plus_qc['cjv'].notna().sum():,}")
    
    # Display key QC metrics
    print("\nKey QC metrics (for subjects with data):")
    qc_metrics = ['cnr', 'snr_total', 'efc', 'fber', 'fwhm_avg']
    for metric in qc_metrics:
        if metric in age_45plus_qc.columns:
            vals = age_45plus_qc[metric].dropna()
            if len(vals) > 0:
                print(f"  {metric}: {vals.mean():.2f} ± {vals.std():.2f} (range: {vals.min():.2f} - {vals.max():.2f})")
    
    # Brain tissue fractions
    print("\nBrain tissue fractions:")
    tissue_metrics = ['icvs_gm', 'icvs_wm', 'icvs_csf']
    for metric in tissue_metrics:
        if metric in age_45plus_qc.columns:
            vals = age_45plus_qc[metric].dropna()
            if len(vals) > 0:
                tissue = metric.split('_')[1].upper()
                print(f"  {tissue}: {vals.mean()*100:.1f}% ± {vals.std()*100:.1f}%")
                
except FileNotFoundError:
    print("QC metrics file (group_T1w.tsv) not found in data/SRPBS/")
    print("This file contains MRIQC quality metrics for T1w images.")

## 8. Export Healthy Controls Dataset

In [None]:
# Prepare export dataframe
output_df = pd.DataFrame({
    'subject_id': age_45plus['participant_id'].values,
    'age': age_45plus['age'].values,
    'sex': age_45plus['sex_mapped'].values,
    'site': age_45plus['site'].values,
    'protocol': age_45plus['protocol'].values,
    'handedness': age_45plus['hand_mapped'].values,
    'dataset': 'SRPBS'
})

# Sort by subject_id
output_df = output_df.sort_values('subject_id').reset_index(drop=True)

# Save to CSV
output_path = 'data/SRPBS/srpbs_healthy_controls_age45plus.csv'
output_df.to_csv(output_path, index=False)

print(f"✓ Exported {len(output_df)} healthy controls to: {output_path}")
print(f"\nDataset summary:")
print(f"  Total subjects: {len(output_df)}")
print(f"  Age range: {output_df['age'].min():.0f} - {output_df['age'].max():.0f} years")
print(f"  Sex distribution: {output_df['sex'].value_counts().to_dict()}")
print(f"  Number of sites: {output_df['site'].nunique()}")
print(f"\nFirst 10 rows:")
print(output_df.head(10))

## 9. Summary Statistics

In [None]:
print("=" * 80)
print("SRPBS DATASET SUMMARY FOR NORMATIVE MODELING")
print("=" * 80)
print()
print(f"Total healthy controls (age ≥45, diag=0, with T1w MRI): {len(output_df)}")
print()
print("Demographics:")
print(f"  Age: {output_df['age'].mean():.1f} ± {output_df['age'].std():.1f} years (range: {output_df['age'].min():.0f}-{output_df['age'].max():.0f})")
print(f"  Sex: {(output_df['sex']=='F').sum()} F / {(output_df['sex']=='M').sum()} M ({(output_df['sex']=='F').sum()/len(output_df)*100:.1f}% female)")
print()
print("Age decade breakdown:")
age_bins = [45, 50, 60, 70, 80, 100]
age_labels = ['45-49', '50-59', '60-69', '70-79', '80+']
output_df['age_decade'] = pd.cut(output_df['age'], bins=age_bins, labels=age_labels, right=False)
for decade in age_labels:
    count = (output_df['age_decade'] == decade).sum()
    print(f"  {decade}: {count:3d} ({count/len(output_df)*100:.1f}%)")
print()
print("Multi-site distribution:")
print(f"  Total sites: {output_df['site'].nunique()}")
print(f"  Top 5 sites:")
for site, count in output_df['site'].value_counts().head(5).items():
    print(f"    {site}: {count} ({count/len(output_df)*100:.1f}%)")
print()
print("Geographic location:")
print("  All subjects from Japan (multiple sites across Japan)")
print()
print("Data quality:")
print("  All subjects have T1w MRI imaging")
print("  MRIQC quality metrics available for quality control")
print("  Cross-sectional study design")
print()
print("Notes:")
print("  - SRPBS is a multi-site, multi-disorder Japanese dataset")
print("  - Dataset includes resting-state fMRI and defaced T1w images")
print("  - Part of AMED DecNef Project with unified imaging protocol")
print("  - For normative modeling, consider site harmonization (e.g., ComBat)")
print("  - Important source of Asian population diversity for normative modeling")
print()
print("=" * 80)