In [1]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, Dataset, random_split
import numpy as np
from os.path import join
import mat73
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
class CreateDataset(Dataset):
    """Creates dataset for EEGNET, meaning move all relevant channel data into one matrix for x, and results into y

    Args:
        Dataset (_type_): Takes in data loaded from Matlab and formats appropriately.
    """
    
    def __init__(self, data, channels, crop=None):
        # Associates channel names with channel data in a dictionary
        datadicts = [dict(zip(np.squeeze(data["label"]),dat)) for dat in data["trial"]]

        x, y = [0]*len(data["trialinfo"]), np.array([0]*len(data["trialinfo"]))
        
        # Extract the y-values, i.e. which side the audio was played
        for i, trialinfo in enumerate(data["trialinfo"]):
            # side : left = 1, right = 0
            y[i] = int(trialinfo[0]["side"])==1
        
        self.y = y
        
        # Extract only information from relevant channels
        for i, dat in enumerate(datadicts):
            x[i] = [dat[ch] for ch in channels]
        x = np.array(x)
        
        # Include only specific section of time-series.
        if crop:
            x=x[:,:,crop[0]:crop[1]]
        
        self.x = x

    def __getitem__(self, index):
        feature = torch.tensor([self.x[index]], dtype=torch.float32)
        label = torch.tensor([self.y[index]], dtype=torch.float32)

        return feature, label
    
    def __len__(self):
        return len(self.x)
        

In [3]:
# https://towardsdatascience.com/convolutional-neural-networks-for-eeg-brain-computer-interfaces-9ee9f3dd2b81
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class EEGNET(nn.Module):
    def __init__(self, filter_sizing, dropout, D, channel_amount, receptive_field=128, mean_pool=15):
        #FIXME: D, filter_sizing and dropout choose on hyper parameter search
        super(EEGNET,self).__init__()
        self.temporal=nn.Sequential(
            nn.Conv2d(1,filter_sizing,kernel_size=[1,receptive_field],stride=1, bias=False,\
                padding='same'), 
            nn.BatchNorm2d(filter_sizing),
        )
        self.spatial=nn.Sequential(
            nn.Conv2d(filter_sizing,filter_sizing*D,kernel_size=[channel_amount,1],bias=False,\
                groups=filter_sizing),
            nn.BatchNorm2d(filter_sizing*D),
            nn.ELU(True),
        )

        self.seperable=nn.Sequential(
            nn.Conv2d(filter_sizing*D,filter_sizing*D,kernel_size=[1,16],\
                padding='same',groups=filter_sizing*D, bias=False),
            nn.Conv2d(filter_sizing*D,filter_sizing*D,kernel_size=[1,1], padding='same',groups=1, bias=False),
            nn.BatchNorm2d(filter_sizing*D),
            nn.ELU(True),
        )

        self.avgpool1 = nn.AvgPool2d([1, mean_pool], stride=[1, mean_pool], padding=0)   
        self.avgpool2 = nn.AvgPool2d([1, mean_pool], stride=[1, mean_pool], padding=0)
        self.dropout = nn.Dropout(dropout)
        self.view = nn.Sequential(Flatten())

        #FIXME: Figure out expression for endsize
        endsize = 48 
        self.fc2 = nn.Linear(endsize, 1)

    def forward(self,x):
        out = self.temporal(x)
        out = self.spatial(out)
        out = self.avgpool1(out)
        out = self.dropout(out)
        out = self.seperable(out)
        out = self.avgpool2(out)
        out = self.dropout(out)
        out = out.view(out.size(0), -1)
        prediction = self.fc2(out)
        return torch.sigmoid(prediction)

In [4]:
def evaluate_loss(model, criterion, dataloader):
    model.eval()
    total_loss = 0.0
    for batch_X, batch_y in dataloader:
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        total_loss += loss.item()
        
    return total_loss / len(dataloader)

In [5]:
def evaluate_acc(model, dataloader):
    model.eval()
    total_acc = 0.0
    for batch_X, batch_y in dataloader:
        outputs = model(batch_X)
        predictions = 1.0*(outputs>0.5)
        total_acc += (predictions==batch_y).sum()
        
    return total_acc / len(dataloader.dataset)

