# Import


In [7]:
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import lightning as L
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger

import os
import numpy as np
import random

# Dataset


In [2]:
root_path = ""
dataset_path = os.path.join(root_path, "dataset")
images_dataset_path = os.path.join(dataset_path, "imageNet_images")
eeg_dataset_path = os.path.join(dataset_path, "eeg")

In [3]:
class EEGDataset(Dataset):
    def __init__(self, eeg_dataset_file_name="eeg_5_95_std.pth") -> None:
        super().__init__()
        loaded = torch.load(os.path.join(eeg_dataset_path, eeg_dataset_file_name))
        self.data = loaded["dataset"]
        self.labels = loaded["labels"]
        self.images = loaded["images"]
        self.size = len(self.data)

    def __getitem__(self, idx):
        # t() -> transpose
        eeg = self.data[idx]["eeg"].t()
        eeg = eeg[20:460, :]

        label = self.data[idx]["label"]
        return eeg, label

    def __len__(self):
        return self.size

In [4]:
class Splitter(Dataset):
    def __init__(self, dataset, split_name="train") -> None:
        super().__init__()
        self.dataset = dataset

        loaded = torch.load(
            os.path.join(eeg_dataset_path, "block_splits_by_image_all.pth")
        )
        self.target_data_indices = loaded["splits"][0][split_name]
        # filter data that is too short
        self.target_data_indices = [
            i
            for i in self.target_data_indices
            if 450 <= self.dataset.data[i]["eeg"].size(1) <= 600
        ]

        self.size = len(self.target_data_indices)

    def __getitem__(self, idx):
        eeg, label = self.dataset[self.target_data_indices[idx]]
        return eeg, label

    def __len__(self):
        return self.size

    def get_item_eeg(self, anchor_label, positive: bool):
        cnt = 0
        while True:
            idx = random.choice(self.target_data_indices)
            if positive and self.dataset[idx][1] == anchor_label:
                return self.dataset[idx][0]
            if not positive and self.dataset[idx][1] != anchor_label:
                return self.dataset[idx][0]

            if cnt >= 100:
                raise Exception(f"get_item_eeg failed after {cnt} tries")
            cnt += 1

In [5]:
dataset = EEGDataset(eeg_dataset_file_name="eeg_5_95_std.pth")
loaders = {
    split: DataLoader(
        Splitter(dataset, split_name=split), batch_size=16, shuffle=True, drop_last=True
    )
    for split in ["train", "val", "test"]
}

# Model


In [6]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

In [None]:
class FeatureExtractor_ContrastiveLearning_NN(L.LightningModule):
    def __init__(self):
        super().__init__()

        def dist_fn(x1, x2):
            # TODO: dim=0??
            return torch.sum(torch.pow(torch.subtract(x1, x2), 2), dim=0)

        self.loss_fn = nn.TripletMarginWithDistanceLoss(
            distance_function=dist_fn, margin=1.5
        )

    def forward(self, input):
        res = 0
        return res

    def training_step(self, batch, batch_idx):
        # TODO: how to get a,p,n
        # TODO: how to configure dataset?
        anchor_indices = ()
        positive_indices = ()
        negative_indices = ()
        anchor = self()
        positive = self()
        negative = self()

        loss = self.loss_fn(anchor, positive, negative)
        self.log("train_loss", loss)
        return loss

    def validation_step(self):
        anchor, positive, negative = self()
        loss = self.loss_fn(anchor, positive, negative)
        self.log("train_loss", loss)

# Training


In [None]:
model = FeatureExtractor_ContrastiveLearning_NN()

logger = TensorBoardLogger(
    save_dir="/Users/ms/cs/ML/NeuroImagen/lightning_logs/",
    name="test",
    version=None,
)

trainer = Trainer = L.Trainer(max_epochs=200, logger=logger)

In [None]:
trainer.fit(model, train_dataloaders=loaders["train"], val_dataloaders=["val"])

# Classifier Model


# Train Classifier


# Test Classifier
