# Assignment: Training EEGNet on P300 EEG Data

In this assignment, you will work with real EEG data from a P300 speller experiment and implement the EEGNet architecture to detect P300 responses. The emphasis of this assignment is on understanding and implementing the EEGNet model rather than extensive signal preprocessing.

**Instructions:**
- Complete the provided code scaffolding
- Fill in missing logic where indicated
- Focus especially on the EEGNet architecture and training


## Part 1: Loading and Inspecting the Dataset

In this section, you will load the EEG dataset and inspect its basic structure. The dataset contains continuous EEG recordings along with stimulus and label information.

In [None]:
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
import scipy.io as sio
import numpy as np

# Load the dataset
# TODO: Update the path if needed
from google.colab import drive
drive.mount('/content/drive')
data = sio.loadmat('/content/drive/MyDrive/Subject_A_Train.mat')

# Inspect available keys
print(data.keys())


Mounted at /content/drive
dict_keys(['__header__', '__version__', '__globals__', 'Signal', 'TargetChar', 'Flashing', 'StimulusCode', 'StimulusType'])


## Part 2: Understanding the Experimental Design

The P300 speller paradigm is based on detecting brain responses to rare target stimuli. In this section, you will identify how stimulus timing and labels are encoded in the data.

In [None]:
# TODO: Identify which variables correspond to
# 1. Continuous EEG signal
# 2. Stimulus onset information
# 3. Target vs non-target labels
# Hint: Look for variables related to stimulus codes and stimulus types
TargetChar=data['TargetChar']
TargetChar=np.array(list(TargetChar[0]))
signal = data['Signal']# shape is (trials,samples,channels); one trial corresponds to flashing of one particular letter
flashing = data['Flashing']
stimulus_code = data['StimulusCode']
stimulus_type = data.get('StimulusType')

n_chars=signal.shape[0]
labels_list = []

for char_idx in range(n_chars):
  char_type = stimulus_type[char_idx]
  char_flashing = flashing[char_idx]
  stimulus_onsets = np.where(np.diff(char_flashing) == 1)[0] + 1
  for onset in stimulus_onsets:
     label = char_type[onset]
     labels_list.append(label)
  labels = np.array(labels_list)
print(signal.shape)

(85, 7794, 64)


## Part 3: EEG Epoch Extraction

EEGNet does not operate on continuous EEG. Instead, the signal must be segmented into short epochs following each stimulus. This step converts raw EEG into trials suitable for supervised learning.

In [None]:
def extract_epochs(signal, stimulus_onsets, labels, fs, t_start=0.0, t_end=0.8):
    """
    Extract EEG epochs around each stimulus onset.

    Parameters:
    - signal: continuous EEG array of shape (trials,time, channels)
    - stimulus_onsets: indices where stimuli occur
    - labels: target/non-target labels per stimulus
    - fs: sampling frequency in Hz
    - t_start: start time (seconds) relative to stimulus
    - t_end: end time (seconds) relative to stimulus


    Returns:
    - epochs: array of shape (num_trials, channels, time)
    - y: corresponding labels
    """
    # TODO: Implement epoch extraction logic
    # Hint: Convert time window to samples using fs
    n_chars = signal.shape[0]
    epoch_samples = int((t_end-t_start) * fs) #converts time to no. of samples to extract
    epochs_list=[]
    for char_idx in range(n_chars):
        char_signal = signal[char_idx]
        for onset in stimulus_onsets:
          if onset+epoch_samples<=len(char_signal):
            epoch=char_signal[onset:onset+epoch_samples,]
            epochs_list.append(epoch)
    epochs = np.array(epochs_list) #epochs has shape (trials,samples, channels)
    epochs=np.transpose(epochs,axes=(0,2,1))#now shape is (trials,channels,samples)
    return epochs,labels


## Part 4: Preparing Data for EEGNet

In this section, you will perform minimal preprocessing to make the data compatible with EEGNet. Extensive signal processing is not required.

In [None]:
def prepare_for_eegnet(epochs):
    """
    Prepare EEG epochs for input into EEGNet.

    Expected input shape: (trials, channels, time)
    Expected output shape: (trials, 1, channels, time)
    """
    # TODO: Add singleton dimension required by Conv2D
    # Hint: Use numpy.expand_dims
    epochs=np.expand_dims(epochs,axis=1)
    print(epochs.shape)
    return epochs


## Part 5: Implementing EEGNet

