In [29]:
import os
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
from scipy.io import loadmat
from torchinfo import summary

DATASET_DIR = "/kaggle/input/bci-term-project/BCICIV_2a_gdf"
DATASET_DIR_TEST = "/kaggle/input/bci-homework-3-kaggle-judge/BCI_hw3_dataset/labeled_test"
DATASET_DIR_EXAM = "/kaggle/input/bci-homework-3-kaggle-judge/BCI_hw3_dataset/unlabeled_test"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## EEG Model

### EEGNet

In [30]:
class EEGNet(nn.Module):
    """EEGNet model from Lawhern et al 2018.
    ... Parameters ............
    C: int
        Number of EEG input channels.
    N: int
        Number of EEG input time samples.
    nb_classes: int
        Number of classes to predict.
    kernLength: int
        Length of temporal convolution in first layer.
    F1, F2: int
        Number of temporal filters (F1) and number of pointwise filters (F2) to learn.
    D: int
        Number of spatial filters to learn within each temporal convolution.
    dropoutRate: float
        Dropout ratio.
    ... References ............
    https://arxiv.org/abs/1611.08024
    """

    def __init__(self, C, N, nb_classes, kernLength=64, F1=8, F2=16, D=2, dropoutRate=0.5):
        super(EEGNet, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, F1, (1, kernLength), padding="valid", bias=False),
            nn.BatchNorm2d(F1, eps=1e-3, momentum=0.99)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                F1, D * F1, (C, 1), groups=F1, bias=False
            ),
            nn.BatchNorm2d(D * F1, eps=1e-3, momentum=0.99),
            nn.ELU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(dropoutRate)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(
                D * F1, D * F1, (1, 16),
                padding=(0, 8), groups=D * F1, bias=False
            ),
            nn.Conv2d(D * F1, F2, (1, 1), bias=False),
            nn.BatchNorm2d(F2, eps=1e-3, momentum=0.99),
            nn.ELU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(dropoutRate)
        )

        fc_inSize = self.get_size(C, N)[1]
        self.classifier = nn.Linear(fc_inSize, nb_classes, bias=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def get_size(self, C, N):
        data = torch.ones((1, 1, C, N))
        x = self.conv1(data)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size()[0], -1)
        return x.size()


### ShallowConvNet

