# Assignment: Training EEGNet on P300 EEG Data

In this assignment, you will work with real EEG data from a P300 speller experiment and implement the EEGNet architecture to detect P300 responses. The emphasis of this assignment is on understanding and implementing the EEGNet model rather than extensive signal preprocessing.

**Instructions:**
- Complete the provided code scaffolding
- Fill in missing logic where indicated
- Focus especially on the EEGNet architecture and training


## Part 1: Loading and Inspecting the Dataset

In this section, you will load the EEG dataset and inspect its basic structure. The dataset contains continuous EEG recordings along with stimulus and label information.

In [11]:
import scipy.io as sio
import numpy as np

from google.colab import drive
import os
drive.mount('/content/drive')

DATA_PATH = '/content/bci2004/BCI_Comp_III_Wads_2004/'
if not os.path.exists(DATA_PATH):
  !unzip /content/drive/MyDrive/BCI_Comp_III_Wads_2004.zip -d /content/bci2004

data = sio.loadmat(DATA_PATH + 'Subject_A_Train.mat')

# Inspect available keys
print(data.keys())


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
dict_keys(['__header__', '__version__', '__globals__', 'Signal', 'TargetChar', 'Flashing', 'StimulusCode', 'StimulusType'])


## Part 2: Understanding the Experimental Design

The P300 speller paradigm is based on detecting brain responses to rare target stimuli. In this section, you will identify how stimulus timing and labels are encoded in the data.

In [12]:
signal = data["Signal"]
flashing = data["Flashing"].squeeze()
stim_type = data["StimulusType"].squeeze()

print(signal.shape)
print(flashing.shape)
print(stim_type.shape)
def get_flash_onsets(flashing_1d):
    """
    flashing_1d: shape (time,)
    returns: indices where flash starts
    """
    return np.where(
        (flashing_1d[1:] == 1) & (flashing_1d[:-1] == 0)
    )[0] + 1

labels = []

for run in range(flashing.shape[0]):
    onsets = get_flash_onsets(flashing[run])

    for onset in onsets:
        labels.append(stim_type[run, onset])

labels = np.array(labels)

print("Total labels:", labels.shape[0])
print("Targets:", np.sum(labels == 1))
print("Non-targets:", np.sum(labels == 0))

(85, 7794, 64)
(85, 7794)
(85, 7794)
Total labels: 15215
Targets: 2537
Non-targets: 12678


## Part 3: EEG Epoch Extraction

EEGNet does not operate on continuous EEG. Instead, the signal must be segmented into short epochs following each stimulus. This step converts raw EEG into trials suitable for supervised learning.

In [13]:
import numpy as np

def extract_epochs(signal, stimulus_onsets, labels, fs, t_start=0.0, t_end=0.8):
    """
    Extract EEG epochs around each stimulus onset.

    Parameters:
    - signal: array of shape (time, channels)
    - stimulus_onsets: array of stimulus onset indices (in samples)
    - labels: array of labels, one per stimulus onset
    - fs: sampling frequency in Hz
    - t_start: start time (seconds) relative to stimulus
    - t_end: end time (seconds) relative to stimulus

    Returns:
    - epochs: array of shape (num_trials, channels, time)
    - y: corresponding labels
    """

    # Convert time window (seconds) to samples
    start_offset = int(t_start * fs)
    end_offset   = int(t_end * fs)

    epochs = []
    y = []

    for onset, label in zip(stimulus_onsets, labels):
        start = onset + start_offset
        end   = onset + end_offset

        # Make sure we do not go out of bounds
        if start >= 0 and end <= signal.shape[0]:
            # Extract epoch and transpose to (channels, time)
            epoch = signal[start:end, :].T
            epochs.append(epoch)
            y.append(label)

    return np.array(epochs), np.array(y)
fs = 240
all_epochs = []
all_labels = []

for run in range(signal.shape[0]):

    run_signal = signal[run]        # (time, channels)
    run_flashing = flashing[run]    # (time,)
    run_stimtype = stim_type[run]   # (time,)

    # 1. Get stimulus onsets for this run
    stimulus_onsets = get_flash_onsets(run_flashing)

    # 2. Get labels for those onsets
    labels = run_stimtype[stimulus_onsets]

    # 3. Extract epochs for this run
    epochs, y = extract_epochs(
        signal=run_signal,
        stimulus_onsets=stimulus_onsets,
        labels=labels,
        fs=fs,
        t_start=0.2,
        t_end=0.6
    )

    all_epochs.append(epochs)
    all_labels.append(y)

