In [138]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, Dataset, random_split
import numpy as np
from os.path import join
import mat73
import matplotlib.pyplot as plt

In [139]:
# Load data into dictionary
DataPath = join("neuro_data","dataSubj10.mat")
data_dict = mat73.loadmat(DataPath, use_attrdict=True)
rawdata = data_dict["data"]

In [162]:
class CreateDataset(Dataset):
    """Creates dataset for EEGNET, meaning move all relevant channel data into one matrix for x, and results into y

    Args:
        Dataset (_type_): _description_
    """
    def __init__(self, data, channels, crop=None):
        # Associates channel names with channel data in a dictionary
        datadicts = [dict(zip(np.squeeze(data["label"]),dat)) for dat in data["trial"]]

        x, y = [0]*len(data["trialinfo"]), np.array([0]*len(data["trialinfo"]))
        
        # Extract the y-values, i.e. which side the audio was played
        for i, trialinfo in enumerate(data["trialinfo"]):
            # side : left = 1, right = 0
            y[i] = int(trialinfo[0]["side"])==1
        
        self.y = y
        
        # Extract only information from relevant channels
        for i, dat in enumerate(datadicts):
            x[i] = [dat[ch] for ch in channels]
        x = np.array(x)
        
        if crop:
            x=x[:,:,crop[0]:crop[1]]
        
        self.x = x

    def __getitem__(self, index):
        feature = torch.tensor([self.x[index]], dtype=torch.float32)
        label = torch.tensor([self.y[index]], dtype=torch.float32)

        return feature, label
    
    def __len__(self):
        return len(self.x)
        

In [172]:
# https://towardsdatascience.com/convolutional-neural-networks-for-eeg-brain-computer-interfaces-9ee9f3dd2b81
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class EEGNET(nn.Module):
    def __init__(self, filter_sizing, dropout, D, channel_amount, receptive_field=128, mean_pool=15):
        #FIXME: D, filter_sizing and dropout choose on hyper parameter search
        super(EEGNET,self).__init__()
        self.temporal=nn.Sequential(
            nn.Conv2d(1,filter_sizing,kernel_size=[1,receptive_field],stride=1, bias=False,\
                padding='same'), 
            nn.BatchNorm2d(filter_sizing),
        )
        self.spatial=nn.Sequential(
            nn.Conv2d(filter_sizing,filter_sizing*D,kernel_size=[channel_amount,1],bias=False,\
                groups=filter_sizing),
            nn.BatchNorm2d(filter_sizing*D),
            nn.ELU(True),
        )

        self.seperable=nn.Sequential(
            nn.Conv2d(filter_sizing*D,filter_sizing*D,kernel_size=[1,16],\
                padding='same',groups=filter_sizing*D, bias=False),
            nn.Conv2d(filter_sizing*D,filter_sizing*D,kernel_size=[1,1], padding='same',groups=1, bias=False),
            nn.BatchNorm2d(filter_sizing*D),
            nn.ELU(True),
        )

        self.avgpool1 = nn.AvgPool2d([1, mean_pool], stride=[1, mean_pool], padding=0)   
        self.avgpool2 = nn.AvgPool2d([1, mean_pool], stride=[1, mean_pool], padding=0)
        self.dropout = nn.Dropout(dropout)
        self.view = nn.Sequential(Flatten())

        endsize = 48
        self.fc2 = nn.Linear(endsize, 1)

    def forward(self,x):
        out = self.temporal(x)
        out = self.spatial(out)
        out = self.avgpool1(out)
        out = self.dropout(out)
        out = self.seperable(out)
        out = self.avgpool2(out)
        out = self.dropout(out)
        out = out.view(out.size(0), -1)
        prediction = self.fc2(out)
        return torch.sigmoid(prediction)

In [142]:
def evaluate_loss(model, criterion, dataloader):
    model.eval()
    total_loss = 0.0
    for batch_X, batch_y in dataloader:
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        total_loss += loss.item()
        
    return total_loss / len(dataloader)

In [143]:
def evaluate_acc(model, dataloader):
    model.eval()
    total_acc = 0.0
    for batch_X, batch_y in dataloader:
        outputs = model(batch_X)
        predictions = 1.0*(outputs>0.5)
        total_acc += (predictions==batch_y).sum()
        
    return total_acc / len(dataloader.dataset)

