In [1]:
from google.colab import drive
drive.mount('/content/drive')


ModuleNotFoundError: No module named 'google.colab'

In [None]:
%cd /content/drive/MyDrive/Colab Notebooks/DLBasics2023_colab/final

/content/drive/MyDrive/Colab Notebooks/DLBasics2023_colab/final


In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/DLBasics2023_colab/final/dl_lecture_competition_pub/

/content/drive/MyDrive/Colab Notebooks/DLBasics2023_colab/final/dl_lecture_competition_pub


In [None]:
!pip install -r requirements.txt



In [None]:
import os
import sys
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import Accuracy
from termcolor import cprint
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

sys.path.append('/content/drive/MyDrive/Colab Notebooks/DLBasics2023_colab/final/dl_lecture_competition_pub/')
from src.datasets import ThingsMEGDataset
from src.utils import set_seed

class Args:
    seed = 42
    batch_size = 10
    num_workers = 0
    data_dir = 'data'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    lr = 0.0001
    epochs = 200
    patience = 20

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :, :]
        return self.dropout(x)

class TransformerClassifier(nn.Module):
    def __init__(self, num_classes, seq_len, num_channels, d_model=128, nhead=8, num_layers=4, dropout=0.2):
        super().__init__()
        self.normalize = nn.BatchNorm1d(num_channels)
        self.embedding = nn.Linear(num_channels, d_model)
        self.bn1 = nn.BatchNorm1d(d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=seq_len)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=512, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.layer_norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, num_classes)
        self.seq_len = seq_len

    def forward(self, x):
        if x.shape[2] != self.seq_len:
            raise ValueError(f"Expected input sequence length {self.seq_len}, but got {x.shape[2]}")

        x = self.normalize(x)
        x = self.embedding(x.transpose(1, 2))
        x = self.bn1(x.permute(0, 2, 1)).permute(0, 2, 1)
        x = x.permute(1, 0, 2)
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x)
        x = self.layer_norm(x)
        x = x.mean(dim=0)
        return self.fc(x)

def mixup_data(x, y, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def add_noise(x, noise_level=0.01):
    return x + noise_level * torch.randn_like(x)

def time_shift(x, max_shift=5):
    shift = torch.randint(-max_shift, max_shift + 1, (x.shape[0],))
    return torch.stack([torch.roll(x[i], shifts=shift[i].item(), dims=1) for i in range(x.shape[0])])

args = Args()
set_seed(args.seed)

loader_args = {"batch_size": args.batch_size, "num_workers": args.num_workers}

train_set = ThingsMEGDataset("train", args.data_dir)
train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, **loader_args)
val_set = ThingsMEGDataset("val", args.data_dir)
val_loader = torch.utils.data.DataLoader(val_set, shuffle=False, **loader_args)
test_set = ThingsMEGDataset("test", args.data_dir)
test_loader = torch.utils.data.DataLoader(test_set, shuffle=False, **loader_args)

model = TransformerClassifier(
    train_set.num_classes,
    seq_len=train_set.seq_len,
    num_channels=train_set.num_channels
).to(args.device)

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

accuracy = Accuracy(
    task="multiclass", num_classes=train_set.num_classes, top_k=10
).to(args.device)

max_val_acc = 0
best_epoch = 0
logdir = 'outputs'

for epoch in range(args.epochs):
    print(f"Epoch {epoch+1}/{args.epochs}")

    model.train()
    train_loss, train_acc = [], []
    for X, y, subject_idxs in tqdm(train_loader, desc="Train"):
        X, y = X.to(args.device), y.to(args.device)
        X = add_noise(X)
        X = time_shift(X)
        X, y_a, y_b, lam = mixup_data(X, y)

        y_pred = model(X)
        loss = mixup_criterion(F.cross_entropy, y_pred, y_a, y_b, lam)
        train_loss.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        acc = accuracy(y_pred, y)
        train_acc.append(acc.item())

    scheduler.step()

    model.eval()
    val_loss, val_acc = [], []
    with torch.no_grad():
        for X, y, subject_idxs in tqdm(val_loader, desc="Validation"):
            X, y = X.to(args.device), y.to(args.device)
            y_pred = model(X)
            val_loss.append(F.cross_entropy(y_pred, y).item())
            val_acc.append(accuracy(y_pred, y).item())

    print(f"Epoch {epoch+1}/{args.epochs} | train loss: {np.mean(train_loss):.3f} | train acc: {np.mean(train_acc):.3f} | val loss: {np.mean(val_loss):.3f} | val acc: {np.mean(val_acc):.3f}")
    torch.save(model.state_dict(), os.path.join(logdir, "model_last.pt"))

    if np.mean(val_acc) > max_val_acc:
        cprint("New best.", "cyan")
        torch.save(model.state_dict(), os.path.join(logdir, "model_best.pt"))
        max_val_acc = np.mean(val_acc)
        best_epoch = epoch

    if epoch - best_epoch >= args.patience:
        cprint("Early stopping", "red")
        break

