# ğŸ§  EEG Motor Imagery Analysis

This notebook explores the PhysioNet Motor Movement/Imagery dataset and visualizes EEG signals.

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy import signal

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / "src"))

from neuralflight.eeg.dataset import PhysioNetDataset, preprocess_eeg
from neuralflight.utils.config_loader import load_config

plt.style.use("seaborn-v0_8-darkgrid")
sns.set_palette("husl")

%matplotlib inline

## 1. Load Data

In [None]:
# Load configuration
config = load_config("eeg_config")

# Initialize dataset
dataset = PhysioNetDataset("../data/raw/physionet")

# Load a single subject
subject_id = 1
X, y = dataset.load_subject(subject_id, channels=["C3", "Cz", "C4"])

print(f"Data shape: {X.shape}")
print(f"Labels: {np.unique(y, return_counts=True)}")

## 2. Visualize Raw EEG

In [None]:
# Plot a single epoch
epoch_idx = 10
epoch = X[epoch_idx]
label = y[epoch_idx]

class_names = {0: "Rest", 1: "Left Hand", 2: "Right Hand"}

fig, axes = plt.subplots(3, 1, figsize=(12, 8))
channel_names = ["C3", "Cz", "C4"]
time = np.linspace(0, 3, epoch.shape[1])

for i, (ax, ch_name) in enumerate(zip(axes, channel_names)):
    ax.plot(time, epoch[i])
    ax.set_ylabel(f"{ch_name} (ÂµV)")
    ax.grid(True, alpha=0.3)

axes[-1].set_xlabel("Time (s)")
fig.suptitle(
    f"EEG Epoch - Class: {class_names.get(label, 'Unknown')}", fontsize=14
)
plt.tight_layout()
plt.show()

## 3. Frequency Analysis

In [None]:
# Power spectral density
fs = 160  # Sampling rate

fig, axes = plt.subplots(3, 1, figsize=(12, 8))

for i, (ax, ch_name) in enumerate(zip(axes, channel_names)):
    freqs, psd = signal.welch(epoch[i], fs, nperseg=256)

    ax.semilogy(freqs, psd)
    ax.set_ylabel(f"{ch_name} PSD")
    ax.set_xlim(0, 50)
    ax.grid(True, alpha=0.3)

    # Highlight motor imagery bands
    ax.axvspan(8, 13, alpha=0.2, color="blue", label="Alpha (8-13 Hz)")
    ax.axvspan(13, 30, alpha=0.2, color="red", label="Beta (13-30 Hz)")

axes[-1].set_xlabel("Frequency (Hz)")
axes[0].legend()
fig.suptitle("Power Spectral Density", fontsize=14)
plt.tight_layout()
plt.show()

## 4. Compare Classes

In [None]:
# Average signals per class
left_hand_epochs = X[y == 1]
right_hand_epochs = X[y == 2]

left_avg = left_hand_epochs.mean(axis=0)
right_avg = right_hand_epochs.mean(axis=0)

fig, axes = plt.subplots(3, 1, figsize=(12, 8))

for i, (ax, ch_name) in enumerate(zip(axes, channel_names)):
    ax.plot(time, left_avg[i], label="Left Hand", linewidth=2)
    ax.plot(time, right_avg[i], label="Right Hand", linewidth=2)
    ax.set_ylabel(f"{ch_name} (ÂµV)")
    ax.legend()
    ax.grid(True, alpha=0.3)

axes[-1].set_xlabel("Time (s)")
fig.suptitle(
    "Average EEG Signals: Left vs Right Hand Imagery", fontsize=14
)
plt.tight_layout()
plt.show()

## 5. Apply Preprocessing

In [None]:
# Apply bandpass filter
X_filtered = preprocess_eeg(X, lowcut=8.0, highcut=30.0, fs=160.0)

# Compare raw vs filtered
epoch_idx = 10
channel_idx = 1  # Cz

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6))

ax1.plot(time, X[epoch_idx, channel_idx])
ax1.set_title("Raw Signal (Cz)")
ax1.set_ylabel("Amplitude (ÂµV)")
ax1.grid(True, alpha=0.3)

ax2.plot(time, X_filtered[epoch_idx, channel_idx])
ax2.set_title("Filtered Signal (8-30 Hz)")
ax2.set_xlabel("Time (s)")
ax2.set_ylabel("Amplitude (ÂµV)")
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Feature Visualization

In [None]:
# Extract simple features: mean band power
def band_power(data, fs, band):
    """Calculate average power in frequency band."""
    freqs, psd = signal.welch(data, fs, nperseg=256)
    idx = np.logical_and(freqs >= band[0], freqs <= band[1])
    return np.trapz(psd[idx], freqs[idx])


# Calculate alpha and beta power for each epoch
alpha_power = []
beta_power = []

for epoch in X_filtered[:100]:  # First 100 epochs
    # Use Cz channel
    alpha = band_power(epoch[1], 160, [8, 13])
    beta = band_power(epoch[1], 160, [13, 30])
    alpha_power.append(alpha)
    beta_power.append(beta)

alpha_power = np.array(alpha_power)
beta_power = np.array(beta_power)
labels = y[:100]

# Scatter plot
plt.figure(figsize=(10, 6))
for label in [1, 2]:
    mask = labels == label
    plt.scatter(
        alpha_power[mask],
        beta_power[mask],
        label=class_names[label],
        alpha=0.6,
        s=50,
    )

plt.xlabel("Alpha Power (8-13 Hz)")
plt.ylabel("Beta Power (13-30 Hz)")
plt.title("Feature Space: Motor Imagery Classification")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## Summary

This notebook demonstrated:
1. Loading EEG data from PhysioNet dataset
2. Visualizing raw EEG signals
3. Frequency analysis (PSD)
4. Comparing different motor imagery classes
5. Preprocessing with bandpass filters
6. Simple feature extraction

Next steps:
- Run `demos/train_model.py` to train the EEGNet classifier
- Try `demos/motor_imagery_demo.py` to see it in action!