# Import


In [63]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader


import lightning as L
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger

import sys
import os
import numpy as np
import cv2
from datetime import datetime

sys.path.append(os.path.join(os.path.dirname(os.getcwd()), "code"))
import dataset as D

# Dataset


In [37]:
dataset = D.EEGDataset(eeg_dataset_file_name="eeg_5_95_std.pth")
loaders = {
    split: DataLoader(
        D.Splitter(dataset, split_name=split),
        batch_size=16,
        shuffle=True,
        drop_last=True,
    )
    for split in ["train", "val", "test"]
}

In [32]:
img = cv2.imread(
    "/Users/ms/cs/ML/NeuroImagen/dataset/imageNet_images/n07753592/n07753592_847.JPEG"
)
print(img.shape)

(500, 375, 3)


# Model


In [53]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(device)

config = {}

mps


In [50]:
class FeatureExtractor_ContrastiveLearning_NN(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()

        # Triplet loss
        # def dist_fn(x1, x2):
        #     return torch.sum(torch.pow(torch.subtract(x1, x2), 2), dim=0)

        # self.loss_fn = nn.TripletMarginWithDistanceLoss(
        #     distance_function=dist_fn, margin=config["margin"]
        # )

        # model
        self.input_size = 128
        self.hidden_size = 128
        self.lstm_layers = 1
        self.out_size = 128

        self.lstm = nn.LSTM(
            self.input_size,
            self.hidden_size,
            num_layers=self.lstm_layers,
            batch_first=True,
        )
        self.output = nn.Sequential(
            nn.Linear(in_features=self.hidden_size, out_features=self.out_size),
            nn.ReLU(),
        )

    def forward(self, input):
        input = input.to(device)

        lstm_out, _ = self.lstm(input)
        res = self.output(lstm_out[:, -1, :])
        return res

In [61]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super().__init__()

        def block(input_features, output_features, normalize=True):
            layers = [nn.Linear(input_features, output_features)]
            if normalize:
                layers.append(nn.BatchNorm1d(output_features, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim + 128, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(config["img-size"]))),
            nn.Tanh(),
        )

    def forward(self, noise, condition):
        gen_input = torch.cat((condition, noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *config["img-size"])
        return img

In [59]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(config["img-size"])) + 128, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img, condition):
        d_input = torch.cat(img.view(img.size(0), -1), condition, -1)
        validity = self.model(d_input)
        return validity

In [45]:
class saliency_map_GAN(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()

        self.data_shape = (3, 32, 32)

        self.generator = Generator()
        self.discriminator = Discriminator()

        self.feature_extractor = (
            FeatureExtractor_ContrastiveLearning_NN.load_from_checkpoint(
                config["checkpoint"]
            )
        )
        self.feature_extractor.requires_grad_(False)
        self.loss_fn = self.adversarial_loss

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def forward(self, noise, condition):
        return self.generator(noise, condition)

    def training_step(self, batch, _, optim_idx):
        eegs, real_imgs = batch

        batch_size = imgs.size(0)
        noise = torch.randn(batch_size, 100)
        noise = noise.to(device)
        imgs = imgs.to(device)

        eeg_features = self.feature_extractor(eegs)

        # generator training
        if optim_idx == 0:
            gen_imgs = self.generator(noise, eeg_features)
            y_hat = self.discriminator(gen_imgs, eeg_features)
            y_real = torch.ones([batch_size, 1], device=device, requires_grad=False)
            g_loss = self.loss_fn(y_hat, y_real)
            return g_loss

        # discriminator training
        if optim_idx == 1:
            y_hat = self.discriminator(real_imgs, eeg_features)
            y_real = torch.ones([batch_size, 1], device=device, requires_grad=False)
            d_loss_real = self.loss_fn(y_hat, y_real)

            gen_imgs = self.generator(noise, eeg_features)
            y_hat = self.discriminator(gen_imgs, y_fake)
            y_fake = torch.zeros([batch_size, 1], device=device, requires_grad=False)
            d_loss_fake = self.loss_fn(y_hat, y_fake)

            d_loss = (d_loss_real + d_loss_fake) / 2
            return d_loss

    # def validation_step(self, batch):
    #     return

    def on_validation_epoch_end(self):
        print("HI")

    def configure_optimizers(self):
        g_optim = optim.Adam(self.generator.parameters(), lr=1e-3)
        d_optim = optim.Adam(self.discriminator.parameters(), lr=1e-3)
        return [g_optim, d_optim]

# Training


In [57]:
config = {
    "img-size": (3, 256, 256),
    "checkpoint": "/Users/ms/cs/ML/NeuroImagen/lightning_logs/ContrastiveLossFeatureLearning/Adam_0.0001_LambdaLR_margin_1.5/weight-decay_0_lambda-factor_0.95/checkpoints/epoch=2-step=1491.ckpt",
}

In [62]:
model = saliency_map_GAN()
model.to(device)

logger = TensorBoardLogger(
    save_dir="/Users/ms/cs/ML/NeuroImagen/lightning_logs/SaliencyMapGAN/",
    name=f"",
    version=f"",
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")

trainer = L.Trainer(max_epochs=200, callbacks=[lr_monitor], logger=logger)
# trainer.fit(model, train_dataloaders=loaders["train"], val_dataloaders=loaders["val"])

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


# Testing


In [47]:
# try output one image from one random eeg. show generated image, original image