model.load_state_dict(torch.load(os.path.join(logdir, "model_best.pt"), map_location=args.device))

preds = []
model.eval()
with torch.no_grad():
    for X, subject_idxs in tqdm(test_loader, desc="Test"):
        preds.append(model(X.to(args.device)).cpu())

preds = torch.cat(preds, dim=0).numpy()
np.save(os.path.join(logdir, "submission"), preds)
cprint(f"Submission {preds.shape} saved at {logdir}", "cyan")

Epoch 1/200


Train: 100%|██████████| 6573/6573 [02:46<00:00, 39.52it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 151.59it/s]


Epoch 1/200 | train loss: 7.547 | train acc: 0.008 | val loss: 7.510 | val acc: 0.009
New best.
Epoch 2/200


Train: 100%|██████████| 6573/6573 [02:45<00:00, 39.76it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 152.58it/s]


Epoch 2/200 | train loss: 7.508 | train acc: 0.010 | val loss: 7.495 | val acc: 0.013
New best.
Epoch 3/200


Train: 100%|██████████| 6573/6573 [02:45<00:00, 39.70it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 152.15it/s]


Epoch 3/200 | train loss: 7.482 | train acc: 0.012 | val loss: 7.460 | val acc: 0.015
New best.
Epoch 4/200


Train: 100%|██████████| 6573/6573 [02:44<00:00, 39.88it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 154.00it/s]


Epoch 4/200 | train loss: 7.430 | train acc: 0.014 | val loss: 7.420 | val acc: 0.019
New best.
Epoch 5/200


Train: 100%|██████████| 6573/6573 [02:44<00:00, 39.95it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.38it/s]


Epoch 5/200 | train loss: 7.382 | train acc: 0.016 | val loss: 7.391 | val acc: 0.021
New best.
Epoch 6/200


Train: 100%|██████████| 6573/6573 [02:44<00:00, 39.92it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 151.72it/s]


Epoch 6/200 | train loss: 7.335 | train acc: 0.020 | val loss: 7.378 | val acc: 0.024
New best.
Epoch 7/200


Train: 100%|██████████| 6573/6573 [02:44<00:00, 39.90it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 152.93it/s]


Epoch 7/200 | train loss: 7.291 | train acc: 0.021 | val loss: 7.366 | val acc: 0.024
Epoch 8/200


Train: 100%|██████████| 6573/6573 [02:44<00:00, 39.93it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 154.43it/s]


Epoch 8/200 | train loss: 7.257 | train acc: 0.023 | val loss: 7.358 | val acc: 0.025
New best.
Epoch 9/200


Train: 100%|██████████| 6573/6573 [02:44<00:00, 39.90it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 153.78it/s]


Epoch 9/200 | train loss: 7.233 | train acc: 0.024 | val loss: 7.353 | val acc: 0.027
New best.
Epoch 10/200


Train: 100%|██████████| 6573/6573 [02:44<00:00, 39.92it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 152.70it/s]


Epoch 10/200 | train loss: 7.222 | train acc: 0.025 | val loss: 7.354 | val acc: 0.026
Epoch 11/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.17it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 156.44it/s]


Epoch 11/200 | train loss: 7.292 | train acc: 0.021 | val loss: 7.356 | val acc: 0.025
Epoch 12/200


Train: 100%|██████████| 6573/6573 [02:41<00:00, 40.75it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 157.75it/s]


