In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
import os
from multiprocessing import Pool
from tqdm import tqdm
from functools import partial
from sklearn import preprocessing
import math

## Supervised Learning

In [2]:
class EEGDataset(Dataset):
    def __init__(self, dir, classes):
        if isinstance(dir, str):
            dir = Path(dir)
        subjects = [os.path.splitext(file)[0] for file in os.listdir(dir) 
                    if not file.startswith(".")]
        print(subjects)
        print(f"Scanning all npz files in the {dir}")
        with Pool() as p:
            result = list(
                tqdm(
                    p.imap(partial(self.load_npz_data, dir=dir, classes=classes), subjects), 
                    total=len(subjects), 
                    bar_format="{l_bar}{bar:10}{r_bar}"
                )
            )
        # print(result)
        _, ch, n_times = result[0][0].shape
        self.raw_data = np.array([row[0] for row in result]).reshape(-1, ch, n_times)
        self.label = np.array([row[1] for row in result]).reshape(-1).astype(np.int64)
        self.data = self.scale_data(self.raw_data)
        self.data = np.expand_dims(self.data, axis=1).astype(np.float32)
        # print(self.data.dtype, self.label.dtype)
        self.len = len(self.label)

    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        return torch.from_numpy(self.data)[idx], torch.from_numpy(self.label)[idx]

    @staticmethod
    def load_npz_data(subject, dir, classes):
        data = []
        if isinstance(dir, str):
            dir = Path(dir)
        npz_fpath = dir / (subject + ".npz")
        # print(npz_fpath)
        item = np.load(npz_fpath)["event_datas"]
        bs, ch, n_times = item.shape
        data.append(item)
        label = np.repeat(np.arange(classes), repeats=bs / classes)
        return [np.array(data).reshape(-1, ch, n_times), label]
    
    def scale_data(self, raw_data):
        for i in range(len(raw_data)):
            raw_data[i, ...] = preprocessing.scale(raw_data[i, ...], axis=1)
        return raw_data

In [None]:
dataset = EEGDataset(dir=Path("./train_course/npz_data"), classes=8)

In [4]:
def autopad(k, p=None, d=1):
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
    return p

In [5]:
class Conv(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, p=None, 
                 g=1, d=1, act=True, has_bias=True):
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), 
                              groups=g, dilation=d, bias=has_bias)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))
    
    def forward_fuse(self, x):
        self.act(self.conv(x))

In [6]:
class Conv2dWithConstraint(nn.Conv2d):
    def __init__(self, *args, max_norm=1.0, **kwargs):
        self.max_norm = max_norm
        super().__init__(*args, **kwargs)
    
    def forward(self, x):
        self.weight.data = torch.renorm(
            self.weight.data, p=2, dim=0, maxnorm=self.max_norm
        )
        return super().forward(x)

In [7]:
class ConvWithConstraint(nn.Module):
    def __init__(self, c1, c2, k=1, max_norm=1.0, 
                 s=1, p=None, g=1, d=1, act=True, has_bias=True):
        super().__init__()
        self.conv = Conv2dWithConstraint(c1, c2, k, max_norm=max_norm, stride=s, 
                                         padding=autopad(k, p, d), groups=g, dilation=d, 
                                         bias=has_bias)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))

In [8]:
class DWConvWithConstraint(ConvWithConstraint):
    def __init__(self, c1, c2, max_norm=1.0, k=1, act=True, s=1, p=0, d=1):
        super().__init__(c1, c2, k, max_norm=max_norm, s=s, p=p, g=math.gcd(c1, c2), d=d, act=act)

In [9]:
class SeparableConv(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, p=None, 
                 d=1, act=True, has_bias=True):
        super().__init__()
        self.dw = nn.Conv2d(c1, c1, k, stride=s, padding=autopad(k, p, d), 
                            dilation=d, groups=c1, bias=has_bias)
        self.pw = nn.Conv2d(c1, c2, 1, stride=s, bias=has_bias)
        self.bn = nn.BatchNorm2d(c2)

        self.act = nn.SiLU() if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.pw(self.dw(x))))

In [10]:
class LinearWithConstraint(nn.Linear):
    def __init__(self, *args, max_norm=1.0, **kwargs):
        self.max_norm = max_norm
        super().__init__(*args, **kwargs)

    def forward(self, x):
        self.weight.data = torch.renorm(
            self.weight.data, p=2, dim=0, maxnorm=self.max_norm
        )
        return super().forward(x)

In [11]:
class EEGNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.temporal_conv = Conv(1, 8, (1, 31))
        self.spatial_conv = DWConvWithConstraint(8, 16, 2., (64, 1))
        self.separable_conv = SeparableConv(16, 16, (1, 15))

        self.linear = LinearWithConstraint(256, 8, 0.5)

        self.avg_pool = nn.AvgPool2d((1, 4))
        self.dropout = nn.Dropout(0.65)
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.spatial_conv(self.temporal_conv(x))
        x = self.dropout(self.avg_pool(x))
        x = self.separable_conv(x)
        x = self.dropout(self.avg_pool(x))
        return self.linear(self.flatten(x))

In [None]:
eegnet = EEGNet()

In [None]:
from d2lightrainer.SupervisedLearning.trainer_config import CLSTrainerConfig
from d2lightrainer.SupervisedLearning.trainer import CLSTrainer

In [None]:
cls_cfg = CLSTrainerConfig()
new_param_dict = {"batch_size": 32, "nominal_batch_size": 64, "save_dir": "runs_eeg"}
cls_cfg.update(**new_param_dict)

In [None]:
cls_trainer = CLSTrainer(eegnet, dataset, cls_cfg)
cls_trainer.train()