# Combine all runs
epochs = np.concatenate(all_epochs, axis=0)
y = np.concatenate(all_labels, axis=0)

print("Epochs shape:", epochs.shape)
print("Labels shape:", y.shape)


Epochs shape: (15215, 64, 96)
Labels shape: (15215,)


## Part 4: Preparing Data for EEGNet

In this section, you will perform minimal preprocessing to make the data compatible with EEGNet. Extensive signal processing is not required.

In [14]:
def prepare_for_eegnet(epochs):
    return np.expand_dims(epochs, axis=1)


## Part 5: Implementing EEGNet

This is the core part of the assignment. You will implement the EEGNet architecture as discussed in class. Focus on matching the block structure and understanding the role of each layer.

In [15]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, DepthwiseConv2D,
    BatchNormalization, AveragePooling2D,
    Dropout, Dense, Activation
)
from tensorflow.keras.constraints import max_norm

def EEGNet(nb_classes, Chans, Samples, F1=8, D=2, F2=16, dropoutRate=0.5):
    inputs = Input(shape=(1, Chans, Samples))

    # Block 1: Temporal Convolution
    x = Conv2D(
        F1,
        (1, 64),
        padding='same',
        use_bias=False,
        data_format='channels_first'
    )(inputs)
    x = BatchNormalization(axis=1)(x)

    # Block 1: Spatial Convolution (across EEG channels)
    x = DepthwiseConv2D(
        (Chans, 1),
        depth_multiplier=D,
        use_bias=False,
        depthwise_constraint=max_norm(1.),
        data_format='channels_first'
    )(x)
    x = BatchNormalization(axis=1)(x)
    x = Activation('elu')(x)
    x = AveragePooling2D((1, 4), data_format='channels_first')(x)
    x = Dropout(dropoutRate)(x)

    # Block 2: Separable Convolution
    # Depthwise temporal convolution
    x = DepthwiseConv2D(
        (1, 16),
        padding='same',
        use_bias=False,
        data_format='channels_first'
    )(x)

    # Pointwise convolution
    x = Conv2D(
        F2,
        (1, 1),
        use_bias=False,
        data_format='channels_first'
    )(x)

    x = BatchNormalization(axis=1)(x)
    x = Activation('elu')(x)
    x = AveragePooling2D((1, 8), data_format='channels_first')(x)
    x = Dropout(dropoutRate)(x)

    # Classification
    x = tf.keras.layers.GlobalAveragePooling2D(
        data_format='channels_first'
    )(x)
    outputs = Dense(nb_classes, activation='softmax')(x)

    return Model(inputs, outputs)

model = EEGNet(
    nb_classes=2,
    Chans=64,
    Samples=96
)

model.summary()

## Part 6: Training the Model

In this section, you will train EEGNet to distinguish between P300 and non-P300 EEG epochs.

In [None]:
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical

X = prepare_for_eegnet(epochs)
y_cat = to_categorical(y, 2)
X_train, X_val, y_train, y_val = train_test_split(
    X, y_cat,
    test_size=0.2,
    random_state=42,
    stratify=y
)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=30,
    batch_size=4,
    verbose=1
)


Epoch 1/30
[1m3043/3043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 5ms/step - accuracy: 0.8122 - loss: 0.4847 - val_accuracy: 0.8334 - val_loss: 0.4445
Epoch 2/30
[1m3043/3043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 4ms/step - accuracy: 0.8302 - loss: 0.4519 - val_accuracy: 0.8334 - val_loss: 0.4430
Epoch 3/30
[1m3043/3043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 4ms/step - accuracy: 0.8347 - loss: 0.4395 - val_accuracy: 0.8334 - val_loss: 0.4404
Epoch 4/30
[1m3043/3043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 4ms/step - accuracy: 0.8327 - loss: 0.4420 - val_accuracy: 0.8331 - val_loss: 0.4360
Epoch 5/30
[1m3043/3043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 4ms/step - accuracy: 0.8330 - loss: 0.4346 - val_accuracy: 0.8301 - val_loss: 0.4373
Epoch 6/30
[1m3043/3043[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 4ms/step - accuracy: 0.8273 - loss: 0.4421 - val_accuracy: 0.8311 - val_loss: 0.4320
Epoch 7/30