# Import


In [1]:
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import lightning as L
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger

import os
import numpy as np
import random
from datetime import datetime

import sys

sys.path.append(os.path.join(os.path.dirname(os.getcwd()), "code"))
import dataset as D

In [2]:
root_path = "/home/choi/BrainDecoder/"
dataset_path = os.path.join(root_path, "dataset")
images_dataset_path = os.path.join(dataset_path, "imageNet_images")
eeg_dataset_path = os.path.join(dataset_path, "eeg")

config = {}

# Dataset


##### dataset (deprecated)

In [4]:
class EEGDataset(Dataset):
    def __init__(self, eeg_dataset_file_name="eeg_5_95_std.pth") -> None:
        super().__init__()
        loaded = torch.load(os.path.join(eeg_dataset_path, eeg_dataset_file_name))
        self.data = loaded["dataset"]
        self.labels = loaded["labels"]
        self.images = loaded["images"]
        self.size = len(self.data)

    def __getitem__(self, idx):
        # t() -> transpose
        eeg = self.data[idx]["eeg"].t().to(torch.float)
        eeg = eeg[20:460, :]

        label = self.data[idx]["label"]
        return eeg, label

    def __len__(self):
        return self.size

In [5]:
class Splitter(Dataset):
    def __init__(self, dataset, split_name="train") -> None:
        super().__init__()
        self.dataset = dataset

        loaded = torch.load(
            os.path.join(eeg_dataset_path, "block_splits_by_image_all.pth")
        )
        self.target_data_indices = loaded["splits"][0][split_name]
        # filter data that is too short
        self.target_data_indices = [
            i
            for i in self.target_data_indices
            if 450 <= self.dataset.data[i]["eeg"].size(1) <= 600
        ]

        self.size = len(self.target_data_indices)
        self.all_labels = np.array(self.get_all_labels())
        self.all_eegs = self.get_all_eegs()

    def __getitem__(self, idx):
        eeg, label = self.dataset[self.target_data_indices[idx]]
        return eeg, label

    def __len__(self):
        return self.size

    def get_all_labels(self):
        data = [self.dataset[idx] for idx in self.target_data_indices]
        return [item[1] for item in data]

    def get_all_eegs(self):
        data = [self.dataset[idx] for idx in self.target_data_indices]
        return [item[0] for item in data]

    def generate_data_points(self, anchor_labels, positive=True):
        eeg_shape = self.__getitem__(0)[0].size()
        eegs = torch.empty(0, eeg_shape[0], eeg_shape[1])
        labels = torch.empty(0)
        for anchor_label in anchor_labels:
            indices = (
                np.argwhere(self.all_labels == anchor_label.item())[:, 0]
                if positive
                else np.argwhere(self.all_labels != anchor_label.item())[:, 0]
            )
            data_idx = np.random.choice(indices)
            eeg = self.all_eegs[data_idx]
            eegs = torch.cat((eegs, eeg.unsqueeze(dim=0)))
            labels = torch.cat(
                (labels, torch.tensor(self.all_labels[data_idx]).unsqueeze(dim=0))
            )

        return eegs, labels

    def get_data(self, anchor_label, positive: bool = True):
        cnt = 0
        while True:
            idx = random.choice(self.target_data_indices)
            if positive and self.dataset[idx][1] == anchor_label:
                return self.dataset[idx]
            if not positive and self.dataset[idx][1] != anchor_label:
                return self.dataset[idx]

            if cnt >= 2000:
                raise Exception(f"get_data failed after {cnt} tries")
            cnt += 1

In [6]:
dataset = EEGDataset(eeg_dataset_file_name="eeg_5_95_std.pth")
loaders = {
    split: DataLoader(
        Splitter(dataset, split_name=split), batch_size=16, shuffle=True, drop_last=True
    )
    for split in ["train", "val", "test"]
}

##### dataset from py file

In [13]:
dataset = D.EEGDataset(eeg_dataset_file_name="eeg_signals_raw_with_mean_std.pth")

loaders = {
    split: DataLoader(
        D.Splitter(dataset, split_name=split),
        batch_size=config["batch-size"],
        shuffle=True if split == "train" else False,
        num_workers=23,
        drop_last=True,
    )
    for split in ["train", "val", "test"]
}

# Model


In [6]:
gpu_id = 2
device = f"cuda:{gpu_id}" if torch.cuda.is_available else "cpu"
print(device)

cuda:2