In [6]:
def train(model, criterion, optimizer, train_loader, valid_loader, n_epochs):
    train_loss_list = []
    valid_loss_list = []
    train_acc_list = []
    valid_acc_list = []
    for epoch in range(1, n_epochs):
        model.train()
        for batch_X, batch_y in train_loader:
            ypred = model.forward(batch_X)
            loss = criterion(ypred, batch_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        train_loss = evaluate_loss(model, criterion, train_loader)
        valid_loss = evaluate_loss(model, criterion, valid_loader)
        train_acc = evaluate_acc(model, train_loader)
        valid_acc = evaluate_acc(model, valid_loader)
        train_loss_list.append(train_loss)
        valid_loss_list.append(valid_loss)
        train_acc_list.append(train_acc)
        valid_acc_list.append(valid_acc)

        print(f"| epoch {epoch:2d} | train loss {train_loss:.6f} | train acc {train_acc:.6f} | valid loss {valid_loss:.6f} | valid acc {valid_acc:.6f} |")

    return train_loss_list, valid_loss_list, train_acc_list, valid_acc_list

# Training

In [10]:
# Load data into dictionary
DataPath = join("neuro_data","dataSubj10.mat")
data_dict = mat73.loadmat(DataPath, use_attrdict=True)

RAWDATA = data_dict["data"]
BATCH_SIZE=16
## CROPPING for data, 2.6s-5.6s, region of interest with audio
FS = 512
LO = int(2.6*FS) #1331
HI = int(5.5*FS) #2816
# Dif : 1485 = 3 * 3 * 3 * 5 * 11

In [9]:
# fix random seed
np.random.seed(293210931)
torch.manual_seed(293210931)



dataset = CreateDataset(RAWDATA, ["T7","FT7","TP7","TP8","FT8","T8"], [LO, HI])
dat_train, dat_val, dat_test = random_split(dataset, [0.7,0.1,0.2])

train_loader = DataLoader(dat_train, batch_size=BATCH_SIZE)
val_loader = DataLoader(dat_val, batch_size=BATCH_SIZE)
test_loader = DataLoader(dat_test, batch_size=BATCH_SIZE)

# Set up elements
model = EEGNET(4,0.5,2,6,mean_pool=15)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)

# Train network
train(model, criterion, optimizer, train_loader, val_loader, 31) #Should be val instead of test


| epoch  1 | train loss 0.695640 | train acc 0.493927 | valid loss 0.689675 | valid acc 0.571429 |
| epoch  2 | train loss 0.697051 | train acc 0.497976 | valid loss 0.689092 | valid acc 0.571429 |
| epoch  3 | train loss 0.697182 | train acc 0.514170 | valid loss 0.689598 | valid acc 0.585714 |
| epoch  4 | train loss 0.697189 | train acc 0.510121 | valid loss 0.690259 | valid acc 0.571429 |
| epoch  5 | train loss 0.695777 | train acc 0.534413 | valid loss 0.691431 | valid acc 0.557143 |
| epoch  6 | train loss 0.695291 | train acc 0.534413 | valid loss 0.691375 | valid acc 0.557143 |
| epoch  7 | train loss 0.694836 | train acc 0.526316 | valid loss 0.690873 | valid acc 0.557143 |
| epoch  8 | train loss 0.694142 | train acc 0.538462 | valid loss 0.691819 | valid acc 0.542857 |
| epoch  9 | train loss 0.693440 | train acc 0.550607 | valid loss 0.691394 | valid acc 0.528571 |
| epoch 10 | train loss 0.692434 | train acc 0.542510 | valid loss 0.691948 | valid acc 0.500000 |
| epoch 11

([0.6956404000520706,
  0.6970506645739079,
  0.6971815563738346,
  0.6971893757581711,
  0.6957773752510548,
  0.6952908001840115,
  0.6948364078998566,
  0.6941417679190636,
  0.6934401765465736,
  0.6924337595701218,
  0.6922488510608673,
  0.6921272836625576,
  0.6915441751480103,
  0.6912834495306015,
  0.69057197868824,
  0.689300786703825,
  0.6895701177418232,
  0.6886910535395145,
  0.6883344277739525,
  0.6873954087495804,
  0.6872167140245438,
  0.687377069145441,
  0.6867239363491535,
  0.6863992102444172,
  0.6863821819424629,
  0.6862100102007389,
  0.6862549334764481,
  0.6864646710455418,
  0.6862837001681328,
  0.6861419156193733],
 [0.6896751761436463,
  0.6890915274620056,
  0.6895977616310119,
  0.6902586817741394,
  0.6914306044578552,
  0.6913747191429138,
  0.690873384475708,
  0.6918186783790589,
  0.6913937211036683,
  0.6919482707977295,
  0.6913990497589111,
  0.6904971480369568,
  0.690432333946228,
  0.6915205955505371,
  0.6922442555427551,
  0.69250543117