# Regression: Predicting Continuous Outcomes

This notebook demonstrates how to use TESLearn for **continuous outcomes** (e.g., symptom improvement scores, cognitive enhancement) rather than binary classification.

## Approach Overview

**Feature Extraction**: Atlas-based ROIs  
**Feature Selection**: F-regression (for continuous targets)  
**Model**: Elastic Net (combines Ridge and Lasso)  
**Validation**: K-Fold Cross-Validation

## When to Use Regression

- **Symptom severity**: Predicting degree of improvement (not just responder vs. non-responder)
- **Cognitive enhancement**: Memory, attention, executive function scores
- **Dose-response**: Relationship between E-field intensity and clinical outcome
- **Personalized dosing**: Predicting optimal stimulation parameters

## Why Elastic Net?

Elastic Net combines the benefits of Ridge (L2) and Lasso (L1) regularization:
- **Feature selection**: L1 component zeros out unimportant features
- **Stability**: L2 component handles correlated features better than pure Lasso
- **Interpretability**: Sparse solutions with selected features

In [None]:
import teslearn as tl
from teslearn.data import load_dataset_from_csv, NiftiLoader
from teslearn.features import AtlasFeatureExtractor
from teslearn.models import ElasticNetModel
from teslearn.selection import FRegressionSelector
from teslearn.cv import KFoldValidator

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

## 1. Data Loading

**Key difference**: Set `task='regression'` to specify continuous targets. The CSV should have a continuous column for the target variable.

In [None]:
# Load dataset with continuous targets
dataset = load_dataset_from_csv(
    csv_path='data/subjects_regression.csv',
    target_col='improvement_score',  # Continuous variable
    task='regression'
)

loader = NiftiLoader()
images, indices = loader.load_dataset_images(dataset)
y = dataset.get_targets()

print(f"Dataset: {len(images)} subjects")
print(f"\nTarget distribution:")
print(f"  Mean: {y.mean():.2f}")
print(f"  Std: {y.std():.2f}")
print(f"  Range: [{y.min():.2f}, {y.max():.2f}]")

## 2. Exploratory Data Analysis

Always visualize your target distribution before modeling. Check for:
- Normality (for parametric methods)
- Outliers (may need robust methods)
- Range (consider standardization)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Histogram
axes[0].hist(y, bins=15, edgecolor='black', alpha=0.7)
axes[0].axvline(y.mean(), color='red', linestyle='--', label=f'Mean: {y.mean():.2f}')
axes[0].set_xlabel('Improvement Score')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Target Distribution')
axes[0].legend()

# Q-Q plot for normality
stats.probplot(y, dist="norm", plot=axes[1])
axes[1].set_title('Q-Q Plot (Normality Check)')

plt.tight_layout()
plt.show()

# Shapiro-Wilk test for normality
shapiro_stat, shapiro_p = stats.shapiro(y)
print(f"\nShapiro-Wilk test: W={shapiro_stat:.4f}, p={shapiro_p:.4f}")
print(f"Distribution is {'normal' if shapiro_p > 0.05 else 'non-normal'} (α=0.05)")

## 3. Feature Extraction

Same as classification, but we'll use different selection and modeling approaches appropriate for continuous targets.

In [None]:
# Extract atlas-based features
extractor = AtlasFeatureExtractor(
    atlas_path='data/atlas/HCP-MMP1.nii.gz',
    statistics=['mean', 'std'],  # Standard deviation captures variability
    top_percentile=90.0
)

X = extractor.fit_transform(images)
print(f"Feature matrix: {X.shape}")

## 4. Feature Selection for Regression

**F-regression** performs univariate linear regression tests and selects features with highest F-scores.

**Why not T-test?** T-tests compare group means (binary), while F-regression tests linear relationships (continuous).

In [None]:
# F-regression for continuous targets
selector = FRegressionSelector(
    p_threshold=0.01,      # Less conservative than classification
    correction='fdr'       # False Discovery Rate control
)

# Test selection
X_selected = selector.fit_transform(X, y)
print(f"Selected {X_selected.shape[1]} / {X.shape[1]} features")
print(f"Reduction: {100 * (1 - X_selected.shape[1]/X.shape[1]):.1f}%")

## 5. Elastic Net Configuration

**Elastic Net formula**: L1_ratio controls the mix:
- `l1_ratio=1.0`: Pure Lasso (sparse solutions)
- `l1_ratio=0.0`: Pure Ridge (all features, small coefficients)
- `l1_ratio=0.5`: Balanced mix (recommended starting point)

**Alpha**: Overall regularization strength (higher = more regularization)

In [None]:
# Configure Elastic Net
model = ElasticNetModel(
    alpha=0.5,             # Regularization strength
    l1_ratio=0.5,          # Balance between L1 and L2
    max_iter=2000,         # May need more iterations
    random_state=42
)

print("Elastic Net Configuration:")
print(f"  Alpha: {model.alpha}")
print(f"  L1 ratio: {model.l1_ratio}")
print(f"  Regularization: {'Lasso-like' if model.l1_ratio > 0.7 else 'Ridge-like' if model.l1_ratio < 0.3 else 'Balanced'}")

## 6. Training with Cross-Validation

For regression, we use:
- **R² score**: Explained variance (0 = mean predictor, 1 = perfect prediction)
- **MAE**: Mean Absolute Error (interpretable in original units)
- **RMSE**: Root Mean Squared Error (punishes large errors)