This is the core part of the assignment. You will implement the EEGNet architecture as discussed in class. Focus on matching the block structure and understanding the role of each layer.

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Conv2D, DepthwiseConv2D,
                                     SeparableConv2D, BatchNormalization,Activation,
                                     AveragePooling2D, Dropout, Flatten, Dense)
from tensorflow.keras.constraints import max_norm

def EEGNet(nb_classes, Chans, Timepts, F1=8, D=2, F2=16, dropoutRate=0.5):
    """
    EEGNet architecture.

    Parameters:
    - nb_classes: number of output classes
    - Chans: number of EEG channels
    - Samples: number of time samples per epoch
    - F1: number of temporal filters
    - D: depth multiplier for spatial filters
    - F2: number of pointwise filters
    """

    input1 = Input(shape=( Chans, Timepts,1))
    kernLength=Timepts//2
    # Block 1: Temporal Convolution
    block1 = Conv2D(F1, (1, kernLength), padding = 'same',use_bias = False)(input1)
    block1 = BatchNormalization()(block1)
    # Block 1: Spatial Convolution
    block1 = DepthwiseConv2D((Chans, 1), use_bias = False,depth_multiplier = D,depthwise_constraint = max_norm(1.))(block1)
    block1 = BatchNormalization()(block1)
    block1 = Activation('elu')(block1)
    block1 = AveragePooling2D((1, 4))(block1)
    block1 = Dropout(dropoutRate)(block1)

    # Block 2: Separable Convolution
    block2 = SeparableConv2D(F2, (1, 16),use_bias = False, padding = 'same')(block1)
    block2 = BatchNormalization()(block2)
    block2 = Activation('elu')(block2)
    block2 = AveragePooling2D((1, 8))(block2)
    block2 = Dropout(dropoutRate)(block2)

    # Classification
    flatten = Flatten(name = 'flatten')(block2)
    dense = Dense(nb_classes, name = 'dense',kernel_constraint = max_norm(.25))(flatten)
    softmax = Activation('softmax', name = 'softmax')(dense)

    return Model(inputs=input1, outputs=softmax)
# TODO: Instantiate the EEGNet model and print the summary
model=EEGNet(2,64,204,8,2,16,.5)
print(model.summary())

None


## Part 6: Training the Model

In this section, you will train EEGNet to distinguish between P300 and non-P300 EEG epochs.

In [None]:
# TODO: Split the dataset into training and validation sets
epochs,labels=extract_epochs(signal, stimulus_onsets, labels, fs=256, t_start=0.0, t_end=0.8)
epochs=prepare_for_eegnet(epochs)
epochs=np.transpose(epochs,axes=(0,2,3,1))
x_train,x_val,y_train,y_val=train_test_split(epochs,labels,test_size=.2,random_state=42,stratify=labels)
y_train=to_categorical(y_train,2)
y_val=to_categorical(y_val,2)
# TODO: Compile the model with an appropriate loss and optimizer
# Hint: Use categorical cross-entropy and Adam optimizer
model.compile(optimizer='adam',  # Default lr=0.001 works well
    loss='categorical_crossentropy',  # Multi-class classification
    metrics=['accuracy'])
# TODO: Train the model and store the training history
#conert epochs to samples,channels timepoints,1
history=model.fit(x_train,y_train,batch_size=8,epochs=10,validation_data=(x_val,y_val),verbose=1)



(15215, 1, 64, 204)
Epoch 1/10
[1m1522/1522[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m376s[0m 245ms/step - accuracy: 0.8346 - loss: 0.4028 - val_accuracy: 0.8400 - val_loss: 0.3999
Epoch 2/10
[1m1522/1522[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m359s[0m 236ms/step - accuracy: 0.8364 - loss: 0.4028 - val_accuracy: 0.8357 - val_loss: 0.3952
Epoch 3/10
[1m1522/1522[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m390s[0m 241ms/step - accuracy: 0.8369 - loss: 0.4037 - val_accuracy: 0.8390 - val_loss: 0.3917
Epoch 4/10
[1m1522/1522[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m371s[0m 234ms/step - accuracy: 0.8414 - loss: 0.3860 - val_accuracy: 0.8363 - val_loss: 0.3917
Epoch 5/10
[1m1522/1522[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m392s[0m 241ms/step - accuracy: 0.8462 - loss: 0.3751 - val_accuracy: 0.8409 - val_loss: 0.3916
Epoch 6/10
[1m1522/1522[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m370s[0m 243ms/step - accuracy: 0.8433 - loss: 0.3762 - val_accuracy