# Basic Atlas-Based Classification

This notebook demonstrates the standard TESLearn workflow for binary classification using atlas-based features.

## Approach Overview

**Feature Extraction**: Atlas-based ROIs (Regions of Interest)  
**Feature Selection**: T-test (univariate statistical test)  
**Model**: Logistic Regression with L2 regularization  
**Validation**: Nested Cross-Validation

## Why This Works

Atlas-based features aggregate electric field intensities within anatomically defined brain regions. This:
- Reduces dimensionality (hundreds of ROIs vs millions of voxels)
- Provides anatomically interpretable features
- Is robust to small sample sizes common in TES studies

T-test feature selection identifies ROIs that show statistically significant differences between responders and non-responders.

In [None]:
import teslearn as tl
from teslearn.data import load_dataset_from_csv, NiftiLoader
from teslearn.features import AtlasFeatureExtractor
from teslearn.models import LogisticRegressionModel
from teslearn.selection import TTestSelector
from teslearn.plotting import plot_roc_curve, plot_confusion_matrix

print(f"TESLearn version: {tl.__version__}")

## 1. Data Loading

**Best Practice**: Always validate your CSV format before loading. Required columns:
- `subject_id`: Unique identifier
- `simulation_name`: Stimulation configuration
- `response`: Binary target (0/1)

The `NiftiLoader` handles BIDS-style organization automatically.

In [None]:
# Load dataset metadata
dataset = load_dataset_from_csv(
    csv_path='data/subjects.csv',
    efield_base_dir='data/derivatives/efields/',
    target_col='response',
    task='classification'
)

print(f"Dataset: {dataset.n_subjects} subjects")
print(f"Responders: {sum(dataset.get_targets())}")
print(f"Non-responders: {len(dataset.get_targets()) - sum(dataset.get_targets())}")

# Load actual NIfTI images
loader = NiftiLoader()
images, indices = loader.load_dataset_images(dataset)
y = dataset.get_targets()

print(f"\nLoaded {len(images)} images")

## 2. Feature Extraction

**Atlas Selection**: The Glasser HCP atlas provides 360 cortical parcels. For subcortical structures, consider Harvard-Oxford.

**Statistics**: Multiple statistics capture different aspects of the E-field distribution:
- `mean`: Average intensity (most stable)
- `max`: Peak intensity (catches focal hotspots)
- `top10mean`: Mean of top 10% (robust to outliers)

In [None]:
# Configure atlas-based feature extraction
extractor = AtlasFeatureExtractor(
    atlas_path='data/atlas/HCP-MMP1.nii.gz',
    statistics=['mean', 'max', 'top10mean'],
    top_percentile=90.0
)

# Extract features (this may take a moment)
X = extractor.fit_transform(images)

print(f"Feature matrix shape: {X.shape}")
print(f"Features per ROI: {len(extractor.statistics)}")
print(f"Total features: {X.shape[1]}")

## 3. Model Configuration

**Logistic Regression** is the default choice because:
- Interpretable coefficients (feature importance)
- Well-calibrated probabilities
- Fast training and prediction
- Works well with L2 regularization

**Regularization**: `C=1.0` is a good starting point. Lower values = stronger regularization (useful for small samples).

In [None]:
# Configure feature selection
selector = TTestSelector(
    p_threshold=0.001,  # Conservative threshold
    correction=None     # Can use 'bonferroni' or 'fdr' for multiple comparisons
)

# Configure logistic regression
model = LogisticRegressionModel(
    C=1.0,              # Inverse regularization strength
    penalty='l2',       # Ridge regularization
    solver='lbfgs',     # Efficient for small-medium datasets
    class_weight='balanced',  # Handle class imbalance
    max_iter=1000,
    random_state=42
)

## 4. Training with Nested Cross-Validation

**Why Nested CV?**
- Outer loop: Unbiased performance estimation
- Inner loop: Feature selection and model fitting
- Prevents data leakage from feature selection

**Best Practice**: Use stratified CV to maintain class proportions in each fold.

In [None]:
# Train with nested cross-validation
result = tl.train_model(
    images=images,
    y=y,
    feature_extractor=extractor,
    model=model,
    feature_selector=selector,
    outer_folds=5,      # 5-fold outer CV
    inner_folds=3,      # 3-fold inner CV for feature selection
    use_scaling=True    # Standardize features
)

# Display results
print(result.get_summary())

## 5. Model Interpretation

Understanding which brain regions drive predictions is crucial for TES studies. We use the `explain_model` function to:
- Extract feature importance from model coefficients
- Map back to atlas regions
- Generate weight maps for visualization

In [None]:
from teslearn.explain import explain_model

# Generate model explanation
explanation = explain_model(
    pipeline=result.pipeline,
    atlas_path='data/atlas/HCP-MMP1.nii.gz',
    create_weight_maps=True,
    output_dir='./output/atlas_explanation'
)

print("\nTop 10 predictive regions:")
for i, (region, importance) in enumerate(explanation.top_positive[:10], 1):
    print(f"{i:2d}. {region}: {importance:.4f}")

## 6. Visualization

Plot evaluation metrics to assess model performance.

In [None]:
import matplotlib.pyplot as plt

# Plot ROC curve
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# ROC Curve (using cross-validated scores)
if hasattr(result, 'all_y_true') and hasattr(result, 'all_y_score'):
    from sklearn.metrics import roc_curve, auc
    fpr, tpr, _ = roc_curve(result.all_y_true, result.all_y_score)
    roc_auc = auc(fpr, tpr)
    
    axes[0].plot(fpr, tpr, lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
    axes[0].plot([0, 1], [0, 1], 'k--', lw=1)
    axes[0].set_xlabel('False Positive Rate')
    axes[0].set_ylabel('True Positive Rate')
    axes[0].set_title('ROC Curve')
    axes[0].legend()

# Feature importance
top_features = explanation.top_positive[:15]
axes[1].barh(range(len(top_features)), [imp for _, imp in top_features])
axes[1].set_yticks(range(len(top_features)))
axes[1].set_yticklabels([name for name, _ in top_features], fontsize=8)
axes[1].set_xlabel('Feature Importance')
axes[1].set_title('Top Predictive Regions')
axes[1].invert_yaxis()

plt.tight_layout()
plt.show()

## Key Takeaways

1. **Atlas features** provide anatomically interpretable, low-dimensional representations
2. **T-test selection** reduces features while maintaining statistical rigor
3. **Nested CV** prevents overfitting and gives unbiased performance estimates
4. **Model interpretation** reveals which brain regions drive predictions

## Next Steps

- Try different atlases (AAL, Harvard-Oxford)
- Experiment with voxel-based features for finer spatial resolution
- Use permutation tests to validate significance of performance