In [1]:
from torch.utils.data import DataLoader
from utils import  TrainDataset, EEGNetModel
from torchsummary import summary 

## Dataset
The dataset contains the EEG recordings of 15 subjects. For each subject, we have 15 different recordings, each one collected while watching a different movie clip. Each clip is associated to an emotional state amon {sad: -1, neutral: 0, happy: 1}. EEG recordings comprises 62 channels.

N.B. Recordings correspondent to the same movies have the same length, while recordings correspondent to different movies have different length (iun general). How to do? No problem, since we are taking sub windows of the signals.

Data have been preprocessed by downsampling signals to 200Hz, segmentating the signals such that it corresponds to the length of the movie and applying a band-pass filter at 0-75Hz. Since recordings are about 4 minutes long and are now sampled at 200Hz, they contain roughly 48k time points each.

In [10]:
dataset = TrainDataset("data/Preprocessed_EEG", "data/Preprocessed_EEG/label.mat", 1000, 100, False)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

Loading data files: 100%|██████████| 45/45 [01:05<00:00,  1.45s/it]


In [7]:
model = EEGNetModel(input_size=(62, 1000))
summary(model, (1, 62, 1000), batch_size=-1, device="cpu")

torch.Size([2, 1, 62, 1000])
torch.Size([2, 16, 62, 1000])
torch.Size([2, 32, 56, 1000])
torch.Size([2, 32, 28, 200])
torch.Size([2, 32, 28, 200])
torch.Size([2, 32, 14, 40])
torch.Size([2, 17920])
torch.Size([2, 128])
torch.Size([2, 3])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 16, 62, 1000]           1,040
       BatchNorm2d-2         [-1, 16, 62, 1000]              32
            Conv2d-3         [-1, 32, 56, 1000]           3,616
       BatchNorm2d-4         [-1, 32, 56, 1000]              64
              ReLU-5         [-1, 32, 56, 1000]               0
         AvgPool2d-6          [-1, 32, 28, 200]               0
           Dropout-7          [-1, 32, 28, 200]               0
            Conv2d-8          [-1, 32, 28, 200]          16,416
            Conv2d-9          [-1, 32, 28, 200]           1,056
      BatchNorm2d-10          [-1, 32, 28, 200]          

In [13]:
batch1 = next(iter(dataloader))

In [14]:
print(batch1[0].shape, batch1[0].dtype)
print(batch1[1].shape, batch1[1].dtype)

torch.Size([16, 1, 62, 1000]) torch.float32
torch.Size([16]) torch.int16


In [None]:
model(batch1[0])