Epoch 12/200 | train loss: 7.244 | train acc: 0.023 | val loss: 7.327 | val acc: 0.030
New best.
Epoch 13/200


Train: 100%|██████████| 6573/6573 [02:41<00:00, 40.74it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 156.76it/s]


Epoch 13/200 | train loss: 7.199 | train acc: 0.026 | val loss: 7.321 | val acc: 0.029
Epoch 14/200


Train: 100%|██████████| 6573/6573 [02:41<00:00, 40.73it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 156.47it/s]


Epoch 14/200 | train loss: 7.147 | train acc: 0.029 | val loss: 7.317 | val acc: 0.031
New best.
Epoch 15/200


Train: 100%|██████████| 6573/6573 [02:41<00:00, 40.74it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 154.88it/s]


Epoch 15/200 | train loss: 7.096 | train acc: 0.032 | val loss: 7.316 | val acc: 0.033
New best.
Epoch 16/200


Train: 100%|██████████| 6573/6573 [02:44<00:00, 40.06it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 154.86it/s]


Epoch 16/200 | train loss: 7.042 | train acc: 0.036 | val loss: 7.307 | val acc: 0.033
New best.
Epoch 17/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.11it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.10it/s]


Epoch 17/200 | train loss: 6.988 | train acc: 0.040 | val loss: 7.292 | val acc: 0.036
New best.
Epoch 18/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.11it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 153.98it/s]


Epoch 18/200 | train loss: 6.933 | train acc: 0.044 | val loss: 7.289 | val acc: 0.038
New best.
Epoch 19/200


Train: 100%|██████████| 6573/6573 [02:44<00:00, 40.06it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 152.47it/s]


Epoch 19/200 | train loss: 6.882 | train acc: 0.049 | val loss: 7.302 | val acc: 0.036
Epoch 20/200


Train: 100%|██████████| 6573/6573 [02:44<00:00, 40.01it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 153.98it/s]


Epoch 20/200 | train loss: 6.833 | train acc: 0.053 | val loss: 7.300 | val acc: 0.038
New best.
Epoch 21/200


Train: 100%|██████████| 6573/6573 [02:44<00:00, 40.06it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.00it/s]


Epoch 21/200 | train loss: 6.778 | train acc: 0.059 | val loss: 7.308 | val acc: 0.040
New best.
Epoch 22/200


Train: 100%|██████████| 6573/6573 [02:44<00:00, 40.06it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 153.91it/s]


Epoch 22/200 | train loss: 6.730 | train acc: 0.064 | val loss: 7.304 | val acc: 0.039
Epoch 23/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.10it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 154.18it/s]


Epoch 23/200 | train loss: 6.691 | train acc: 0.066 | val loss: 7.319 | val acc: 0.039
Epoch 24/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.11it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 153.69it/s]


Epoch 24/200 | train loss: 6.651 | train acc: 0.072 | val loss: 7.313 | val acc: 0.041
New best.
Epoch 25/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.11it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 154.59it/s]


Epoch 25/200 | train loss: 6.615 | train acc: 0.075 | val loss: 7.318 | val acc: 0.040
Epoch 26/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.24it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.96it/s]


Epoch 26/200 | train loss: 6.588 | train acc: 0.077 | val loss: 7.318 | val acc: 0.041
New best.
Epoch 27/200


Train: 100%|██████████| 6573/6573 [02:42<00:00, 40.37it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 156.16it/s]


Epoch 27/200 | train loss: 6.568 | train acc: 0.079 | val loss: 7.316 | val acc: 0.041
New best.
Epoch 28/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.24it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 156.84it/s]


Epoch 28/200 | train loss: 6.543 | train acc: 0.081 | val loss: 7.315 | val acc: 0.042
New best.
Epoch 29/200


Train: 100%|██████████| 6573/6573 [02:42<00:00, 40.34it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 154.72it/s]


Epoch 29/200 | train loss: 6.535 | train acc: 0.085 | val loss: 7.323 | val acc: 0.041
Epoch 30/200


Train: 100%|██████████| 6573/6573 [02:42<00:00, 40.33it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 156.78it/s]


