# 06 - Temporal Decoding (EEG/MEG)

Time-resolved decoding to identify when information emerges.

**Contents:**
1. Time-point decoding
2. Sliding window analysis
3. Temporal generalization
4. Statistical significance

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt

from core.dataset import DecodingDataset
from models.temporal import TemporalDecoder, TemporalGeneralizationDecoder
from models.classifiers import SVMDecoder
from visualization.temporal_plots import plot_temporal_decoding, plot_temporal_generalization

## Create Synthetic EEG Data

In [None]:
# Simulate EEG epochs
n_epochs = 200
n_channels = 64
n_times = 200  # 1 second at 200 Hz
sfreq = 200

# Time vector
tmin, tmax = -0.2, 0.8
times = np.linspace(tmin, tmax, n_times)

# Generate random EEG-like data with class-specific signal
np.random.seed(42)

X = np.random.randn(n_epochs, n_channels, n_times) * 10  # Baseline noise
y = np.random.randint(0, 2, n_epochs)  # Binary labels

# Add class-specific signal around 200-400ms (N170-like component)
signal_window = (times > 0.15) & (times < 0.35)
for i in range(n_epochs):
    if y[i] == 1:
        X[i, :10, signal_window] += 5  # Add signal to first 10 channels

# Flatten for DecodingDataset
X_flat = X.reshape(n_epochs, n_channels * n_times)

eeg_dataset = DecodingDataset(
    X=X_flat,
    y=y,
    class_names=["condition_A", "condition_B"],
    metadata={
        "n_channels": n_channels,
        "n_times": n_times,
        "sfreq": sfreq,
        "tmin": tmin,
        "tmax": tmax
    },
    modality="eeg"
)

print(f"EEG dataset: {eeg_dataset.n_samples} epochs, {eeg_dataset.n_features} features")
print(f"Time window: {tmin}s to {tmax}s")

## 1. Time-Resolved Decoding

In [None]:
# Create temporal decoder
temporal_decoder = TemporalDecoder(
    decoder=SVMDecoder(kernel="linear"),
    time_window=0.05,  # 50ms window
    step=0.02,         # 20ms step
    n_jobs=-1,
    verbose=1
)

# Fit
temporal_decoder.fit(eeg_dataset, times=times)

# Get results
time_points, scores = temporal_decoder.get_temporal_scores()

print(f"\nDecoded {len(time_points)} time points")
print(f"Peak accuracy: {np.max(scores):.1%} at {time_points[np.argmax(scores)]:.3f}s")

In [None]:
# Plot temporal decoding
temporal_decoder.plot(
    chance_level=0.5,
    show_std=True,
    title="Temporal Decoding"
)

## 2. Finding Significant Time Windows

In [None]:
# Find periods of significant decoding
significant_periods = temporal_decoder.get_significant_periods(
    chance_level=0.5,
    alpha=0.05
)

print("Significant time periods:")
for start, end in significant_periods:
    print(f"  {start:.3f}s to {end:.3f}s")

In [None]:
# Plot with significant periods highlighted
plot_temporal_decoding(
    times=time_points,
    scores=scores,
    scores_std=temporal_decoder.scores_std_,
    chance_level=0.5,
    significant_times=significant_periods,
    title="Temporal Decoding with Significant Periods"
)

## 3. Temporal Generalization

Test whether patterns learned at one time generalize to other times.

In [None]:
# Note: This requires MNE Epochs object
# For now, we'll simulate the matrix

# Simulated temporal generalization matrix
n_timepoints = 40
gen_times = np.linspace(tmin, tmax, n_timepoints)

# Create matrix with diagonal structure
gen_matrix = np.zeros((n_timepoints, n_timepoints))
for i in range(n_timepoints):
    for j in range(n_timepoints):
        # Strong diagonal + some off-diagonal generalization
        distance = abs(i - j)
        gen_matrix[i, j] = 0.5 + 0.3 * np.exp(-distance / 5)

# Add noise
gen_matrix += np.random.randn(n_timepoints, n_timepoints) * 0.05
gen_matrix = np.clip(gen_matrix, 0, 1)

print(f"Generalization matrix shape: {gen_matrix.shape}")

In [None]:
# Plot temporal generalization
plot_temporal_generalization(
    scores=gen_matrix,
    times=gen_times,
    chance_level=0.5,
    title="Temporal Generalization Matrix"
)

## 4. Using MNE for Temporal Decoding

With real MNE epochs, use `fit_mne()` for optimal results.

In [None]:
# Example with MNE (uncomment with real data)

# import mne

# # Load epochs
# epochs = mne.read_epochs("sub-01-epo.fif", preload=True)

# # Temporal decoder
# temporal = TemporalDecoder(
#     decoder=SVMDecoder(kernel="linear"),
#     time_window=0.05
# )

# # Fit directly from epochs
# temporal.fit_mne(epochs)

# # Plot
# temporal.plot(chance_level=0.5)

print("MNE integration provides optimal temporal decoding.")

In [None]:
# MNE's built-in decoding tools (example)

# from mne.decoding import SlidingEstimator, GeneralizingEstimator, cross_val_multiscore
# from sklearn.pipeline import make_pipeline
# from sklearn.preprocessing import StandardScaler
# from sklearn.svm import SVC

# # Create classifier pipeline
# clf = make_pipeline(StandardScaler(), SVC(kernel='linear'))

# # Sliding estimator (time-point decoding)
# time_decod = SlidingEstimator(clf, scoring='accuracy', n_jobs=-1)

# # Cross-validate
# scores = cross_val_multiscore(time_decod, X, y, cv=5)
# mean_scores = scores.mean(axis=0)

print("MNE provides SlidingEstimator and GeneralizingEstimator.")

## 5. Comparing Conditions

In [None]:
# Compare temporal decoding across different analyses
from visualization.temporal_plots import plot_temporal_comparison

# Example comparison data
results_list = [
    {
        'times': time_points,
        'scores': scores,
        'scores_std': temporal_decoder.scores_std_
    },
    {
        'times': time_points,
        'scores': scores * 0.9,  # Simulated second condition
        'scores_std': temporal_decoder.scores_std_
    }
]

plot_temporal_comparison(
    results_list,
    labels=['Condition A', 'Condition B'],
    chance_level=0.5,
    title='Temporal Decoding Comparison'
)

## Key Insights from Temporal Decoding

1. **Onset latency**: When does classification become significant?
2. **Peak accuracy**: When is information maximal?
3. **Duration**: How long does information persist?
4. **Temporal generalization**: Is the same code used over time?

## Next Steps

- **07_group_analysis.ipynb**: Multi-subject group analysis