# 02 - fMRI ROI Classification

Per-region classification using HCP atlas parcellation.

**Based on**: B_TMAZE_FMRI_STANDARD_CLASS notebooks

**Key Analysis:**
1. Load GLM beta values for each ROI
2. Classify reward vs. no-reward per ROI
3. Identify top discriminative regions
4. Network-level analysis

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt

# T-Maze analysis modules
from core.containers import TMAzefMRIData
from io.loaders import TMazefMRILoader
from classification.classifiers import (
    TMazeClassifier,
    classify_roi,
    classify_all_rois,
    get_top_rois,
    permutation_test
)
from visualization.plots import plot_roi_accuracies, plot_confusion_matrix

## 1. Load fMRI Data

Load GLM beta values using HCP atlas parcellation.

In [None]:
# Example loader configuration
# Adjust for your atlas and data paths

condition_mapping = {
    'MazeReward': 1,
    'NoMazeReward': 1,
    'MazeNoReward': 0,
    'NoMazeNoReward': 0
}

# loader = TMazefMRILoader(
#     atlas_path="/path/to/HCP_MMP1_atlas.nii.gz",
#     standardize=True,
#     condition_mapping=condition_mapping
# )

# Load data
# fmri_data = loader.load_afni_betas("/path/to/glm_results/", subject_id="sub-01")

print("Loader configured. Ready to load fMRI data.")

## 2. Create Simulated Data (Demo)

In [None]:
# Simulate fMRI data for demo
np.random.seed(42)

n_trials = 120  # 30 per condition
n_rois = 426    # HCP atlas

# Generate random data
X = np.random.randn(n_trials, n_rois)

# Create labels (equal per condition)
y = np.array([0, 1, 0, 1] * 30)  # Alternating reward/no-reward

# Add signal to some ROIs for reward condition
discriminative_rois = [10, 25, 50, 100, 150]  # Some ROIs carry reward info
for roi_idx in discriminative_rois:
    X[y == 1, roi_idx] += 0.8  # Reward signal

# Create ROI names
roi_names = [f'ROI_{i:03d}' for i in range(n_rois)]
# Name some as network regions
roi_names[10] = 'DMN_mPFC'
roi_names[25] = 'SAL_ACC'
roi_names[50] = 'FPN_dlPFC_L'
roi_names[100] = 'VIS_V1'
roi_names[150] = 'MOT_M1_L'

# Create data container
fmri_data = TMAzefMRIData(
    data=X,
    labels=y,
    condition_names=['NoReward', 'Reward'],
    roi_names=roi_names,
    subject_id='demo',
    atlas_name='HCP_426'
)

print(fmri_data)

## 3. Single ROI Classification

In [None]:
# Classify using a single discriminative ROI
roi_name = 'DMN_mPFC'
roi_data = fmri_data.get_roi(roi_name)

result = classify_roi(
    roi_data,
    fmri_data.labels,
    classifier_type='lda',
    cv=5
)

print(f"{roi_name}:")
print(f"  Accuracy: {result.accuracy:.1%} (+/- {result.accuracy_std:.1%})")
print(f"  AUC: {result.auc:.3f}" if result.auc else "")

## 4. Classify All ROIs

In [None]:
# Run classification for each ROI
print("Classifying all ROIs...")
roi_results = classify_all_rois(
    fmri_data,
    classifier_type='lda',
    cv=5,
    verbose=True
)

print(f"\nCompleted classification for {len(roi_results)} ROIs")

In [None]:
# Get top performing ROIs
top_rois = get_top_rois(roi_results, n_top=20, min_accuracy=0.55)

print("Top 20 Discriminative ROIs:")
print("=" * 50)
for i, (roi_name, result) in enumerate(top_rois):
    print(f"{i+1:2d}. {roi_name:20s}: {result.accuracy:.1%} (+/- {result.accuracy_std:.1%})")

In [None]:
# Plot top ROI accuracies
plot_roi_accuracies(roi_results, n_top=20)
plt.show()

