02 - Preprocessing Pipeline

1. Load raw GDF data
2. Bandpass filter (8-30 Hz) - isolate mu (8-12) and beta (13-30) rhythms, the neural signatures of motor imagery
3. Epoch extraction around cues - capture the motor imagery period after cue, skip initial reaction time
4. Artifact rejection - remove trials contaminated by muscle/eye movements
5. Quality assessment
6. Save processed data

In [None]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import mne
from tqdm.notebook import tqdm
import pickle

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / 'src'))

from preprocessing import (
    load_raw_gdf, apply_bandpass_filter, extract_events, create_epochs,
    preprocess_subject, CHANNEL_NAMES, CLASS_LABELS
)
from visualization import set_style, CLASS_NAMES

set_style()
mne.set_log_level('WARNING')

DATA_DIR = Path('../data/raw')
PROCESSED_DIR = Path('../data/processed')
PROCESSED_DIR.mkdir(exist_ok=True)

print(f"MNE version: {mne.__version__}")

In [None]:
# Load subject 1 training data
SUBJECT = 1
gdf_path = DATA_DIR / f'A0{SUBJECT}T.gdf'

if not gdf_path.exists():
    raise FileNotFoundError(f"Data not found. Run: python scripts/download_data.py")

raw = load_raw_gdf(gdf_path)
print(f"Loaded: {gdf_path.name}")
print(f"Duration: {raw.times[-1]:.1f}s | Channels: {len(raw.ch_names)} | Sfreq: {raw.info['sfreq']} Hz")

- Mu rhythm (8-12 Hz): Suppressed during motor imagery on contralateral side
- Beta rhythm (13-30 Hz): Motor planning and execution

We exclude
- < 8 Hz: Movement artifacts, eye blinks, drift
- > 30 Hz: Muscle artifacts (EMG), line noise

In [None]:
# Compare before and after filtering
L_FREQ = 8.0
H_FREQ = 30.0

raw_filt = apply_bandpass_filter(raw, l_freq=L_FREQ, h_freq=H_FREQ)

# Plot PSD comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

raw.compute_psd(fmax=60).plot(axes=axes[0], show=False, average=True)
axes[0].set_title('Before Filtering')
axes[0].axvspan(8, 30, alpha=0.2, color='green', label='Passband')

raw_filt.compute_psd(fmax=60).plot(axes=axes[1], show=False, average=True)
axes[1].set_title('After Bandpass (8-30 Hz)')
axes[1].axvspan(8, 30, alpha=0.2, color='green', label='Passband')

plt.tight_layout();

In [None]:
# Time-domain comparison at C3 (motor cortex)
fig, axes = plt.subplots(2, 1, figsize=(14, 5), sharex=True)

# Get 5 seconds of data
start, stop = int(10 * raw.info['sfreq']), int(15 * raw.info['sfreq'])
times = raw.times[start:stop]

c3_idx = raw.ch_names.index('C3')
data_raw = raw.get_data(picks=[c3_idx])[0, start:stop] * 1e6  # Convert to µV
data_filt = raw_filt.get_data(picks=[c3_idx])[0, start:stop] * 1e6

axes[0].plot(times, data_raw, 'b-', linewidth=0.5)
axes[0].set_ylabel('Amplitude (µV)')
axes[0].set_title('Raw EEG at C3')
axes[0].set_ylim(-100, 100)

axes[1].plot(times, data_filt, 'g-', linewidth=0.5)
axes[1].set_ylabel('Amplitude (µV)')
axes[1].set_xlabel('Time (s)')
axes[1].set_title('Filtered EEG at C3 (8-30 Hz)')
axes[1].set_ylim(-50, 50)

plt.tight_layout();

Epoch timing:
- t=0: Cue onset (visual instruction appears)
- t=0 to 0.5s: Reaction/preparation time
- t=0.5 to 4s: Motor imagery period (what we use)

We skip the first 0.5s to avoid visual evoked potentials from the cue.

In [None]:
# Extract events
events, event_id = extract_events(raw_filt)

print(f"Found {len(events)} motor imagery trials")
print(f"\nClass distribution:")
for name, code in event_id.items():
    count = np.sum(events[:, 2] == code)
    print(f"  {name}: {count} trials")

In [None]:
# Epoch parameters
TMIN = 0.5   # Start 0.5s after cue
TMAX = 4.0   # End at 4s
BASELINE = None  # No baseline correction (important for CSP)

# Create epochs without rejection first
epochs_all = create_epochs(
    raw_filt, events, event_id,
    tmin=TMIN, tmax=TMAX,
    baseline=BASELINE,
    reject=None
)

print(f"\nEpoch shape: {epochs_all.get_data().shape}")
print(f"  (n_epochs, n_channels, n_times)")
print(f"  Duration: {TMAX - TMIN}s = {epochs_all.get_data().shape[2]} samples")

In [None]:
# Visualize epoch structure
fig, ax = plt.subplots(figsize=(12, 4))

