# Import & Config


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# from torch.optim import Adam
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split

import lightning as L
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger


import os

# import cv2
# import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from datetime import datetime
from pytz import timezone
import sys

sys.path.append(os.path.join(os.path.dirname(os.getcwd()), "code"))
import dataset as D

In [None]:
root_path = "/home/choi/BrainDecoder/"
dataset_path = os.path.join(root_path, "dataset")
images_dataset_path = os.path.join(dataset_path, "imageNet_images")
eeg_dataset_path = os.path.join(dataset_path, "eeg")

config = {}

# Dataset


##### Dataset_1 (using custom dataset method)


In [None]:
# Merge all image folders to create list of images

import random

dir_list = list(os.walk(images_dataset_path))
# skip any unnecessary folders
start_idx = len(dir_list) - 40
images_list = []
for sub_dir in dir_list[start_idx:]:
    # images_list+=sub_dir[2]
    images_list.extend(sub_dir[2])

images_list = [image_name.replace(".JPEG", "") for image_name in images_list]

In [None]:
# Split images list to 8:1:1

# random.shuffle(images_list)

images_total_size = len(images_list)
train_size = int(images_total_size * 0.8)
val_size = int(images_total_size * 0.1)
test_size = images_total_size - train_size - val_size

train_images = images_list[:train_size]
val_images = images_list[train_size : train_size + val_size]
test_images = images_list[-1 * test_size :]

In [None]:
eeg_dataset_name = "eeg_5_95_std.pth"
eeg_dataset = torch.load(os.path.join(eeg_dataset_path, eeg_dataset_name))

In [None]:
class CustomDataset(Dataset):
    def __init__(self, images_list) -> None:
        super().__init__()
        self.x_data = []
        self.y_data = []
        # self.y_data = [image_name.split("_")[0] for image_name in images_list]

        for eeg_segment in eeg_dataset["dataset"]:
            for image_name in images_list:
                if eeg_dataset["images"][eeg_segment["image"]] == image_name:
                    self.x_data.append(eeg_segment["eeg"][:, 20:460])
                    # all_channel_list = np.array(eeg_segment['eeg'])
                    # self.x_data.append(torch.from_numpy(all_channel_list[:,40:480]))
                    # self.x_data.append(torch.FloatTensor([eeg_sequence[40:480] for eeg_sequence in eeg_segment['eeg']]))
                    class_id = image_name.split("_")[0]
                    self.y_data.append(lookup_dict[class_id])
                    # self.y_data.append(eeg_dataset['labels'][eeg_segment['label']])
                    break

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return len(self.x_data)

In [None]:
train_dataset = CustomDataset(train_images)
val_dataset = CustomDataset(val_images)
test_dataset = CustomDataset(test_images)

In [None]:
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=16)
val_loader = DataLoader(val_dataset, batch_size=16)
test_loader = DataLoader(test_dataset, batch_size=16)

##### Dataset_2 (using splitter method)


In [None]:
class EEGDataset(Dataset):
    def __init__(self, eeg_dataset_file_name="eeg_5_95_std.pth") -> None:
        super().__init__()
        loaded = torch.load(os.path.join(eeg_dataset_path, eeg_dataset_file_name))
        self.data = loaded["dataset"]
        self.labels = loaded["labels"]
        self.images = loaded["images"]
        self.size = len(self.data)

    def __getitem__(self, idx):
        # t() -> transpose
        eeg = self.data[idx]["eeg"].t().to(torch.float)
        eeg = eeg[20:460, :]

        label = self.data[idx]["label"]
        return eeg, label

    def __len__(self):
        return self.size

