In [1]:
# !pip install torchmetrics
# !pip install omegaconf
# !pip install wandb
# !pip install einops
import os, sys
import numpy as np
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
from omegaconf import DictConfig
import wandb
from termcolor import cprint
from tqdm import tqdm
import os
import numpy as np
import torch
from typing import Tuple
from termcolor import cprint
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
import zipfile
import random
import numpy as np
import torch

In [2]:
class ThingsMEGDataset(torch.utils.data.Dataset):
    def __init__(self, split: str, data_dir: str = "data") -> None:
        super().__init__()

        assert split in ["train", "val", "test"], f"Invalid split: {split}"
        self.split = split
        self.num_classes = 1854

        self.X = torch.load(os.path.join(data_dir, f"{split}_X.pt"))
        self.subject_idxs = torch.load(os.path.join(data_dir, f"{split}_subject_idxs.pt"))

        if split in ["train", "val"]:
            self.y = torch.load(os.path.join(data_dir, f"{split}_y.pt"))
            assert len(torch.unique(self.y)) == self.num_classes, "Number of classes do not match."

    def __len__(self) -> int:
        return len(self.X)

    def __getitem__(self, i):
        if hasattr(self, "y"):
            return self.X[i], self.y[i], self.subject_idxs[i]
        else:
            return self.X[i], self.subject_idxs[i]

    @property
    def num_channels(self) -> int:
        return self.X.shape[1]

    @property
    def seq_len(self) -> int:
        return self.X.shape[2]

In [3]:
import torch
import torch.nn as nn


class EEGNet2d(nn.Module):
    """
    four block:
    1. conv2d
    2. depthwiseconv2d
    3. separableconv2d
    4. classify
    """
    def __init__(self, batch_size=4, num_class=2):
        super(EEGNet2d, self).__init__()
        self.batch_size = batch_size
        # 1. conv2d
        self.block1 = nn.Sequential()
        self.block1_conv = nn.Conv2d(in_channels=1,
                                     out_channels=8,
                                     kernel_size=(1, 64),
                                     padding=(0, 32),
                                     bias=False
                                     )
        self.block1.add_module('conv1', self.block1_conv)
        self.block1.add_module('norm1', nn.BatchNorm2d(8))

        # 2. depthwiseconv2d
        self.block2 = nn.Sequential()
        # [N, 8, 64, 128] -> [N, 16, 1, 128]
        self.block2.add_module('conv2', nn.Conv2d(in_channels=8,
                                                  out_channels=16,
                                                  kernel_size=(271, 1),
                                                  groups=2,
                                                  bias=False))
        self.block2.add_module('norm3', nn.BatchNorm2d(16))
        self.block2.add_module('act1', nn.ELU())
        # [N, 16, 1, 128] -> [N, 16, 1, 32]
        self.block2.add_module('pool1', nn.AvgPool2d(kernel_size=(1, 4)))
        self.block2.add_module('drop1', nn.Dropout(p=0.5))

        # 3. separableconv2d
        self.block3 = nn.Sequential()
        self.block3.add_module('conv3', nn.Conv2d(in_channels=16,
                                                  out_channels=32,
                                                  kernel_size=(1, 16),
                                                  padding=(0, 8),
                                                  groups=16,
                                                  bias=False
                                                  ))
        self.block3.add_module('conv4', nn.Conv2d(in_channels=32,
                                                  out_channels=64,
                                                  kernel_size=(1, 1),
                                                  bias=False))
        self.block3.add_module('norm2', nn.BatchNorm2d(64))
        self.block3.add_module('act2', nn.ELU())
        self.block3.add_module('pool2', nn.AvgPool2d(kernel_size=(1, 8)))
        self.block3.add_module('drop2', nn.Dropout(p=0.5))

        # 4. classify
        self.classify = nn.Sequential(nn.Linear(512, num_class))


    def forward(self, x):
        # [B, 64, 128] -> [B, 1, 64, 128]
        if len(x.shape) == 3:
            x = x.unsqueeze(1) 


        # [B, 1, 64, 128] -> [B, 8, 64, 128]
        x = self.block1(x)

        # [B, 8, 64, 128] -> [B, 16, 1, 128] -> [B, 16, 1, 32]
        x = self.block2(x)


        # [B, 16, 1, 31] -> [B, 16, 1, 4]
        x = self.block3(x)

        # [B, 16, 1, 4] -> [B, 64]
        x = x.view(x.size(0), -1)

        # [B, 64] -> [B, num_class]
        x = self.classify(x)



        return x