In [None]:
# K-Fold CV for regression
outer_cv = KFoldValidator(n_splits=5, shuffle=True, random_state=42)
inner_cv = KFoldValidator(n_splits=3, shuffle=True, random_state=42)

# Train model
result = tl.train_model(
    images=images,
    y=y,
    feature_extractor=extractor,
    model=model,
    feature_selector=selector,
    outer_validator=outer_cv,
    inner_validator=inner_cv,
    use_scaling=True
)

print(result.get_summary())

## 7. Visualizing Predictions

For regression, visualize:
1. Predicted vs Actual scatter plot (should follow diagonal)
2. Residual distribution (should be centered at 0)
3. Feature importance (coefficients)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Get predictions from CV
if hasattr(result, 'all_y_true') and hasattr(result, 'all_y_pred'):
    y_true = result.all_y_true
    y_pred = result.all_y_pred
    
    # 1. Predicted vs Actual
    axes[0, 0].scatter(y_true, y_pred, alpha=0.6, edgecolors='black', linewidth=0.5)
    min_val, max_val = min(y_true.min(), y_pred.min()), max(y_true.max(), y_pred.max())
    axes[0, 0].plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Perfect prediction')
    axes[0, 0].set_xlabel('Actual Improvement Score')
    axes[0, 0].set_ylabel('Predicted Improvement Score')
    axes[0, 0].set_title(f'Predicted vs Actual (R² = {result.mean_r2:.3f})')
    axes[0, 0].legend()
    
    # Add correlation
    corr, p_val = stats.pearsonr(y_true, y_pred)
    axes[0, 0].text(0.05, 0.95, f'r = {corr:.3f}, p = {p_val:.4f}', 
                    transform=axes[0, 0].transAxes, verticalalignment='top',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # 2. Residuals
    residuals = y_true - y_pred
    axes[0, 1].scatter(y_pred, residuals, alpha=0.6, edgecolors='black', linewidth=0.5)
    axes[0, 1].axhline(y=0, color='r', linestyle='--')
    axes[0, 1].set_xlabel('Predicted Values')
    axes[0, 1].set_ylabel('Residuals (Actual - Predicted)')
    axes[0, 1].set_title('Residual Plot')
    
    # 3. Residual distribution
    axes[1, 0].hist(residuals, bins=15, edgecolor='black', alpha=0.7)
    axes[1, 0].axvline(x=0, color='red', linestyle='--')
    axes[1, 0].set_xlabel('Residual Value')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].set_title(f'Residual Distribution (Mean: {residuals.mean():.3f})')

# 4. Feature importance from fitted model
result.pipeline.fit(images, y)
importance = result.pipeline.get_feature_importance()

if importance:
    # Sort by absolute importance
    sorted_imp = sorted(importance.items(), key=lambda x: abs(x[1]), reverse=True)[:15]
    names, values = zip(*sorted_imp)
    
    colors = ['green' if v > 0 else 'red' for v in values]
    axes[1, 1].barh(range(len(values)), values, color=colors, alpha=0.7)
    axes[1, 1].set_yticks(range(len(values)))
    axes[1, 1].set_yticklabels(names, fontsize=8)
    axes[1, 1].set_xlabel('Coefficient Value')
    axes[1, 1].set_title('Top 15 Features (Elastic Net Coefficients)')
    axes[1, 1].invert_yaxis()

plt.tight_layout()
plt.show()

## 8. Interpreting Coefficients

Elastic Net coefficients tell us:
- **Magnitude**: How strongly the ROI predicts outcome
- **Sign**: Positive = higher E-field predicts better outcome
- **Zero**: Feature was excluded by L1 regularization

**Important**: Coefficients are in standardized units (due to scaling).

In [None]:
# Get non-zero coefficients
nonzero = {k: v for k, v in importance.items() if abs(v) > 1e-6}
print(f"\nModel used {len(nonzero)} / {len(importance)} features")
print(f"Sparsity: {100 * (1 - len(nonzero)/len(importance)):.1f}%\n")

print("Top 10 predictive regions:")
for i, (region, coef) in enumerate(sorted_imp[:10], 1):
    direction = "↑" if coef > 0 else "↓"
    print(f"{i:2d}. {region}: {coef:+.4f} {direction}")

print("\nInterpretation:")
print("  ↑ Positive coefficients: Higher E-field in this region → Better outcome")
print("  ↓ Negative coefficients: Higher E-field in this region → Worse outcome")

## Key Takeaways

1. **Use F-regression** for continuous targets (not T-test)
2. **Elastic Net** provides both prediction and feature selection
3. **Check residuals** to validate model assumptions
4. **R² interpretation**: Values < 0.1 are common in neuroimaging; focus on cross-validation stability

## Comparison: Classification vs Regression

| Aspect | Classification | Regression |
|--------|---------------|------------|
| Target | Binary (0/1) | Continuous |
| Selection | T-test | F-regression |
| Metric | Accuracy, AUC | R², MAE, RMSE |
| Interpretation | Feature importance | Coefficients (direction matters) |
| Use case | Responder prediction | Symptom severity, dose optimization |

## Best Practices for Regression

- Always check target distribution (transform if highly skewed)
- Report both R² and MAE/RMSE
- Validate residuals are normally distributed
- Consider robust regression for outliers
- Use L1_ratio tuning if many correlated features