In [1]:
import torch
from train import train
from data_module import EEGDataModule
from model import ViTransformer, LSTM, ConvLSTM, DeepConvNet, RNN, LSTM, ShallowConvNet, EEGNet_Modified
from ATCNet import ATCNet

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
device = torch.device(0)

In [4]:
from preprocessing import mean_max_subsampling

In [5]:
train_params = {
    "train_epochs": 10,
    "accumulate_grad_batches": 1,
    "test_size": 0.2,
    "random_state": 42,
    "data_dir": "../../Data/",
    "train_batch_size": 64,
    "eval_batch_size": 32
}

In [6]:
dataset = EEGDataModule(args=train_params)

In [7]:
dataset.setup(transform=None)

INFO:data_module:Training data shape: (1692, 22, 1000)
INFO:data_module:Training labels shape: (1692,)


In [8]:
train_dataloader, valid_dataloader = dataset.train_dataloader(), dataset.val_dataloader()

INFO:data_module:loaded 1692 train data instances
INFO:data_module:loaded 423 train data instances


In [9]:
model = EEGNet_Modified()



In [10]:
from torchinfo import summary
# Only uses outputs of modules.
print(summary(model))

Layer (type:depth-idx)                   Param #
EEGNet_Modified                          --
├─Conv2d: 1-1                            512
├─BatchNorm2d: 1-2                       16
├─Conv2d: 1-3                            704
├─BatchNorm2d: 1-4                       64
├─ELU: 1-5                               --
├─AvgPool2d: 1-6                         --
├─AvgPool2d: 1-7                         --
├─Dropout: 1-8                           --
├─Dropout: 1-9                           --
├─Conv2d: 1-10                           16,384
├─BatchNorm2d: 1-11                      64
├─LazyLinear: 1-12                       --
Total params: 17,744
Trainable params: 17,744
Non-trainable params: 0


In [11]:
loss_hist, acc_hist, val_loss_hist, val_acc_hist = train(model, train_dataloader, valid_dataloader, device, weight_decay=1e-5)

  self.padding, self.dilation, self.groups)
100%|██████████████████████████████| 100/100 [00:35<00:00,  2.79it/s, acc=0.985, val_acc=0.73]


In [16]:
model = EEGNet_Modified()

In [17]:
model.load_state_dict(torch.load("./model/EEG_modified"))

<All keys matched successfully>

In [22]:
model.eval()

EEGNet_Modified(
  (temporal_conv1): Conv2d(1, 8, kernel_size=(1, 64), stride=(1, 1), padding=same, bias=False)
  (batch_norm_1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (depth_wise_conv): Conv2d(8, 32, kernel_size=(22, 1), stride=(1, 1), groups=8, bias=False)
  (batch_norm_2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (elu): ELU(alpha=1.0)
  (average_pool1): AvgPool2d(kernel_size=(1, 8), stride=(1, 8), padding=0)
  (average_pool2): AvgPool2d(kernel_size=(1, 8), stride=(1, 8), padding=0)
  (dropout1): Dropout(p=0.3, inplace=False)
  (dropout2): Dropout(p=0.3, inplace=False)
  (spatial_conv1): Conv2d(32, 32, kernel_size=(1, 16), stride=(1, 1), padding=same, bias=False)
  (batch_norm_3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (temp_linear): LazyLinear(in_features=0, out_features=4, bias=True)
)

In [28]:
def eval(dataloader, model):
    ns = 0
    nc = 0
    with torch.no_grad():
        for batch in dataloader:
            x, y = batch
            if device is not None:
                x = x.to(device)
                y = y.to(device)
            out = model(x)
            ns += len(y)
            nc += (out.max(1)[1] == y).detach().cpu().numpy().sum()
    return nc/ns

In [29]:
test_dataloader = dataset.test_dataloader()

INFO:data_module:loaded 443 train data instances


In [30]:
eval(test_dataloader, model.to(device))

(0.6997742663656885, None)

In [32]:
model = ATCNet()
model.load_state_dict(torch.load("./model/ATCNet"))
model.eval()
eval(test_dataloader, model.to(device))

(0.6704288939051919, None)

In [33]:
model = DeepConvNet()
model.load_state_dict(torch.load("./model/DeepConvNet"))
model.eval()
eval(test_dataloader, model.to(device))

(0.636568848758465, None)

In [34]:
model = ShallowConvNet()
model.load_state_dict(torch.load("./model/ShallowConvNet"))
model.eval()
eval(test_dataloader, model.to(device))

(0.5507900677200903, None)