# Assignment: Training EEGNet on P300 EEG Data

In this assignment, you will work with real EEG data from a P300 speller experiment and implement the EEGNet architecture to detect P300 responses. The emphasis of this assignment is on understanding and implementing the EEGNet model rather than extensive signal preprocessing.

**Instructions:**
- Complete the provided code scaffolding
- Fill in missing logic where indicated
- Focus especially on the EEGNet architecture and training


In [80]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Part 1: Loading and Inspecting the Dataset

In this section, you will load the EEG dataset and inspect its basic structure. The dataset contains continuous EEG recordings along with stimulus and label information.

In [147]:
import scipy.io as sio
import numpy as np

# Load the dataset
# TODO: Update the path if needed
data = sio.loadmat('/content/drive/MyDrive/EEG dataset/Subject_A_Train.mat')

# Inspect available keys
print(data.keys())


dict_keys(['__header__', '__version__', '__globals__', 'Signal', 'TargetChar', 'Flashing', 'StimulusCode', 'StimulusType'])


## Part 2: Understanding the Experimental Design

The P300 speller paradigm is based on detecting brain responses to rare target stimuli. In this section, you will identify how stimulus timing and labels are encoded in the data.

In [148]:
# 1. Continuous EEG signal
eeg_signal = data['Signal']

# 2. Stimulus onset information
flashing = data['Flashing']
stimulus_code = data['StimulusCode']

# 3. Target vs non-target labels
stimulus_type = data['StimulusType']

# Print shapes to verify consistency
print("EEG signal shape:", eeg_signal.shape)
print("Flashing shape:", flashing.shape)
print("Stimulus code shape:", stimulus_code.shape)
print("Stimulus type shape:", stimulus_type.shape)


EEG signal shape: (85, 7794, 64)
Flashing shape: (85, 7794)
Stimulus code shape: (85, 7794)
Stimulus type shape: (85, 7794)


## Part 3: EEG Epoch Extraction

EEGNet does not operate on continuous EEG. Instead, the signal must be segmented into short epochs following each stimulus. This step converts raw EEG into trials suitable for supervised learning.

In [149]:
def extract_epochs(signal, stimulus_onsets, labels, fs, t_start=0.0, t_end=0.8):
    start_sample = int(t_start * fs)
    end_sample = int(t_end * fs)
    epoch_length = end_sample - start_sample

    epochs = []
    y = []

    for idx, onset in enumerate(stimulus_onsets):
        start_idx = onset + start_sample
        end_idx = onset + end_sample

        if start_idx >= 0 and end_idx <= signal.shape[0]:
            epoch = signal[start_idx:end_idx, :].T  # (channels, time)
            epochs.append(epoch)
            y.append(labels[idx])

    epochs = np.array(epochs)
    y = np.array(y)

    return epochs, y


## Part 4: Preparing Data for EEGNet

In this section, you will perform minimal preprocessing to make the data compatible with EEGNet. Extensive signal processing is not required.

In [160]:
def prepare_for_eegnet(epochs):
    # Original shape: (num_trials, channels, time)
    return np.expand_dims(epochs, axis=-1)


## Part 5: Implementing EEGNet

This is the core part of the assignment. You will implement the EEGNet architecture as discussed in class. Focus on matching the block structure and understanding the role of each layer.

In [159]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Conv2D, DepthwiseConv2D,
                                     SeparableConv2D, BatchNormalization,
                                     AveragePooling2D, Dropout, Flatten, Dense)

def EEGNet(nb_classes, Chans, Samples, F1=8, D=2, F2=16, dropoutRate=0.5):

    # Change input shape to channels_last (NHWC)
    inputs = Input(shape=(Chans, Samples, 1))
    x = Conv2D(F1, (1, 64), padding='same', use_bias=False)(inputs)
    x = BatchNormalization()(x)
    x = DepthwiseConv2D((Chans, 1), use_bias=False, depth_multiplier=D,
                        padding='valid')(x)
    x = BatchNormalization()(x)
    x = tf.keras.layers.Activation('elu')(x)
    x = AveragePooling2D((1, 4))(x)
    x = Dropout(dropoutRate)(x)

    x = SeparableConv2D(F2, (1, 16), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = tf.keras.layers.Activation('elu')(x)
    x = AveragePooling2D((1, 8))(x)
    x = Dropout(dropoutRate)(x)

    # Classification
    x = Flatten()(x)
    outputs = Dense(nb_classes, activation='softmax')(x)

    return Model(inputs=inputs, outputs=outputs)

## Part 6: Training the Model

In this section, you will train EEGNet to distinguish between P300 and non-P300 EEG epochs.

In [161]:
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
import numpy as np
import tensorflow as tf


tf.keras.backend.set_image_data_format('channels_last')

fs = 250
eeg_signal_full = eeg_signal
flashing_full = flashing
stimulus_type_full = stimulus_type

eeg_data_single_session = eeg_signal_full[0]
flashing_single_session = flashing_full[0]
stimulus_type_single_session = stimulus_type_full[0]

onset_indices = np.where(np.diff(flashing_single_session) == 1)[0] + 1

labels_at_onsets = stimulus_type_single_session[onset_indices]

t_start = 0.0
t_end = 0.8
epochs_raw, y_labels_extracted = extract_epochs(
    signal=eeg_data_single_session,
    stimulus_onsets=onset_indices,
    labels=labels_at_onsets,
    fs=fs,
    t_start=t_start,
    t_end=t_end
)

X = prepare_for_eegnet(epochs_raw)

nb_trials, Chans, Samples, _ = X.shape

nb_classes = 2

model = EEGNet(nb_classes=nb_classes, Chans=Chans, Samples=Samples)

y_cat = to_categorical(y_labels_extracted, num_classes=nb_classes)

X_train, X_val, y_train, y_val = train_test_split(
    X, y_cat, test_size=0.2, random_state=42, stratify=y_labels_extracted
)

# Compile the model
model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

# Train the model
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=50,
    batch_size=16,
    verbose=1
)


Epoch 1/50
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 307ms/step - accuracy: 0.6520 - loss: 0.6527 - val_accuracy: 0.5833 - val_loss: 0.7082
Epoch 2/50
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 370ms/step - accuracy: 0.7301 - loss: 0.5509 - val_accuracy: 0.5833 - val_loss: 0.6710
Epoch 3/50
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 254ms/step - accuracy: 0.7086 - loss: 0.6165 - val_accuracy: 0.7500 - val_loss: 0.6394
Epoch 4/50
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 265ms/step - accuracy: 0.8018 - loss: 0.4325 - val_accuracy: 0.7778 - val_loss: 0.6152
Epoch 5/50
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 263ms/step - accuracy: 0.8332 - loss: 0.4559 - val_accuracy: 0.8056 - val_loss: 0.5887
Epoch 6/50
[1m9/9[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 441ms/step - accuracy: 0.8577 - loss: 0.4533 - val_accuracy: 0.8333 - val_loss: 0.5706
Epoch 7/50
[1m9/9[0m [32m━━━━━━━━━━━━