# Plot single epoch from each class
for i, (name, code) in enumerate(event_id.items()):
    idx = np.where(epochs_all.events[:, 2] == code)[0][0]
    data = epochs_all.get_data()[idx, epochs_all.ch_names.index('C3'), :] * 1e6
    ax.plot(epochs_all.times, data + i * 30, label=name.replace('_', ' ').title())

ax.axvline(0.5, color='k', linestyle='--', alpha=0.5, label='Cue (t=0.5 in epoch)')
ax.set_xlabel('Time relative to epoch start (s)')
ax.set_ylabel('Amplitude (µV) + offset')
ax.set_title('Single Trial Examples at C3')
ax.legend(loc='upper right')
plt.tight_layout();

Remove epochs with abnormally high amplitude likely artifacts from
- Eye blinks/movements
- Muscle tension
- Electrode pops

Threshold: Reject trials with peak-to-peak amplitude > 100 µV

In [None]:
# Compute peak-to-peak amplitude for each epoch
data = epochs_all.get_data() * 1e6  # to µV
ptp = np.ptp(data, axis=2).max(axis=1)  # Max across channels and time

# Plot distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(ptp, bins=50, edgecolor='black', alpha=0.7)
axes[0].axvline(100, color='r', linestyle='--', linewidth=2, label='Threshold (100 µV)')
axes[0].set_xlabel('Peak-to-peak amplitude (µV)')
axes[0].set_ylabel('Number of epochs')
axes[0].set_title('Amplitude Distribution')
axes[0].legend()

# Show which epochs would be rejected
reject_mask = ptp > 100
axes[1].scatter(range(len(ptp)), ptp, c=['red' if r else 'blue' for r in reject_mask], 
                alpha=0.5, s=20)
axes[1].axhline(100, color='r', linestyle='--', linewidth=2)
axes[1].set_xlabel('Epoch index')
axes[1].set_ylabel('Peak-to-peak amplitude (µV)')
axes[1].set_title(f'Epochs to reject: {reject_mask.sum()}/{len(ptp)} ({100*reject_mask.mean():.1f}%)')

plt.tight_layout();

In [None]:
# Create epochs with rejection
REJECT_THRESHOLD = 100e-6  # 100 µV in Volts

epochs = create_epochs(
    raw_filt, events, event_id,
    tmin=TMIN, tmax=TMAX,
    baseline=BASELINE,
    reject={'eeg': REJECT_THRESHOLD}
)

n_original = len(epochs_all)
n_kept = len(epochs)
n_rejected = n_original - n_kept

print(f"Artifact rejection results:")
print(f"  Original epochs: {n_original}")
print(f"  Kept epochs: {n_kept}")
print(f"  Rejected: {n_rejected} ({100*n_rejected/n_original:.1f}%)")
print(f"\nFinal class distribution:")
for name, code in event_id.items():
    count = np.sum(epochs.events[:, 2] == code)
    print(f"  {name}: {count} trials")

In [None]:
# Plot average evoked response for each class
fig, axes = plt.subplots(2, 2, figsize=(14, 8))