In [4]:
def set_seed(seed: int = 0) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [5]:
config = {
    "batch_size": 128,
    "epochs" : 150,
    "lr" : 0.001,
    "device" : "cuda:0",
    "num_workers" : 1,
    "seed" : 1234,
    "use_wandb" : True,
    "data_dir" : './data',
}

In [6]:
def run(args: DictConfig):
    set_seed(args.seed)

    logdir = '/content/MEG_data/log'

    # ------------------
    #    Dataloader
    # ------------------
    loader_args = {"batch_size": args.batch_size, "num_workers": args.num_workers}

    train_set = ThingsMEGDataset("train", args.data_dir)
    train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, **loader_args)
    val_set = ThingsMEGDataset("val", args.data_dir)
    val_loader = torch.utils.data.DataLoader(val_set, shuffle=False, **loader_args)
    test_set = ThingsMEGDataset("test", args.data_dir)
    test_loader = torch.utils.data.DataLoader(
        test_set, shuffle=False, batch_size=args.batch_size, num_workers=args.num_workers
    )

    # ------------------
    #       Model
    # ------------------
    model = EEGNet2d(args.batch_size, train_set.num_classes,).to(args.device)

    # ------------------
    #     Optimizer
    # ------------------
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,weight_decay=1e-5)

    # ------------------
    # Learning Rate Scheduler
    # ------------------
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

    # ------------------
    #   Start training
    # ------------------
    max_val_acc = 0
    accuracy = Accuracy(
        task="multiclass", num_classes=train_set.num_classes, top_k=10
    ).to(args.device)

    for epoch in range(args.epochs):
        print(f"Epoch {epoch+1}/{args.epochs}")

        train_loss, train_acc, val_loss, val_acc = [], [], [], []

        model.train()
        for X, y, subject_idxs in tqdm(train_loader, desc="Train"):
            X, y = X.to(args.device), y.to(args.device)
            y_pred = model(X)
            loss = F.cross_entropy(y_pred, y)
            train_loss.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = accuracy(y_pred, y)
            train_acc.append(acc.item())

        model.eval()
        for X, y, subject_idxs in tqdm(val_loader, desc="Validation"):
            X, y = X.to(args.device), y.to(args.device)

            with torch.no_grad():
                y_pred = model(X)

            val_loss.append(F.cross_entropy(y_pred, y).item())
            val_acc.append(accuracy(y_pred, y).item())

        print(f"Epoch {epoch+1}/{args.epochs} | train loss: {np.mean(train_loss):.3f} | train acc: {np.mean(train_acc):.3f} | val loss: {np.mean(val_loss):.3f} | val acc: {np.mean(val_acc):.3f}")
        torch.save(model.state_dict(), os.path.join(logdir, "model_last.pt"))

        if np.mean(val_acc) > max_val_acc:
            cprint("New best.", "cyan")
            torch.save(model.state_dict(), os.path.join(logdir, "model_best.pt"))
            max_val_acc = np.mean(val_acc)


    # ----------------------------------
    #  Start evaluation with best model
    # ----------------------------------
    model.load_state_dict(torch.load(os.path.join(logdir, "model_best.pt"), map_location=args.device))

    preds = []
    model.eval()
    for X, subject_idxs in tqdm(test_loader, desc="Validation"):
        preds.append(model(X.to(args.device)).detach().cpu())

    preds = torch.cat(preds, dim=0).numpy()
    np.save(os.path.join(logdir, "submission"), preds)
    cprint(f"Submission {preds.shape} saved at {logdir}", "cyan")


In [7]:
args = DictConfig(config)
run(args)

Epoch 1/150


Train:   0%|          | 0/514 [00:53<?, ?it/s]


RuntimeError: DataLoader worker (pid(s) 14568) exited unexpectedly