In [None]:
class Splitter(Dataset):
    def __init__(self, dataset, split_name="train") -> None:
        super().__init__()
        self.dataset = dataset

        loaded = torch.load(
            os.path.join(eeg_dataset_path, "block_splits_by_image_all.pth")
        )
        self.target_data_indices = loaded["splits"][0][split_name]
        # filter data that is too short
        self.target_data_indices = [
            i
            for i in self.target_data_indices
            if 450 <= self.dataset.data[i]["eeg"].size(1) <= 600
        ]

        self.size = len(self.target_data_indices)
        self.all_labels = np.array(self.get_all_labels())
        self.all_eegs = self.get_all_eegs()

    def __getitem__(self, idx):
        eeg, label = self.dataset[self.target_data_indices[idx]]
        return eeg, label

    def __len__(self):
        return self.size

    def get_all_labels(self):
        data = [self.dataset[idx] for idx in self.target_data_indices]
        return [item[1] for item in data]

    def get_all_eegs(self):
        data = [self.dataset[idx] for idx in self.target_data_indices]
        return [item[0] for item in data]

    def generate_data_points(self, anchor_labels, positive=True):
        eeg_shape = self.__getitem__(0)[0].size()
        eegs = torch.empty(0, eeg_shape[0], eeg_shape[1])
        labels = torch.empty(0)
        for anchor_label in anchor_labels:
            indices = (
                np.argwhere(self.all_labels == anchor_label.item())[:, 0]
                if positive
                else np.argwhere(self.all_labels != anchor_label.item())[:, 0]
            )
            data_idx = np.random.choice(indices)
            eeg = self.all_eegs[data_idx]
            eegs = torch.cat((eegs, eeg.unsqueeze(dim=0)))
            labels = torch.cat(
                (labels, torch.tensor(self.all_labels[data_idx]).unsqueeze(dim=0))
            )

        return eegs, labels

    def get_data(self, anchor_label, positive: bool = True):
        cnt = 0
        while True:
            idx = random.choice(self.target_data_indices)
            if positive and self.dataset[idx][1] == anchor_label:
                return self.dataset[idx]
            if not positive and self.dataset[idx][1] != anchor_label:
                return self.dataset[idx]

            if cnt >= 2000:
                raise Exception(f"get_data failed after {cnt} tries")
            cnt += 1

In [None]:
dataset = EEGDataset(eeg_dataset_file_name="eeg_signals_raw_with_mean_std.pth")
loaders = {
    split: DataLoader(
        Splitter(dataset, split_name=split),
        batch_size=config["batch-size"],
        shuffle=True,
        drop_last=True,
        num_workers=1,
    )
    for split in ["train", "val", "test"]
}

##### Dataset 3 import from py file


In [36]:
dataset = D.EEGDataset(eeg_dataset_file_name="eeg_signals_raw_with_mean_std.pth")

loaders = {
    split: DataLoader(
        D.Splitter(dataset, split_name=split),
        batch_size=config["batch-size"],
        shuffle=True if split == "train" else False,
        num_workers=23,
        drop_last=True,
    )
    for split in ["train", "val", "test"]
}

# Model 만들기


In [6]:
gpu_id = 2
device = f"cuda:{gpu_id}" if torch.cuda.is_available else "cpu"
print(device)

seed = 1563423

cuda:2