In [31]:
class ShallowConvNet(nn.Module):
    """Shallow ConvNet model from Schirrmeister et al 2017.
    ... Parameters ............
    C: int
        Number of EEG input channels.
    N: int
        Number of EEG input time samples.
    nb_classes: int
        Number of classes to predict.
    NT: int
        Number of temporal filters.
    NS: int
        Number of spatial filters.
    tkerLen: int
        Length of the temporal filter.
    pool_tLen: int
        Length of temporal pooling filter.
    pool_tStep: int
        Length of stride of temporal pooling filters.
    batch_norm: bool
        Whether to use batch normalization.
    dropRate: float
        Dropout ratio.
    ... References ............
    https://arxiv.org/abs/1703.05051
    """

    def __init__(self, C, N, nb_classes, NT=40, NS=40, tkerLen=12, pool_tLen=35, pool_tStep=7, batch_norm=True,
                 dropRate=0.25):
        super(ShallowConvNet, self).__init__()

        self.conv1 = nn.Conv2d(1, NT, (1, tkerLen), bias=False)
        self.conv2 = nn.Conv2d(NT, NS, (C, 1), bias=False)
        self.Bn1 = nn.BatchNorm2d(NS)
        self.AvgPool1 = nn.AvgPool2d((1, pool_tLen), stride=(1, pool_tStep))
        self.Drop1 = nn.Dropout(dropRate)
        fc_inSize = self.get_size(C, N)[1]
        self.classifier = nn.Linear(fc_inSize, nb_classes, bias=True)
        self.batch_norm = batch_norm

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        if self.batch_norm:
            x = self.Bn1(x)
        x = x ** 2
        x = self.AvgPool1(x)
        x = torch.log(x)
        x = self.Drop1(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def get_size(self, C, N):
        data = torch.ones((1, 1, C, N))
        x = self.conv1(data)
        x = self.conv2(x)
        x = self.AvgPool1(x)
        x = x.view(x.size()[0], -1)
        return x.size()


### SCCNet

In [32]:
# (Bonus) Optional TODO: Advanced SCCNet model without permutation layer
class SCCNet_v2(nn.Module):
    """Advanced SCCNet model without permutation layer.
    ... Parameters ............
    C: int
        Number of EEG input channels.
    N: int
        Number of EEG input time samples.
    nb_classes: int
        Number of classes to predict.
    Nu: int
        Number of spatial kernel.
    Nt: int
        Length of spatial kernel.
    Nc: int
        Number of spatial-temporal kernel.
    fs: float
        Sampling frequency of EEG input.
    dropoutRate: float
        Dropout ratio.
    ... References ............
    https://ieeexplore.ieee.org/document/8716937
    """

    def __init__(self, C, N, nb_classes, Nu=None, Nt=1, Nc=20, fs=1000.0, dropoutRate=0.5):
        super(SCCNet_v2, self).__init__()
        Nu = C if Nu is None else Nu
        self.conv1 = nn.Conv2d(1, 1, (1, Nt), bias=False)
        self.bn1 = nn.BatchNorm2d(num_features=1)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(1, Nc, (Nu, int(fs * 0.1)), bias=False)
        self.bn2 = nn.BatchNorm2d(num_features=Nc)
        self.relu2 = nn.ReLU(inplace=True)

        self.dropout = nn.Dropout(p=dropoutRate)
        # Pooling block
        self.avgpool = nn.AvgPool2d(kernel_size=(1, int(fs * 0.5)), stride=(1, int(fs * 0.1)))

        # get linear size
        fc_inSize = self.get_size(C, N)[1]
        self.classifier = nn.Linear(fc_inSize, nb_classes, bias=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.dropout(x)

        # Second convolution block
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = x ** 2

        # Pooling block
        x = self.avgpool(x)
        x = self.dropout(x)

        # Softmax block
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def get_size(self, C, N):
        data = torch.ones((1, 1, C, N))
        x = self.conv1(data)
        x = self.conv2(x)
        x = self.avgpool(x)
        x = x.view(x.size()[0], -1)
        return x.size()


In [33]:
# TODO: finish the SCCNet
class SCCNet(nn.Module):
    """SCCNet model from Wei et al 2019.
    Note: Nc was misused in the paper to describe both the number of EEG input channels and the number of spatial-temporal kernel.
    Please refer to the following description to find the correspondence of the defined parameters and noataions in the paper.
    ... Parameters ............
    C: int
        Number of EEG input channels. (Same as the Nc mentioned in the first and second paragraphs of the paper section II.B)
    N: int
        Number of EEG input time samples.
    nb_classes: int
        Number of classes to predict.
    Nu: int
        Number of spatial kernel.
    Nt: int
        Length of spatial kernel.
    Nc: int
        Number of spatial-temporal kernel. (Same as the Nc mentioned in the third paragraph of the paper section II.B)
    fs: float
        Sampling frequency of EEG input.
    dropoutRate: float
        Dropout ratio.
    ... References ............
    https://ieeexplore.ieee.org/document/8716937
    """

    # You can only add extra argument to this function, do NOT remove the existed arguments
    # The model structure should be dynamic changed by the provided arguments,
    # There will be a score penalty if SCCNet structure is static
    def __init__(self, C, N, nb_classes, Nu=None, Nt=1, Nc=20, fs=1000.0, dropoutRate=0.5):
        super(SCCNet, self).__init__()
        Nu = C if Nu is None else Nu
        self.conv1 = nn.Conv2d(1, Nu, (C, Nt), bias=False)
        self.bn1 = nn.BatchNorm2d(num_features=1)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(1, Nc, (Nu, int(fs * 0.1)), bias=False)
        self.bn2 = nn.BatchNorm2d(num_features=Nc)
        self.relu2 = nn.ReLU(inplace=True)

        self.dropout = nn.Dropout(p=dropoutRate)
        # Pooling block
        self.avgpool = nn.AvgPool2d(kernel_size=(1, int(fs * 0.5)), stride=(1, int(fs * 0.1)))

        # get linear size
        fc_inSize = self.get_size(C, N)[1]
        self.classifier = nn.Linear(fc_inSize, nb_classes, bias=True)

    def forward(self, x):

        # First convolution block
        x = self.conv1(x)
        x = x.permute(0, 2, 1, 3)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.dropout(x)

        # Second convolution block
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = x ** 2

        # Pooling block
        x = self.avgpool(x)
        x = self.dropout(x)

        # Softmax block
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def get_size(self, C, N):
        data = torch.ones((1, 1, C, N))
        x = self.conv1(data)
        x = x.permute(0, 2, 1, 3)
        x = self.conv2(x)
        x = self.avgpool(x)
        x = x.view(x.size()[0], -1)
        return x.size()


class Permute2d(nn.Module):
    def __init__(self, shape):
        super(Permute2d, self).__init__()
        self.shape = shape

    def forward(self, x):
        return torch.permute(x, self.shape)

## Train

In [34]:
# config training scheme, mode, hyperparam
eegmodel_name = "SCCNet"
eegmodel = SCCNet  # function alias, should be EEGNet, ShallowConvNet, SCCNet, SCCNet_v2
kwargs = dict(fs=125.0, dropoutRate=0.5)  # custom args for different EEG model
scheme = "si"  # "ind", "si", "sd", "sift"
epochs = 200
batch_size = 16
lr = 1e-3
savepath = "/kaggle/working/checkpoints"
os.makedirs(savepath, exist_ok=True)

subject_id = 6

### load data

In [None]:
mat_T = loadmat(os.path.join(DATASET_DIR, f"A0{subject_id}T_output.mat"))
mat_label = loadmat(os.path.join(DATASET_DIR, f"A0{subject_id}T.mat"))
x_data, y_data = mat_T["data"], mat_label["classlabel"].squeeze()
x_train = np.zeros((0, *x_data.shape[1:]), dtype=x_data.dtype)
y_train = np.zeros((0,), dtype=y_data.dtype)
x_valid = np.zeros((0, *x_data.shape[1:]), dtype=x_data.dtype)
y_valid = np.zeros((0,), dtype=y_data.dtype)

for sub_id in range(1, 10):
    if sub_id == subject_id:
        continue
    mat_T = loadmat(os.path.join(DATASET_DIR, f"A0{sub_id}T_output.mat"))
    mat_label = loadmat(os.path.join(DATASET_DIR, f"A0{sub_id}T.mat"))
    x_data, y_data = mat_T["data"], mat_label["classlabel"].squeeze()
    for c in range(1, 5):
        x_, y_ = x_data[y_data == c], y_data[y_data == c]
        ## if subject_id==select_test don't add,else add 75% to train 25% to validation
        x_train = np.append(x_train, x_[:54], axis=0)
        y_train = np.append(y_train, y_[:54], axis=0)
        x_valid = np.append(x_valid, x_[54:], axis=0)
        y_valid = np.append(y_valid, y_[54:], axis=0)

for sub_id in range(1, 10):
    if sub_id == subject_id:
        continue
    mat_T = loadmat(os.path.join(DATASET_DIR, f"A0{sub_id}E_output.mat"))
    mat_label = loadmat(os.path.join(DATASET_DIR, f"A0{sub_id}E.mat"))
    x_data, y_data = mat_T["data"], mat_label["classlabel"].squeeze()
    for c in range(1, 5):
        x_, y_ = x_data[y_data == c], y_data[y_data == c]
        ## if subject_id==select_test don't add,else add 75% to train 25% to validation
        x_train = np.append(x_train, x_[:54], axis=0)
        y_train = np.append(y_train, y_[:54], axis=0)
        x_valid = np.append(x_valid, x_[54:], axis=0)
        y_valid = np.append(y_valid, y_[54:], axis=0)

# numpy array to tensor 
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train).long()
x_valid = torch.from_numpy(x_valid)
y_valid = torch.from_numpy(y_valid).long()

# tensor reshape for training
x_train = x_train.unsqueeze(1)
y_train = y_train - 1
y_train = F.one_hot(y_train, 4)

x_valid = x_valid.unsqueeze(1)
y_valid = y_valid - 1
y_valid = F.one_hot(y_valid, 4)


print("train: {}, {}".format(x_train.size(), y_train.size()))
print("valid: {}, {}".format(x_valid.size(), y_valid.size()))

# build training and validation dataloader
trainset = torch.utils.data.TensorDataset(x_train, y_train)
validset = torch.utils.data.TensorDataset(x_valid, y_valid)
tra_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(validset, batch_size=batch_size, shuffle=True, num_workers=2)

### training stage

In [None]:
# train an epoch, evaluate an epoch
# if you are familiar with Pytorch, you CAN custom these function
#  such as adding the lr_scheduler to optimize the training progress

def train_an_epoch(model, data_loader, loss_fn, optimizer):
    model.train()

    a, b = 0, 0  # hit sample, total sample
    epoch_loss = np.zeros((len(data_loader),))
    for i, (x_batch, y_batch) in enumerate(data_loader):
        x_batch, y_batch = x_batch.to(device, dtype=torch.float), y_batch.to(device, dtype=torch.float)
        
        optimizer.zero_grad()
        output = model(x_batch) 
        loss = loss_fn(output, y_batch)  
        ## L2_regularization
        l2_lambda = 0.0001
        l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
        loss += l2_lambda * l2_norm
        loss.backward()
        optimizer.step()

        epoch_loss[i] = loss.item()
        b += y_batch.size(0)
        a += torch.sum(y_batch.argmax(dim=1) == output.argmax(dim=1)).item()
    return epoch_loss.mean(), a / b  # return the loss and acc


def evaluate_an_epoch(model, data_loader, loss_fn):
    
    for m in model.modules():
        for child in m.children():
            if type(child) == nn.BatchNorm2d:
                child.track_running_stats = False
                child.running_mean = None
                child.running_var = None
                
    model.eval()
    a, b = 0, 0  # hit sample, total sample
    epoch_loss = np.zeros((len(data_loader),))
    for i, (x_batch, y_batch) in enumerate(data_loader):
        x_batch, y_batch = x_batch.to(device, dtype=torch.float), y_batch.to(device, dtype=torch.float)
        output = model(x_batch)  
        loss = loss_fn(output, y_batch) 

        epoch_loss[i] = loss.item()
        b += y_batch.size(0)
        a += torch.sum(y_batch.argmax(dim=1) == output.argmax(dim=1)).item()
    return epoch_loss.mean(), a / b  # return the loss and acc

In [None]:
if eegmodel_name=="EEGNet" or eegmodel_name=="ShallowConvNet":
    model = eegmodel(x_train.size(2), x_train.size(3), 4)  # EEGNet, ShallowConvNet
elif eegmodel_name=="SCCNet" or eegmodel_name=="SCCNet_v2":
    model = eegmodel(x_train.size(2), x_train.size(3), 4, **kwargs) # SCCNet
    
loss_fn = nn.CrossEntropyLoss()  # loss function, can be modified
opt_fn = torch.optim.Adam(model.parameters(), lr=lr)  # optimizer, CAN be modified

# dump the model structure
summary(model, input_size=(batch_size, *list(x_train.size()[1:])))

In [None]:
# if you are familiar with Pytorch, you CAN custom the following training loop

hist = dict(
    loss=np.zeros((epochs,)), val_loss=np.zeros((epochs,)),
    acc=np.zeros((epochs,)), val_acc=np.zeros((epochs,))
)
clock_ini = time.time()
for ep in range(epochs):
    loss, acc = train_an_epoch(model, tra_loader, loss_fn, opt_fn)
    val_loss, val_acc = evaluate_an_epoch(model, val_loader, loss_fn)
    print("Epoch {}: loss={:.4f}, acc={:.4f}, val_loss={:.4f}, val_acc={:.4f}".format(ep, loss, acc, val_loss, val_acc))
    hist["loss"][ep] = loss
    hist["acc"][ep] = acc
    hist["val_loss"][ep] = val_loss
    hist["val_acc"][ep] = val_acc

    if True:
        # save the pre-trained weight in each epoch, CAN be modified
        checkpoint = dict(epoch=-1, state_dict=model.state_dict(), loss=loss, val_loss=val_loss)
        torch.save(checkpoint, os.path.join(savepath, f"{eegmodel_name}_MODEL-ep{ep}.pth"))
print("time spend: {:.2f} sec".format(time.time() - clock_ini))

In [None]:
# Acc curve, Loss curve
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.title("Acc Curve")
plt.plot(hist["acc"], color="red")
plt.plot(hist["val_acc"], color="blue")
plt.subplot(1, 2, 2)
plt.title("Loss Curve")
plt.plot(hist["loss"], color="red")
plt.plot(hist["val_loss"], color="blue")
plt.show()

## Test

In [None]:
# load pre-trained model

best_epoch = hist["val_loss"].argmin()  # TODO: determine the `BEST` epoch
print(best_epoch)
test_model_path = os.path.join(savepath, "{}_MODEL-ep{}.pth".format(eegmodel_name, best_epoch))
checkpoint = torch.load(test_model_path, map_location="cpu")  # load .pth
model.load_state_dict(checkpoint["state_dict"])  # set model weight

# testing on subject S01 test (E) session
mat = loadmat(os.path.join(DATASET_DIR, f"A0{subject_id}E_output.mat"))
mat_label = loadmat(os.path.join(DATASET_DIR, f"A0{subject_id}E.mat"))
x, y = mat["data"], mat_label["classlabel"].squeeze()

x = torch.from_numpy(x)
y = torch.from_numpy(y).long()
x = x.unsqueeze(1)

y = y - 1
print(np.unique(y))
y = F.one_hot(y, 4)

print("test: {}, {}".format(x.size(), y.size()))

testset = torch.utils.data.TensorDataset(x, y)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
loss, acc = evaluate_an_epoch(model, test_loader, loss_fn)
print(loss, round(acc, 14))


## Analyzing
You need to do some further analysis including
- confusion matrix, 
- topographic maps for spatial kernel weights in SCCNet

You need to do these on your own in this section

In [None]:
def evaluate_an_epoch_pred(model, data_loader, loss_fn):
    model.eval()
    a, b = 0, 0  # hit sample, total sample
    epoch_loss = np.zeros((len(data_loader),))
    pred = []
    labels = []
    for i, (x_batch, y_batch) in enumerate(data_loader):
        x_batch, y_batch = x_batch.to(device, dtype=torch.float), y_batch.to(device, dtype=torch.float)
        output = model(x_batch)
        output_np = output.detach().cpu().numpy()
        ## save the label and prediction
        for p in output_np:
            pred.append(np.argmax(p))
        label = y_batch.detach().cpu().numpy()
        
        for l in label:
            labels.append(np.argmax(l))
            
        loss = loss_fn(output, y_batch)
        epoch_loss[i] = loss.item()
        b += y_batch.size(0)
        a += torch.sum(y_batch.argmax(dim=1) == output.argmax(dim=1)).item()
    return epoch_loss.mean(), a / b, pred, labels  # return the loss and acc


In [None]:
# mne doc: https://mne.tools/stable/python_reference.html
# "data_detail.json" in BCI_hw3_dataset/ provides you with all channel names in this dataset. It will help you to plot a topoplot.
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

## use evaluate that can return prediction and true label
cur_loss, cur_acc, cur_pred, cur_label = evaluate_an_epoch_pred(model, test_loader, loss_fn)

## construct a confusion matrix
print(f'Accuracy on subject {subject_id}: {round(cur_acc, 4)}')
cm = confusion_matrix(cur_label, cur_pred)
plt.imshow(cm, interpolation='nearest', cmap=plt.get_cmap('Blues'))
plt.colorbar()
tick_marks = np.arange(len(np.unique(cur_label)))
plt.xticks(tick_marks, np.unique(cur_label))
plt.yticks(tick_marks, np.unique(cur_label))
## add the number on the graph
thresh = cm.max() / 2.
for i, j in np.ndindex(cm.shape):
    plt.text(j, i, format(cm[i, j], 'd'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")

plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()