# Notebook for training TopoNEt model on MEG data for predicting the task being performed

Model architecture - https://arxiv.org/pdf/1611.08024

Training data - MEG data recorder while people were performing different tasks.

In [12]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from pymatreader import read_mat
import numpy as np

## First experiment
0.5 seconds before the actual experiment has started. Participant is informed, what he needs to pay attention to.
Only one participant is used for training and testing.

### Loading data

In [13]:
def load_first_500ms(file_path: str) -> list:
    """
    Loads trials from a file and takes the first 500ms of data from them.
    :return: (500ms_of_trial, label) where label indicates what subject was attenting to.
    """
    label = file_path.split('_')[-2]

    data = read_mat(file_path)
    cut_data = []
    for trial in data['finalStruct']['trial']:
        cut_trial = np.empty([0, 500])
        for sensor in trial:
            cut_trial = np.append(cut_trial, sensor[:500])
        cut_data.append((cut_trial, label))
    return cut_data

In [15]:
SUBJECT_NUMBER = '01' # We can train our models on different participants

data = load_first_500ms(f'/mnt/diska/baldauf/Subject_{SUBJECT_NUMBER}_OnsetStim_BOT_scoutTimeSeriesNew.mat') + \
    load_first_500ms(f'/mnt/diska/baldauf/Subject_{SUBJECT_NUMBER}_OnsetStim_TOP_scoutTimeSeriesNew.mat')

# Since we have a small amount of data, we will only use traint and test
test_size = int(0.1 * len(data))
train_dataset, test_dataset = random_split(data, [len(data) - test_size, test_size])

torch.manual_seed(42) # for reproducibility
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

len(data) - test_size

219

### Model classes

Here we implement EEGNet, but with some modifications, due to the nature of our data and need to use topographical constraints.

In [None]:
class ConstrainedConv2d(nn.Conv2d):
    """
    Implementation of maximum norm constraint for Conv2D layer
    """
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.conv2d(x, self.weight.clamp(max=1.0), self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

class ConstrainedLinear(nn.Linear):
    """
    Implementation of maximum norm constraint for Linear layer
    """
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.linear(x, self.weight.clamp(max=0.5), self.bias)

class EEGNet(nn.Module):
    def __init__(self, sensors: int, samples: int, num_classes: int, f1: int, depth: int, f2: int, dropout: float):
        super().__init__()
        self.block1 = nn.Sequential(
            # Since we have only 500ms of data we set kernel length at 250
            # So we can capture patterns with frequency above 4Hz
            nn.Conv2d(in_channels=1, out_channels=f1, kernel_size=(1, 250), padding='same',
                      bias=False),
            nn.BatchNorm2d(f1),
            ConstrainedConv2d(in_channels=f1, out_channels=f1*depth, kernel_size=(sensors, 1), padding='same',
                      groups=f1, bias=False),
            nn.BatchNorm2d(f1*depth),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, 4)),
            nn.Dropout(dropout),
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=f1*depth, out_channels=f1*depth, kernel_size=(1, 16), padding='same',
                      groups=f1*depth, bias=False),
            nn.Conv2d(in_channels=f1*depth, out_channels=f2, kernel_size=(1, 1), bias=False),
            nn.BatchNorm2d(f2),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, 8)),
            nn.Dropout(dropout),
            nn.Flatten(),
        )
        # We have to add one dense layer in order to implement topographical constraints
        self.linear = ConstrainedLinear(in_features=f2*samples//32, out_features=36)
        self.classifier = nn.Linear(in_features=36, out_features=num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.block1(x)
        x = self.block2(x)
        x = self.linear(x)
        return self.classifier(x)