In [29]:
# with classifier attached
class FeatureExtractorNN(L.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        # seed_everything(seed,workers=True)

        self.input_size = 128
        self.hidden_size = 128
        self.lstm_layers = config["lstm-layer"]
        self.out_size = 128

        # self.lstm = nn.LSTM(input_size=128,hidden_size=128,num_layers=128)
        self.lstm = nn.LSTM(
            self.input_size,
            self.hidden_size,
            num_layers=self.lstm_layers,
            batch_first=True,
        )
        self.output = nn.Sequential(
            nn.Linear(in_features=self.hidden_size, out_features=self.out_size),
            nn.ReLU(),
        )
        self.classifer = nn.Sequential(
            nn.Linear(in_features=self.out_size, out_features=40),
            # don't use softmax with cross entropy loss
            # nn.Softmax(dim=1)
        )

        self.loss_fn = nn.CrossEntropyLoss()
        # self.loss_fn = nn.NLLLoss()
        self.training_step_outputs = {"correct_num": 0, "loss_sum": 0}
        self.validation_step_outputs = {"correct_num": 0, "loss_sum": 0}

    def forward(self, input):
        batch_size = input.size(0)
        lstm_init = (
            torch.zeros(self.lstm_layers, batch_size, self.hidden_size),
            torch.zeros(self.lstm_layers, batch_size, self.hidden_size),
        )
        lstm_init = (lstm_init[0].to(device), lstm_init[0].to(device))

        # dont need to transpose because already transposed when creating dataset
        # input = input.transpose(1,2)

        lstm_out, _ = self.lstm(input, lstm_init)
        # tmp_out = lstm_out[:,-1,:] if input.dim()==3 else lstm_out[-1,:]
        tmp_out = lstm_out[:, -1, :]
        out = self.output(tmp_out)
        # print("out shape",out.shape)
        res = self.classifer(out)

        return res

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        x = x.to(device)
        y = y.to(device)
        out = self(x)
        loss = self.loss_fn(out, y)

        self.log_dict({"train_loss": loss}, prog_bar=True, on_epoch=True)
        preds = out.argmax(dim=1)
        self.training_step_outputs["correct_num"] += (preds == y).sum()
        self.training_step_outputs["loss_sum"] += loss
        return loss

    def on_train_epoch_end(self) -> None:
        num_correct = self.training_step_outputs["correct_num"]
        acc = num_correct / loaders["train"].dataset.__len__()
        loss = self.training_step_outputs["loss_sum"] / loaders["train"].__len__()
        print("\n")
        # print("EPOCH:",self.current_epoch)
        print(
            f"Training accuracy: {acc.item()} ({num_correct.item()}/{loaders['train'].dataset.__len__()} correct)"
        )
        print("Training loss (average):", loss.item())
        # print("\n")
        self.training_step_outputs["correct_num"] = 0
        self.training_step_outputs["loss_sum"] = 0

        print("Learning rate:", self.scheduler.get_last_lr(), "\n")

        self.log_dict({"train_acc_epoch": acc.item()})

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        x = x.to(device)
        y = y.to(device)
        out = self(x)
        loss = self.loss_fn(out, y)

        self.log_dict({"val_loss": loss}, prog_bar=True, on_epoch=True)
        preds = out.argmax(dim=1)
        self.validation_step_outputs["correct_num"] += (preds == y).sum()
        self.validation_step_outputs["loss_sum"] += loss
        # return loss

    def on_validation_epoch_end(self) -> None:
        num_correct = self.validation_step_outputs["correct_num"]
        acc = num_correct / loaders["val"].dataset.__len__()
        loss = self.validation_step_outputs["loss_sum"] / loaders["val"].__len__()
        print("\n")
        # print("EPOCH:",self.current_epoch)
        print(
            f"Validation accuracy: {acc.item()} ({num_correct.item()}/{loaders['val'].dataset.__len__()} correct)"
        )
        print("Validation loss (average):", loss.item())
        print("\n")
        self.validation_step_outputs["correct_num"] = 0
        self.validation_step_outputs["loss_sum"] = 0
        self.log_dict({"val_acc_epoch": acc.item()})

    def test_step(self, batch, batch_idx):
        x, y, _ = batch

        out = self(x)
        loss = self.loss_fn(out, y)

        y_hat = torch.argmax(out, dim=1)
        # print("OUT,YHAT:",out,y_hat)
        test_acc = torch.sum(y == y_hat).item() / (len(y) * 1.0)

        self.log_dict(
            {"test_loss": loss, "test_acc": test_acc}, prog_bar=True, on_epoch=True
        )
        # print("   ||   test loss:",loss.item(), "   ||   test accuracy:",test_acc )

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(),
            lr=config["lr"],
            weight_decay=config["weight-decay"],
            betas=config["betas"],
        )
        # scheduler = optim.lr_scheduler.LambdaLR(
        #     optimizer, lambda epoch: config["lambda-factor"] ** epoch
        # )
        scheduler = optim.lr_scheduler.CyclicLR(
            optimizer,
            base_lr=1e-5,
            max_lr=1e-2,
            step_size_up=1000,
            step_size_down=None,
            mode="exp_range",
            gamma=0.995,
            cycle_momentum=False,
        )
        self.scheduler = scheduler
        return [optimizer], [scheduler]
        # return [optimizer]

# Training


In [30]:
config = {
    "batch-size": 16,
    "optimizer": "Adam",  # ("Adam", "AdamW", "SGD")
    "lr": 1e-3,
    "betas": (0.9, 0.999),
    # "scheduler": "LambdaLR",
    "scheduler": "CyclicLR",
    "lambda-factor": 0.95,
    "weight-decay": 0.001,
    "lstm-layer": 3,
}