In [144]:
def train(model, criterion, optimizer, train_loader, valid_loader):
    train_loss_list = []
    valid_loss_list = []
    train_acc_list = []
    valid_acc_list = []
    for epoch in range(1, 16):
        model.train()
        for batch_X, batch_y in train_loader:
            ypred = model.forward(batch_X)
            loss = criterion(ypred, batch_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        train_loss = evaluate_loss(model, criterion, train_loader)
        valid_loss = evaluate_loss(model, criterion, valid_loader)
        train_acc = evaluate_acc(model, train_loader)
        valid_acc = evaluate_acc(model, valid_loader)
        train_loss_list.append(train_loss)
        valid_loss_list.append(valid_loss)
        train_acc_list.append(train_acc)
        valid_acc_list.append(valid_acc)

        print(f"| epoch {epoch:2d} | train loss {train_loss:.6f} | train acc {train_acc:.6f} | valid loss {valid_loss:.6f} | valid acc {valid_acc:.6f} |")

    return train_loss_list, valid_loss_list, train_acc_list, valid_acc_list

In [164]:
fs = 512
lo = int(2.6*fs)
hi = int(5.5*fs)
crop = [lo, hi]
print(hi-lo)

1485


# Training

In [176]:
# fix random seed
np.random.seed(293210931)
torch.manual_seed(293210931)

## CROPPING for data, 2.6s-5.6s
fs = 512
lo = int(2.6*fs)
hi = int(5.5*fs)
crop = [lo, hi]

dataset = CreateDataset(rawdata, ["T7","FT7","TP7","TP8","FT8","T8"], crop)
dat_train, dat_val, dat_test = random_split(dataset, [0.7,0.1,0.2])

train_loader = DataLoader(dat_train, batch_size=16)
val_loader = DataLoader(dat_val, batch_size=16)
test_loader = DataLoader(dat_test, batch_size=16)

# Set up elements
model = EEGNET(4,0.5,2,6,mean_pool=15)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0002)

# Train network
train(model, criterion, optimizer, train_loader, val_loader) #Should be val instead of test


| epoch  1 | train loss 0.694393 | train acc 0.502024 | valid loss 0.691742 | valid acc 0.428571 |
| epoch  2 | train loss 0.694853 | train acc 0.530364 | valid loss 0.688271 | valid acc 0.485714 |
| epoch  3 | train loss 0.694596 | train acc 0.546559 | valid loss 0.683260 | valid acc 0.600000 |
| epoch  4 | train loss 0.694746 | train acc 0.530364 | valid loss 0.681262 | valid acc 0.600000 |
| epoch  5 | train loss 0.692692 | train acc 0.538462 | valid loss 0.674889 | valid acc 0.628571 |
| epoch  6 | train loss 0.692112 | train acc 0.546559 | valid loss 0.673521 | valid acc 0.600000 |
| epoch  7 | train loss 0.691525 | train acc 0.546559 | valid loss 0.672336 | valid acc 0.600000 |
| epoch  8 | train loss 0.690126 | train acc 0.554656 | valid loss 0.666643 | valid acc 0.628571 |
| epoch  9 | train loss 0.689024 | train acc 0.558704 | valid loss 0.663479 | valid acc 0.600000 |
| epoch 10 | train loss 0.688084 | train acc 0.570850 | valid loss 0.659576 | valid acc 0.657143 |
| epoch 11

([0.694392554461956,
  0.6948527991771698,
  0.6945959888398647,
  0.6947462521493435,
  0.6926921978592873,
  0.6921121552586555,
  0.6915251761674881,
  0.690125547349453,
  0.6890236996114254,
  0.6880837343633175,
  0.6879018470644951,
  0.6876201964914799,
  0.6870603635907173,
  0.6869524121284485,
  0.6863409765064716],
 [0.6917416850725809,
  0.6882707675298055,
  0.6832601626714071,
  0.6812621156374613,
  0.6748885909716288,
  0.6735211412111918,
  0.6723364591598511,
  0.6666430632273356,
  0.6634794473648071,
  0.6595760782559713,
  0.6627465486526489,
  0.6617307265599569,
  0.65932563940684,
  0.6649388074874878,
  0.6632795333862305],
 [tensor(0.5020),
  tensor(0.5304),
  tensor(0.5466),
  tensor(0.5304),
  tensor(0.5385),
  tensor(0.5466),
  tensor(0.5466),
  tensor(0.5547),
  tensor(0.5587),
  tensor(0.5709),
  tensor(0.5709),
  tensor(0.5870),
  tensor(0.5749),
  tensor(0.5709),
  tensor(0.5870)],
 [tensor(0.4286),
  tensor(0.4857),
  tensor(0.6000),
  tensor(0.6000),