# 03 - Training Classifiers

This notebook demonstrates training different classifiers for neural decoding.

**Contents:**
1. SVM Decoder
2. Random Forest Decoder
3. Logistic Regression
4. LDA Decoder
5. Ensemble Methods
6. Comparing Classifiers

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification

from core.dataset import DecodingDataset
from models.classifiers import (
    SVMDecoder, 
    RandomForestDecoder, 
    LogisticDecoder, 
    LDADecoder,
    EnsembleDecoder
)

## Create Synthetic Data

In [None]:
# Create realistic neuroimaging-like data
X, y = make_classification(
    n_samples=200,
    n_features=1000,
    n_informative=50,
    n_redundant=50,
    n_classes=2,
    class_sep=0.8,  # Some overlap between classes
    random_state=42
)

groups = np.repeat(np.arange(1, 6), 40)  # 5 runs

dataset = DecodingDataset(
    X=X,
    y=y,
    groups=groups,
    class_names=["class_A", "class_B"],
    modality="fmri"
)

print(f"Dataset: {dataset.n_samples} samples, {dataset.n_features} features")

## 1. SVM Decoder

Linear SVM is the default choice for high-dimensional neuroimaging data.

In [None]:
# Linear SVM (recommended for fMRI)
svm_linear = SVMDecoder(
    kernel="linear",
    C=1.0,
    standardize=True
)

# Cross-validate
results_svm = svm_linear.cross_validate(dataset)

print(f"SVM Linear: {results_svm.accuracy:.1%} (+/- {results_svm.cv_std:.1%})")
print(f"Per-class: {results_svm.accuracy_per_class}")

In [None]:
# RBF SVM (for non-linear patterns)
svm_rbf = SVMDecoder(
    kernel="rbf",
    C=1.0,
    gamma="scale"
)

results_rbf = svm_rbf.cross_validate(dataset)
print(f"SVM RBF: {results_rbf.accuracy:.1%}")

In [None]:
# Feature importances (weights) for linear SVM
svm_linear.fit(dataset.X, dataset.y)
weights = svm_linear.get_feature_importances()

plt.figure(figsize=(10, 4))
plt.hist(weights, bins=50)
plt.xlabel('Weight Magnitude')
plt.ylabel('Count')
plt.title('SVM Feature Weights Distribution')
plt.show()

## 2. Random Forest Decoder

Good for non-linear patterns and interpretable feature importances.

In [None]:
# Random Forest
rf_decoder = RandomForestDecoder(
    n_estimators=100,
    max_depth=None,
    n_jobs=-1
)

results_rf = rf_decoder.cross_validate(dataset)
print(f"Random Forest: {results_rf.accuracy:.1%}")

In [None]:
# Feature importances (Gini)
rf_decoder.fit(dataset.X, dataset.y)
importances = rf_decoder.get_feature_importances()

# Top 20 features
top_idx = np.argsort(importances)[::-1][:20]

plt.figure(figsize=(10, 6))
plt.barh(range(20), importances[top_idx][::-1])
plt.yticks(range(20), [f"Feature {i}" for i in top_idx[::-1]])
plt.xlabel('Gini Importance')
plt.title('Top 20 Features (Random Forest)')
plt.tight_layout()
plt.show()

## 3. Logistic Regression

In [None]:
# Logistic Regression with L2 penalty
lr_decoder = LogisticDecoder(
    C=1.0,
    penalty="l2",
    max_iter=1000
)

results_lr = lr_decoder.cross_validate(dataset)
print(f"Logistic Regression: {results_lr.accuracy:.1%}")

In [None]:
# L1 penalty for sparse solutions
lr_sparse = LogisticDecoder(
    C=0.1,
    penalty="l1"
)

results_sparse = lr_sparse.cross_validate(dataset)
print(f"Logistic (L1 sparse): {results_sparse.accuracy:.1%}")

## 4. LDA Decoder

In [None]:
# Linear Discriminant Analysis
lda_decoder = LDADecoder(
    solver="svd",
    shrinkage=None
)

results_lda = lda_decoder.cross_validate(dataset)
print(f"LDA: {results_lda.accuracy:.1%}")

## 5. Ensemble Methods

Combine multiple classifiers for improved performance.

In [None]:
# Ensemble with soft voting
ensemble = EnsembleDecoder(
    decoders=[
        SVMDecoder(kernel="linear"),
        RandomForestDecoder(n_estimators=100),
        LogisticDecoder()
    ],
    voting="soft",  # Average probabilities
    weights=None    # Equal weights
)

results_ensemble = ensemble.cross_validate(dataset)
print(f"Ensemble: {results_ensemble.accuracy:.1%}")

In [None]:
# Weighted ensemble
ensemble_weighted = EnsembleDecoder(
    decoders=[
        SVMDecoder(kernel="linear"),
        RandomForestDecoder(n_estimators=100),
        LogisticDecoder()
    ],
    voting="soft",
    weights=[2, 1, 1]  # Give SVM more weight
)

results_weighted = ensemble_weighted.cross_validate(dataset)
print(f"Weighted Ensemble: {results_weighted.accuracy:.1%}")

## 6. Comparing Classifiers

In [None]:
# Collect all results
results = {
    'SVM Linear': results_svm,
    'SVM RBF': results_rbf,
    'Random Forest': results_rf,
    'Logistic L2': results_lr,
    'LDA': results_lda,
    'Ensemble': results_ensemble
}

# Plot comparison
fig, ax = plt.subplots(figsize=(10, 5))

names = list(results.keys())
accuracies = [r.accuracy for r in results.values()]
stds = [r.cv_std for r in results.values()]

x = np.arange(len(names))
bars = ax.bar(x, accuracies, yerr=stds, capsize=5, color='steelblue', edgecolor='black')

ax.axhline(y=0.5, color='red', linestyle='--', label='Chance')
ax.set_xticks(x)
ax.set_xticklabels(names, rotation=45, ha='right')
ax.set_ylabel('Accuracy')
ax.set_title('Classifier Comparison')
ax.set_ylim(0, 1)
ax.legend()

# Add value labels
for bar, acc in zip(bars, accuracies):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
            f'{acc:.1%}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

In [None]:
# Detailed summary
print("Classifier Comparison")
print("="*50)
for name, result in results.items():
    print(f"{name:20s}: {result.accuracy:.1%} (+/- {result.cv_std:.1%})")

## 7. Plotting Results

In [None]:
# Confusion matrix for best classifier
best_result = results_svm
best_result.plot_confusion_matrix(normalize=True)

In [None]:
# CV scores
best_result.plot_cv_scores()

## Next Steps

- **04_cross_validation.ipynb**: Advanced CV strategies
- **05_searchlight.ipynb**: Whole-brain searchlight analysis