In [33]:
model = FeatureExtractorNN()
# model = FeatureExtractorNN.load_from_checkpoint(PATH)
model.to(device)

now = datetime.now(tz=timezone("Asia/Tokyo"))
now_time = now.strftime("%H:%M")

lr_monitor = LearningRateMonitor(logging_interval="epoch")
logger = TensorBoardLogger(
    "/home/choi/BrainDecoder/lightning_logs/FeatureExtractionClassification",
    name=f"{now_time}_{config['optimizer']}_{config['lr']}_{config['scheduler']}_weight-decay_{config['weight-decay']}_lambda-factor_{config['lambda-factor']}",
    version=now.strftime("%Y-%m-%d %H:%M:%S"),
)

trainer = L.Trainer(
    max_epochs=500,
    callbacks=[lr_monitor],
    logger=logger,
    accelerator="gpu",
    devices=[gpu_id],
)
trainer.fit(model, train_dataloaders=loaders["train"], val_dataloaders=loaders["val"])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name      | Type             | Params
-----------------------------------------------
0 | lstm      | LSTM             | 396 K 
1 | output    | Sequential       | 16.5 K
2 | classifer | Sequential       | 5.2 K 
3 | loss_fn   | CrossEntropyLoss | 0     
-----------------------------------------------
417 K     Trainable params
0         Non-trainable params
417 K     Total params
1.672     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/choi/Downloads/miniconda3/envs/braindecoder/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:492: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  7.71it/s]

Validation accuracy: 0.0 (0/1994 correct)
Validation loss (average): 0.059094130992889404


Epoch 0: 100%|██████████| 498/498 [00:09<00:00, 53.76it/s, v_num=9:37, train_loss_step=3.670]

Validation accuracy: 0.020060181617736816 (40/1994 correct)
Validation loss (average): 3.69327449798584


Epoch 0: 100%|██████████| 498/498 [00:14<00:00, 33.54it/s, v_num=9:37, train_loss_step=3.670, val_loss=3.690, train_loss_epoch=3.690]

Training accuracy: 0.02965196780860424 (236/7959 correct)
Training loss (average): 3.688868522644043
Learning rate: [1.9940049999998908e-05] 

Epoch 1: 100%|██████████| 498/498 [00:13<00:00, 37.45it/s, v_num=9:37, train_loss_step=3.690, val_loss=3.690, train_loss_epoch=3.690]

Validation accuracy: 0.02306920848786831 (46/1994 correct)
Validation loss (average): 3.6927762031555176


Epoch 1: 100%|██████████| 498/498 [00:18<00:00, 26.44it/s, v_num=9:37, train_loss_step=3.690, val_loss=3.690, 

/home/choi/Downloads/miniconda3/envs/braindecoder/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


# Testing


In [34]:
trainer.test(model, dataloaders=loaders["test"])

/Users/ms/anaconda3/envs/braindecoder/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:492: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/Users/ms/anaconda3/envs/braindecoder/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 125/125 [00:03<00:00, 32.05it/s]


/Users/ms/anaconda3/envs/braindecoder/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 3. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[{'test_loss': 3.7041985988616943, 'test_acc': 0.01308505330234766}]

In [36]:
# idx = 6
# query = val_dataset[idx][0].unsqueeze(dim=0)
# # print(query.shape)
# query.to(device)
# pred = model(query)
# # print(pred)
# pred = torch.argmax(pred, dim=1)
# # pred = pred.max(dim=1)
# # print(pred)
# print("predicted: ", id_to_name[lookup_dict[pred.item()]])
# print("answer: ", id_to_name[lookup_dict[val_dataset[idx][1]]])

In [37]:
# # Calculate test accuracy
# num_correct = 0
# model.to(device)
# for x, y in test_loader:
#     x = x.to(device)
#     y = y.to(device)
#     out = model(x)
#     y_hat = out.argmax(dim=1)
#     # print(y==y_hat)
#     num_correct += (y == y_hat).sum()
# acc = num_correct / len(test_loader.dataset)
# print("Accuracy:", acc.item())