# Import


In [2]:
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import lightning as L
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger

import os
import numpy as np
import random

# Dataset


In [3]:
root_path = "/Users/ms/cs/ML/NeuroImagen/"
dataset_path = os.path.join(root_path, "dataset")
images_dataset_path = os.path.join(dataset_path, "imageNet_images")
eeg_dataset_path = os.path.join(dataset_path, "eeg")

In [4]:
class EEGDataset(Dataset):
    def __init__(self, eeg_dataset_file_name="eeg_5_95_std.pth") -> None:
        super().__init__()
        loaded = torch.load(os.path.join(eeg_dataset_path, eeg_dataset_file_name))
        self.data = loaded["dataset"]
        self.labels = loaded["labels"]
        self.images = loaded["images"]
        self.size = len(self.data)

    def __getitem__(self, idx):
        # t() -> transpose
        eeg = self.data[idx]["eeg"].t()
        eeg = eeg[20:460, :]

        label = self.data[idx]["label"]
        return eeg, label

    def __len__(self):
        return self.size

In [26]:
class Splitter(Dataset):
    def __init__(self, dataset, split_name="train") -> None:
        super().__init__()
        self.dataset = dataset

        loaded = torch.load(
            os.path.join(eeg_dataset_path, "block_splits_by_image_all.pth")
        )
        self.target_data_indices = loaded["splits"][0][split_name]
        # filter data that is too short
        self.target_data_indices = [
            i
            for i in self.target_data_indices
            if 450 <= self.dataset.data[i]["eeg"].size(1) <= 600
        ]

        self.size = len(self.target_data_indices)
        self.all_labels = np.array(self.get_all_labels())
        self.all_eegs = self.get_all_eegs()

    def __getitem__(self, idx):
        eeg, label = self.dataset[self.target_data_indices[idx]]
        return eeg, label

    def __len__(self):
        return self.size

    def get_all_labels(self):
        data = [self.dataset[idx] for idx in self.target_data_indices]
        return [item[1] for item in data]

    def get_all_eegs(self):
        data = [self.dataset[idx] for idx in self.target_data_indices]
        return [item[0] for item in data]

    def generate_data_points(self, anchor_labels, positive=True):
        eeg_shape = self.__getitem__(0)[0].size()
        eegs = torch.empty(0, eeg_shape[0], eeg_shape[1])
        labels = torch.empty(0)
        for anchor_label in anchor_labels:
            indices = (
                np.argwhere(self.all_labels == anchor_label.item())[:, 0]
                if positive
                else np.argwhere(self.all_labels != anchor_label.item())[:, 0]
            )
            data_idx = np.random.choice(indices)
            eeg = self.dataset[data_idx][0]
            eegs = torch.cat((eegs, eeg.unsqueeze(dim=0)))
            labels = torch.cat(
                (labels, torch.tensor(self.all_labels[data_idx]).unsqueeze(dim=0))
            )

        return eegs, labels

    def get_data(self, anchor_label, positive: bool = True):
        cnt = 0
        while True:
            idx = random.choice(self.target_data_indices)
            if positive and self.dataset[idx][1] == anchor_label:
                return self.dataset[idx]
            if not positive and self.dataset[idx][1] != anchor_label:
                return self.dataset[idx]

            if cnt >= 2000:
                raise Exception(f"get_data failed after {cnt} tries")
            cnt += 1

In [27]:
dataset = EEGDataset(eeg_dataset_file_name="eeg_5_95_std.pth")
loaders = {
    split: DataLoader(
        Splitter(dataset, split_name=split), batch_size=16, shuffle=True, drop_last=True
    )
    for split in ["train", "val", "test"]
}

Epoch 0:   0%|          | 0/497 [03:40<?, ?it/s]


# Model


In [7]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