for idx, (name, code) in enumerate(event_id.items()):
    ax = axes[idx // 2, idx % 2]
    evoked = epochs[name].average()
    evoked.plot(picks=['C3', 'C4', 'Cz'], axes=ax, show=False, 
                spatial_colors=True, titles='')
    ax.set_title(name.replace('_', ' ').title())
    ax.set_xlim([0.5, 4.0])

fig.suptitle('Average Response at Motor Cortex (C3, C4, Cz)', y=1.02, fontsize=14)
plt.tight_layout();

In [None]:
# Topographic maps during motor imagery period
fig, axes = plt.subplots(2, 4, figsize=(14, 6))

times = [1.0, 2.0, 3.0, 3.5]  # Time points to show

for row, (name, code) in enumerate([('left_hand', 769), ('right_hand', 770)]):
    evoked = epochs[name].average()
    evoked.plot_topomap(times=times, axes=axes[row], show=False, 
                        colorbar=False, time_unit='s')
    axes[row, 0].set_ylabel(name.replace('_', ' ').title(), fontsize=12)

fig.suptitle('Topographic Maps: Left Hand vs Right Hand', y=1.02, fontsize=14)
plt.tight_layout();

In [None]:
# Time-frequency comparison: Left vs Right at C3 and C4
freqs = np.arange(8, 31, 1)
n_cycles = freqs / 2

fig, axes = plt.subplots(2, 2, figsize=(14, 8))

for row, (name, code) in enumerate([('left_hand', 769), ('right_hand', 770)]):
    for col, ch in enumerate(['C3', 'C4']):
        power = mne.time_frequency.tfr_morlet(
            epochs[name], freqs=freqs, n_cycles=n_cycles,
            picks=ch, return_itc=False, average=True, verbose=False
        )
        power.plot([0], axes=axes[row, col], show=False, colorbar=True,
                   title=f'{name.replace("_", " ").title()} at {ch}')

plt.tight_layout();

In [None]:
# Preprocessing parameters (matching config/default.yaml)
PREPROCESS_PARAMS = {
    'l_freq': 8.0,
    'h_freq': 30.0,
    'tmin': 0.5,
    'tmax': 4.0,
    'reject_threshold': 100e-6
}

print("Preprocessing parameters:")
for k, v in PREPROCESS_PARAMS.items():
    print(f"  {k}: {v}")

In [None]:
def process_and_save_subject(subject_id, session='T'):
    """Process a single subject and return epochs + labels."""
    gdf_path = DATA_DIR / f'A0{subject_id}{session}.gdf'
    
    if not gdf_path.exists():
        print(f"  Skipping subject {subject_id}: file not found")
        return None, None
    
    # Use the preprocessing module
    epochs, labels = preprocess_subject(
        gdf_path,
        **PREPROCESS_PARAMS
    )
    
    return epochs, labels

In [None]:
# Process all subjects
SUBJECTS = list(range(1, 10))  # 1-9

all_data = {
    'train': {'X': [], 'y': [], 'subjects': []},
    'test': {'X': [], 'y': [], 'subjects': []}
}

subject_stats = []

for subj in tqdm(SUBJECTS, desc="Processing subjects"):
    stats = {'subject': subj}
    
    # Training session
    epochs_train, labels_train = process_and_save_subject(subj, 'T')
    if epochs_train is not None:
        X_train = epochs_train.get_data()
        all_data['train']['X'].append(X_train)
        all_data['train']['y'].append(labels_train)
        all_data['train']['subjects'].append(np.full(len(labels_train), subj))
        stats['train_epochs'] = len(labels_train)
    
    # Evaluation session
    epochs_test, labels_test = process_and_save_subject(subj, 'E')
    if epochs_test is not None:
        X_test = epochs_test.get_data()
        all_data['test']['X'].append(X_test)
        all_data['test']['y'].append(labels_test)
        all_data['test']['subjects'].append(np.full(len(labels_test), subj))
        stats['test_epochs'] = len(labels_test)
    
    subject_stats.append(stats)

print("\nDone")

In [None]:
# Concatenate all subjects
for split in ['train', 'test']:
    if all_data[split]['X']:
        all_data[split]['X'] = np.concatenate(all_data[split]['X'], axis=0)
        all_data[split]['y'] = np.concatenate(all_data[split]['y'], axis=0)
        all_data[split]['subjects'] = np.concatenate(all_data[split]['subjects'], axis=0)

print("Dataset summary:")
print(f"  Training: {all_data['train']['X'].shape[0]} epochs")
print(f"  Test: {all_data['test']['X'].shape[0]} epochs")
print(f"  Shape: {all_data['train']['X'].shape} (epochs, channels, times)")

In [None]:
# Per-subject statistics
import pandas as pd

df_stats = pd.DataFrame(subject_stats)
df_stats['total'] = df_stats['train_epochs'] + df_stats['test_epochs']
print("Epochs per subject:")
print(df_stats.to_string(index=False))
print(f"\nTotal: {df_stats['total'].sum()} epochs")

In [None]:
# Class balance check
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

for idx, (split, title) in enumerate([('train', 'Training'), ('test', 'Test')]):
    unique, counts = np.unique(all_data[split]['y'], return_counts=True)
    axes[idx].bar([CLASS_NAMES[i] for i in unique], counts, color='steelblue', edgecolor='black')
    axes[idx].set_ylabel('Number of epochs')
    axes[idx].set_title(f'{title} Set Class Distribution')
    for i, (u, c) in enumerate(zip(unique, counts)):
        axes[idx].text(i, c + 5, str(c), ha='center')

plt.tight_layout();

In [None]:
# Save as numpy arrays
np.savez_compressed(
    PROCESSED_DIR / 'preprocessed_data.npz',
    X_train=all_data['train']['X'],
    y_train=all_data['train']['y'],
    subjects_train=all_data['train']['subjects'],
    X_test=all_data['test']['X'],
    y_test=all_data['test']['y'],
    subjects_test=all_data['test']['subjects'],
    sfreq=250.0,
    ch_names=CHANNEL_NAMES,
    class_names=CLASS_NAMES,
    preprocess_params=PREPROCESS_PARAMS
)

print(f"Saved to: {PROCESSED_DIR / 'preprocessed_data.npz'}")
print(f"File size: {(PROCESSED_DIR / 'preprocessed_data.npz').stat().st_size / 1e6:.1f} MB")

In [None]:
# Verify we can reload it
data = np.load(PROCESSED_DIR / 'preprocessed_data.npz', allow_pickle=True)
print("Saved arrays:")
for key in data.files:
    arr = data[key]
    if hasattr(arr, 'shape'):
        print(f"  {key}: {arr.shape}")
    else:
        print(f"  {key}: {arr}")

**Output:**
- `data/processed/preprocessed_data.npz` containing:
  - Training data: X_train, y_train, subjects_train
  - Test data: X_test, y_test, subjects_test
  - Metadata: sampling freq, channel names, class names