Epoch 30/200 | train loss: 6.531 | train acc: 0.081 | val loss: 7.329 | val acc: 0.041
Epoch 31/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.31it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.25it/s]


Epoch 31/200 | train loss: 6.703 | train acc: 0.064 | val loss: 7.379 | val acc: 0.037
Epoch 32/200


Train: 100%|██████████| 6573/6573 [02:42<00:00, 40.36it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 156.29it/s]


Epoch 32/200 | train loss: 6.651 | train acc: 0.067 | val loss: 7.376 | val acc: 0.039
Epoch 33/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.32it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.60it/s]


Epoch 33/200 | train loss: 6.597 | train acc: 0.072 | val loss: 7.407 | val acc: 0.039
Epoch 34/200


Train: 100%|██████████| 6573/6573 [02:42<00:00, 40.36it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 157.29it/s]


Epoch 34/200 | train loss: 6.530 | train acc: 0.078 | val loss: 7.443 | val acc: 0.039
Epoch 35/200


Train: 100%|██████████| 6573/6573 [02:42<00:00, 40.34it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 154.71it/s]


Epoch 35/200 | train loss: 6.474 | train acc: 0.085 | val loss: 7.454 | val acc: 0.038
Epoch 36/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.32it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.65it/s]


Epoch 36/200 | train loss: 6.411 | train acc: 0.090 | val loss: 7.485 | val acc: 0.041
Epoch 37/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.32it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.66it/s]


Epoch 37/200 | train loss: 6.334 | train acc: 0.097 | val loss: 7.537 | val acc: 0.038
Epoch 38/200


Train: 100%|██████████| 6573/6573 [02:42<00:00, 40.34it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.00it/s]


Epoch 38/200 | train loss: 6.274 | train acc: 0.106 | val loss: 7.577 | val acc: 0.040
Epoch 39/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.28it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.18it/s]


Epoch 39/200 | train loss: 6.195 | train acc: 0.115 | val loss: 7.624 | val acc: 0.041
Epoch 40/200


Train: 100%|██████████| 6573/6573 [02:42<00:00, 40.34it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.43it/s]


Epoch 40/200 | train loss: 6.130 | train acc: 0.119 | val loss: 7.671 | val acc: 0.040
Epoch 41/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.22it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 153.77it/s]


Epoch 41/200 | train loss: 6.050 | train acc: 0.131 | val loss: 7.693 | val acc: 0.042
Epoch 42/200


Train: 100%|██████████| 6573/6573 [02:42<00:00, 40.33it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.45it/s]


Epoch 42/200 | train loss: 5.978 | train acc: 0.140 | val loss: 7.758 | val acc: 0.040
Epoch 43/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.32it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.49it/s]


Epoch 43/200 | train loss: 5.900 | train acc: 0.147 | val loss: 7.794 | val acc: 0.040
Epoch 44/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.31it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 154.66it/s]


Epoch 44/200 | train loss: 5.839 | train acc: 0.156 | val loss: 7.858 | val acc: 0.040
Epoch 45/200


Train: 100%|██████████| 6573/6573 [02:42<00:00, 40.33it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.34it/s]


Epoch 45/200 | train loss: 5.771 | train acc: 0.165 | val loss: 7.906 | val acc: 0.039
Epoch 46/200


Train: 100%|██████████| 6573/6573 [02:42<00:00, 40.34it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 154.49it/s]


Epoch 46/200 | train loss: 5.698 | train acc: 0.175 | val loss: 7.962 | val acc: 0.040
Epoch 47/200


Train: 100%|██████████| 6573/6573 [02:43<00:00, 40.31it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.53it/s]


Epoch 47/200 | train loss: 5.637 | train acc: 0.178 | val loss: 8.019 | val acc: 0.037
Epoch 48/200


Train: 100%|██████████| 6573/6573 [02:42<00:00, 40.34it/s]
Validation: 100%|██████████| 1644/1644 [00:10<00:00, 155.84it/s]


Epoch 48/200 | train loss: 5.566 | train acc: 0.189 | val loss: 8.056 | val acc: 0.039
Early stopping


Test: 100%|██████████| 1644/1644 [00:08<00:00, 184.54it/s]


Submission (16432, 1854) saved at outputs