In [31]:
class FeatureExtractor_ContrastiveLearning_NN(L.LightningModule):
    def __init__(self):
        super().__init__()

        # Triplet loss
        def dist_fn(x1, x2):
            return torch.sum(torch.pow(torch.subtract(x1, x2), 2), dim=0)

        self.loss_fn = nn.TripletMarginWithDistanceLoss(
            distance_function=dist_fn, margin=1.5
        )

        # model
        self.input_size = 128
        self.hidden_size = 128
        self.lstm_layers = 1
        self.out_size = 128

        self.lstm = nn.LSTM(
            self.input_size,
            self.hidden_size,
            num_layers=self.lstm_layers,
            batch_first=True,
        )
        self.output = nn.Sequential(
            nn.Linear(in_features=self.hidden_size, out_features=self.out_size),
            nn.ReLU(),
        )

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-4)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.95**epoch)
        return [optimizer], [scheduler]

    def forward(self, input):
        input = input.to(device)

        lstm_out, _ = self.lstm(input)
        res = self.output(lstm_out[:, -1, :])
        return res

    def training_step(self, batch, batch_idx):
        anchor_eeg, anchor_label = batch
        # positive_eeg = torch.empty(
        #     (0, anchor_eeg[0].shape[0], anchor_eeg[0].shape[1]), dtype=torch.float32
        # )
        # negative_eeg = torch.empty(
        #     (0, anchor_eeg[0].shape[0], anchor_eeg[0].shape[1]), dtype=torch.float32
        # )
        # for label in anchor_label:
        #     positive_eeg = torch.cat(
        #         (
        #             positive_eeg,
        #             loaders["train"]
        #             .dataset.get_data(label, positive=True)[0]
        #             .unsqueeze(dim=0),
        #         )
        #     )
        #     negative_eeg = torch.cat(
        #         (
        #             negative_eeg,
        #             loaders["train"]
        #             .dataset.get_data(label, positive=False)[0]
        #             .unsqueeze(dim=0),
        #         )
        #     )

        positive_eeg, positive_label = loaders["train"].dataset.generate_data_points(
            anchor_label, positive=True
        )
        negative_eeg, negative_label = loaders["train"].dataset.generate_data_points(
            anchor_label, positive=False
        )

        anchor_feature = self(anchor_eeg)
        positive_feature = self(positive_eeg)
        negative_eeg = self(negative_eeg)

        loss = self.loss_fn(anchor_feature, positive_feature, negative_eeg)

        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        anchor_eeg, anchor_label = batch
        # positive_eeg = torch.empty(
        #     (0, anchor_eeg[0].shape[0], anchor_eeg[0].shape[1]), dtype=torch.float32
        # )
        # negative_eeg = torch.empty(
        #     (0, anchor_eeg[0].shape[0], anchor_eeg[0].shape[1]), dtype=torch.float32
        # )
        # for label in anchor_label:
        #     positive_eeg = torch.cat(
        #         (
        #             positive_eeg,
        #             loaders["val"]
        #             .dataset.get_data(label, positive=True)[0]
        #             .unsqueeze(dim=0),
        #         )
        #     )
        #     negative_eeg = torch.cat(
        #         (
        #             negative_eeg,
        #             loaders["val"]
        #             .dataset.get_data(label, positive=False)[0]
        #             .unsqueeze(dim=0),
        #         )
        #     )

        positive_eeg, positive_label = loaders["val"].dataset.generate_data_points(
            anchor_label, positive=True
        )
        negative_eeg, negative_label = loaders["val"].dataset.generate_data_points(
            anchor_label, positive=False
        )

        anchor_feature = self(anchor_eeg)
        positive_feature = self(positive_eeg)
        negative_eeg = self(negative_eeg)

        loss = self.loss_fn(anchor_feature, positive_feature, negative_eeg)

        self.log("val_loss", loss, on_epoch=True, prog_bar=True)

# Training


In [32]:
model = FeatureExtractor_ContrastiveLearning_NN()

logger = TensorBoardLogger(
    save_dir="/Users/ms/cs/ML/NeuroImagen/lightning_logs/ContrastiveLoss",
    name="test",
    version=None,
)

trainer = Trainer = L.Trainer(max_epochs=200, logger=logger)
trainer.fit(model, train_dataloaders=loaders["train"], val_dataloaders=loaders["val"])

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name    | Type                          | Params
----------------------------------------------------------
0 | loss_fn | TripletMarginWithDistanceLoss | 0     
1 | lstm    | LSTM                          | 132 K 
2 | output  | Sequential                    | 16.5 K
----------------------------------------------------------
148 K     Trainable params
0         Non-trainable params
148 K     Total params
0.594     Total estimated model params size (MB)


Epoch 0:   8%|▊         | 41/497 [00:03<00:36, 12.42it/s, v_num=69, train_loss_step=1.490]

# Classifier Model


# Train Classifier


# Test Classifier