In [20]:
class FeatureExtractor_ContrastiveLearning_NN(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()

        # Triplet loss
        def dist_fn(x1, x2):
            return torch.sum(torch.pow(torch.subtract(x1, x2), 2), dim=1)

        self.loss_fn = nn.TripletMarginWithDistanceLoss(
            distance_function=dist_fn, margin=config["margin"]
        )

        # model
        self.input_size = 128
        self.hidden_size = 128
        self.lstm_layers = config["lstm-layer"]
        self.out_size = 128

        self.lstm = nn.LSTM(
            self.input_size,
            self.hidden_size,
            num_layers=self.lstm_layers,
            batch_first=True,
        )
        self.output = nn.Sequential(
            nn.Linear(in_features=self.hidden_size, out_features=self.out_size),
            nn.ReLU(),
        )

    def forward(self, input):
        input = input.to(device)

        lstm_out, _ = self.lstm(input)
        res = self.output(lstm_out[:, -1, :])
        return res

    def training_step(self, batch, batch_idx):
        anchor_eeg, anchor_label, _ = batch

        positive_eeg, positive_label = loaders["train"].dataset.generate_data_points(
            anchor_label, positive=True
        )
        negative_eeg, negative_label = loaders["train"].dataset.generate_data_points(
            anchor_label, positive=False
        )

        anchor_feature = self(anchor_eeg)
        positive_feature = self(positive_eeg)
        negative_feature = self(negative_eeg)

        loss = self.loss_fn(anchor_feature, positive_feature, negative_feature)

        self.log_dict(
            # {"train_loss": loss, "lr": self.scheduler.get_last_lr()[0]},
            {
                "train_loss": loss,
            },
            on_epoch=True,
            prog_bar=True,
            batch_size=config["batch-size"],
        )
        return loss

    def validation_step(self, batch, batch_idx):
        anchor_eeg, anchor_label, _ = batch

        positive_eeg, positive_label = loaders["val"].dataset.generate_data_points(
            anchor_label, positive=True
        )
        negative_eeg, negative_label = loaders["val"].dataset.generate_data_points(
            anchor_label, positive=False
        )

        anchor_feature = self(anchor_eeg)
        positive_feature = self(positive_eeg)
        negative_feature = self(negative_eeg)

        loss = self.loss_fn(anchor_feature, positive_feature, negative_feature)

        self.log_dict(
            {"val_loss": loss},
            on_epoch=True,
            prog_bar=True,
            batch_size=config["batch-size"],
        )

    def create_optimizer(self):
        if config["optimizer"] == "Adam":
            return optim.Adam(
                self.parameters(),
                lr=config["lr"],
                weight_decay=config["weight-decay"],
                betas=config["betas"],
            )
        elif config["optimizer"] == "AdamW":
            return optim.AdamW(
                self.parameters(),
                lr=config["lr"],
                weight_decay=config["weight-decay"],
            )
        elif config["optimizer"] == "SGD":
            return optim.SGD(
                self.parameters(),
                lr=config["lr"],
                weight_decay=config["weight-decay"],
            )
        else:
            raise Exception("optimizer config error")

    def create_scheduler(self, optimizer):
        if config["scheduler"] == "LambdaLR":
            return optim.lr_scheduler.LambdaLR(
                optimizer, lambda epoch: config["lambda-factor"] ** epoch
            )
        else:
            raise Exception("scheduler config error")

    def configure_optimizers(self):
        optimizer = self.create_optimizer()
        scheduler = self.create_scheduler(optimizer)
        self.scheduler = scheduler
        return [optimizer], [scheduler]
        # return [optimizer]

# Training


In [18]:
config = {
    "optimizer": "Adam",  # ("Adam", "AdamW", "SGD")
    "lr": 1e-3,
    "betas": (0.9, 0.999),
    # "scheduler": "LambdaLR",
    "scheduler": "None",
    # "lambda-factor": 0.975,
    "lambda-factor": "None",
    "weight-decay": 0,
    "margin": 2.0,
    "lstm-layer": 3,
    "batch-size": 16,
}

In [21]:
model = FeatureExtractor_ContrastiveLearning_NN()
model.to(device)

now = datetime.now()
now_hm = now.strftime("%H:%M")

logger = TensorBoardLogger(
    save_dir="/home/choi/BrainDecoder/lightning_logs/ContrastiveLossFeatureLearning",
    name=f"{now_hm}_{config['optimizer']}_{config['lr']}_{config['scheduler']}_margin_{config['margin']}_weight-decay_{config['weight-decay']}_lambda-factor_{config['lambda-factor']}",
    version=now.strftime("%Y-%m-%d %H:%M:%S"),
)

lr_monitor = LearningRateMonitor(logging_interval="epoch")

trainer = Trainer = L.Trainer(
    max_epochs=500,
    logger=logger,
    callbacks=[lr_monitor],
    accelerator="gpu",
    devices=[gpu_id],
)
trainer.fit(model, train_dataloaders=loaders["train"], val_dataloaders=loaders["val"])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name    | Type                          | Params
----------------------------------------------------------
0 | loss_fn | TripletMarginWithDistanceLoss | 0     
1 | lstm    | LSTM                          | 396 K 
2 | output  | Sequential                    | 16.5 K
----------------------------------------------------------
412 K     Trainable params
0         Non-trainable params
412 K     Total params
1.651     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Epoch 53: 100%|██████████| 497/497 [00:45<00:00, 10.83it/s, v_num=2:57, train_loss_step=1.200, val_loss=0.654, train_loss_epoch=0.638]  

/home/choi/Downloads/miniconda3/envs/braindecoder/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


# Classifier Model


In [22]:
path = "/home/choi/BrainDecoder/lightning_logs/ContrastiveLossFeatureLearning/02:02_Adam_0.001_None_margin_2.0_weight-decay_0_lambda-factor_None/2023-12-18 02:02:57/checkpoints/epoch=52-step=26341.ckpt"

CKPT_PATH = os.path.join(root_path, path)

In [32]:
from torchmetrics.functional import accuracy

classifier_config = {}


class EEG_Classifier(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()

        # self.feature_extractor = FeatureExtractor_ContrastiveLearning_NN()
        self.feature_extractor = (
            FeatureExtractor_ContrastiveLearning_NN.load_from_checkpoint(CKPT_PATH)
        )
        self.feature_extractor.requires_grad_(False)
        self.classifier = nn.Linear(128, 40)

        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, input):
        # for p in self.feature_extractor.parameters():
        #     p.requires_grad = False
        # with torch.no_grad():
        #     input = self.feature_extractor(input)
        input = self.feature_extractor(input)
        res = self.classifier(input)
        return res

    def training_step(self, batch, batch_idx):
        eegs, labels, _ = batch
        eegs = eegs.to(device)
        labels = labels.to(device)

        out = self(eegs)

        loss = self.loss_fn(out, labels)
        self.log_dict(
            {"train_loss": loss, "lr": self.scheduler.get_last_lr()[0]},
            on_epoch=True,
            prog_bar=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        self._shared_eval(batch, prefix="val")

    def test_step(self, batch, batch_idx):
        self._shared_eval(batch, prefix="test")

    def _shared_eval(self, batch, prefix="val"):
        eegs, labels, _ = batch
        out = self(eegs)
        loss = self.loss_fn(out, labels)
        preds = torch.argmax(out, dim=1)
        acc = accuracy(preds, labels, "multiclass", num_classes=preds.shape[0])
        self.log_dict({f"{prefix}_loss": loss, f"{prefix}_acc": acc}, prog_bar=True)

    def predict_step(self, batch):
        return self(batch)

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(),
            lr=classifier_config["lr"],
            weight_decay=classifier_config["weight-decay"],
            betas=classifier_config["betas"],
        )
        scheduler = optim.lr_scheduler.LambdaLR(
            optimizer, (lambda epoch: classifier_config["lambda-factor"] ** epoch)
        )
        self.scheduler = scheduler
        return [optimizer], [scheduler]

# Train Classifier


In [33]:
classifier_config = {
    "optimizer": "Adam",  # ("Adam", "AdamW", "SGD")
    "lr": 1e-3,
    "betas": (0.9, 0.999),
    "scheduler": "LambdaLR",
    "lambda-factor": 0.95,
    "weight-decay": 0,
}

In [34]:
model = EEG_Classifier()
model.to(device)

now = datetime.now()
now_time = now.strftime("%H:%M")

logger = TensorBoardLogger(
    save_dir="/home/choi/BrainDecoder/lightning_logs/ContrastiveLossClassification",
    name=f"{now_time}_{config['optimizer']}_{config['lr']}__weight-decay_{config['weight-decay']}_{config['scheduler']}_lambda-factor_{config['lambda-factor']}",
    version=now.strftime("%Y-%m-%d %H:%M:%S"),
)

lr_monitor = LearningRateMonitor(logging_interval="epoch")

trainer = L.Trainer(
    max_epochs=200,
    logger=logger,
    callbacks=[lr_monitor],
    accelerator="gpu",
    devices=[gpu_id],
)
trainer.fit(model, train_dataloaders=loaders["train"], val_dataloaders=loaders["val"])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name              | Type                                    | Params
------------------------------------------------------------------------------
0 | feature_extractor | FeatureExtractor_ContrastiveLearning_NN | 412 K 
1 | classifier        | Linear                                  | 5.2 K 
2 | loss_fn           | CrossEntropyLoss                        | 0     
------------------------------------------------------------------------------
5.2 K     Trainable params
412 K     Non-trainable params
417 K     Total params
1.672     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Epoch 51:   0%|          | 0/497 [00:00<?, ?it/s, v_num=4:43, train_loss_step=2.550, lr_step=7.69e-5, val_loss=2.460, val_acc=0.130, train_loss_epoch=2.420, lr_epoch=7.69e-5]            

Exception ignored in: <function _releaseLock at 0x7f1f2e95e340>
Traceback (most recent call last):
  File "/home/choi/Downloads/miniconda3/envs/braindecoder/lib/python3.11/logging/__init__.py", line 237, in _releaseLock
    def _releaseLock():
    
KeyboardInterrupt: 


Epoch 52:  57%|█████▋    | 282/497 [00:09<00:07, 29.37it/s, v_num=4:43, train_loss_step=2.520, lr_step=6.94e-5, val_loss=2.460, val_acc=0.129, train_loss_epoch=2.420, lr_epoch=7.31e-5]

/home/choi/Downloads/miniconda3/envs/braindecoder/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


# Test Classifier


In [118]:
trainer.test(model, dataloaders=loaders["test"])

/Users/ms/anaconda3/envs/neuroimagen/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:492: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/Users/ms/anaconda3/envs/neuroimagen/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 124/124 [00:04<00:00, 25.95it/s]


[{'test_loss': 3.7022695541381836, 'test_acc': 0.019657257944345474}]