## 5. Compare Classifiers

In [None]:
# Compare LDA vs SVM on top ROI
top_roi_name = top_rois[0][0]
top_roi_data = fmri_data.get_roi(top_roi_name)

classifiers = ['lda', 'svm', 'logistic', 'rf']
clf_results = {}

for clf_type in classifiers:
    result = classify_roi(
        top_roi_data,
        fmri_data.labels,
        classifier_type=clf_type,
        cv=5
    )
    clf_results[clf_type.upper()] = result
    print(f"{clf_type.upper():10s}: {result.accuracy:.1%} (+/- {result.accuracy_std:.1%})")

## 6. Permutation Testing

In [None]:
# Permutation test for top ROI
top_roi_name = top_rois[0][0]
top_roi_data = fmri_data.get_roi(top_roi_name).reshape(-1, 1)

observed, p_value, null_dist = permutation_test(
    top_roi_data,
    fmri_data.labels,
    classifier_type='lda',
    n_permutations=100,  # Use 1000 for publication
    cv=5
)

print(f"Permutation Test for {top_roi_name}:")
print(f"  Observed accuracy: {observed:.1%}")
print(f"  P-value: {p_value:.4f}")
print(f"  Significant (p<0.05): {p_value < 0.05}")

In [None]:
# Plot null distribution
fig, ax = plt.subplots(figsize=(8, 5))

ax.hist(null_dist, bins=30, color='gray', alpha=0.7, label='Null distribution')
ax.axvline(observed, color='red', linewidth=2, linestyle='--', 
           label=f'Observed: {observed:.1%}')
ax.axvline(0.5, color='black', linestyle=':', label='Chance')

ax.set_xlabel('Accuracy')
ax.set_ylabel('Count')
ax.set_title(f'Permutation Test: {top_roi_name} (p={p_value:.4f})')
ax.legend()

plt.tight_layout()
plt.show()

## 7. Network Analysis

In [None]:
# Summarize results by network
networks = ['DMN', 'SAL', 'FPN', 'VIS', 'MOT']

network_summary = {}
for network in networks:
    # Find ROIs belonging to this network
    network_rois = [name for name in roi_names if network in name]
    
    if network_rois:
        network_accs = [roi_results[roi].accuracy for roi in network_rois 
                       if roi in roi_results]
        if network_accs:
            network_summary[network] = {
                'mean': np.mean(network_accs),
                'max': np.max(network_accs),
                'n_rois': len(network_accs)
            }

print("Network Summary:")
print("=" * 50)
for network, stats in network_summary.items():
    print(f"{network:5s}: Mean={stats['mean']:.1%}, Max={stats['max']:.1%} (n={stats['n_rois']})")

## 8. Multi-ROI Classification

Combine top ROIs for improved classification.

In [None]:
# Use top 10 ROIs together
top_10_names = [name for name, _ in top_rois[:10]]
X_top10 = fmri_data.get_rois(top_10_names)

combined_result = classify_roi(
    X_top10,
    fmri_data.labels,
    classifier_type='svm',
    cv=5
)

print(f"Combined Top 10 ROIs:")
print(f"  Accuracy: {combined_result.accuracy:.1%} (+/- {combined_result.accuracy_std:.1%})")

In [None]:
# Confusion matrix
plot_confusion_matrix(
    combined_result,
    class_names=['No Reward', 'Reward'],
    title='Top 10 ROIs Combined'
)
plt.show()

## Summary

This notebook demonstrated:
1. Loading fMRI ROI data with HCP atlas
2. Per-ROI classification (LDA, SVM)
3. Identifying top discriminative regions
4. Permutation testing for significance
5. Multi-ROI classification

### Key Findings Template:
- Top ROI: XXX with XX% accuracy
- Networks with highest decoding: DMN, SAL, FPN
- Combined top-10 accuracy: XX%

### Next Steps:
- **03_multimodal_fusion.ipynb**: Combine EEG and fMRI