In [7]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
import sys
import os
sys.path.append(os.path.abspath(''))

In [None]:
from torch.utils.data import DataLoader, random_split
import torch
from pytorch_lightning import Trainer
# from pytorch_lightning.loggers import TensorBoardLogger
from torchvision import datasets, transforms
from datasets import load_dataset

In [17]:
from data.dataloader import EEG_Dataset, EEG_Spectogram_Dataset, EncodedDataset, LatentDataset

from models.cnn import CNNClassifier
from models.latent_dim import LatentProjection
from models.mnist_basic import MNISTClassifier
from models.resnet import ResNetClassifier
from models.unet_autoencoder import UNetAutoencoder
from models.vgg import VGGish

from utils.utils import LossTracker, plot_losses, plot_spectrogram, create_dataloaders

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [22]:
# logger = TensorBoardLogger("lightning_logs", name="CNN_model")

In [None]:
POSTFIX = "MNIST_IN" # or "MNIST_EP" 
train_dataset = load_dataset(f"DavidVivancos/MindBigData2022_{POSTFIX}", split="train")

## Raw EEG

In [34]:
full_eeg_dataset = EEG_Dataset(train_dataset)
eeg_train_dataloader, eeg_val_dataloader = create_dataloaders(full_eeg_dataset, batch_size=128, train_split=0.8, seed=42)

In [None]:
loss_tracker = LossTracker()
cnn_classifier = CNNClassifier(input_channels=5, sequence_length=256, dropout_rate=.2)
trainer = Trainer(max_epochs=50, callbacks=[loss_tracker]) # add logger=logger to log to tensorboard
trainer.fit(cnn_classifier, eeg_train_dataloader, eeg_val_dataloader)

NameError: name 'logger' is not defined

## Spectogram EEG

In [20]:
full_eeg_spectogram_dataset = EEG_Spectogram_Dataset(train_dataset)
eeg_spectogram_train_dataloader, eeg_spectogram_val_dataloader = create_dataloaders(full_eeg_spectogram_dataset, batch_size=128, train_split=0.8, seed=42)

In [None]:
loss_tracker = LossTracker()
vggish = VGGish()
trainer = Trainer(max_epochs=50, callbacks=[loss_tracker])
trainer.fit(vggish, eeg_spectogram_train_dataloader, eeg_spectogram_val_dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## UNET

In [None]:
# checkpoint_path = "lightning_logs/version_16/checkpoints/epoch=9-step=5220.ckpt"
# autoencoder = CNNAutoencoder.load_from_checkpoint(checkpoint_path)
latent_dim = 64
loss_tracker = LossTracker()
autoencoder = UNetAutoencoder(latent_dim=latent_dim)
trainer = Trainer(max_epochs=10, callbacks=[loss_tracker])
trainer.fit(autoencoder, eeg_train_dataloader, eeg_val_dataloader)

## MNIST

In [328]:
transform = transforms.Compose([transforms.ToTensor()])
mnist_full = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
mnist_train_dataset, mnist_val_dataset = random_split(mnist_full, [55000, 5000])
mnist_train_loader = DataLoader(mnist_train_dataset, batch_size=64, shuffle=True)
mnist_val_loader = DataLoader(mnist_val_dataset, batch_size=64)

In [329]:
checkpoint_path = "lightning_logs/version_14/checkpoints/epoch=9-step=8600.ckpt"
mnist_classifier = MNISTClassifier.load_from_checkpoint(checkpoint_path)
# mnist_classifier = MNISTClassifier(latent_dim=latent_dim)
# mnist_trainer = pl.Trainer(max_epochs=3)
# mnist_trainer.fit(mnist_classifier, mnist_train_loader, mnist_val_loader)

## Model Alignement

In [355]:
eeg_encoder = autoencoder.get_encoder()
mnist_encoder = mnist_classifier.feature_extractor

In [356]:
eeg_train_encoded_dataset = EncodedDataset(eeg_train_dataloader, eeg_encoder)
eeg_val_encoded_dataset = EncodedDataset(eeg_val_dataloader, eeg_encoder)

In [357]:
mnist_train_encoded_data = EncodedDataset(mnist_train_loader, mnist_encoder)
mnist_val_encoded_data = EncodedDataset(mnist_val_loader, mnist_encoder)

In [358]:
latent_train_dataset = LatentDataset(
    eeg_train_encoded_dataset, mnist_train_encoded_data
)
latent_train_dataloader = DataLoader(latent_train_dataset, batch_size=32, shuffle=True)

latent_val_dataset = LatentDataset(eeg_val_encoded_dataset, mnist_train_encoded_data)
latent_val_dataloader = DataLoader(latent_val_dataset, batch_size=32, shuffle=False)

In [359]:
loss_tracker = LossTracker()
latent_projection = LatentProjection(latent_dim,latent_dim)
latent_trainer = Trainer(max_epochs=10)
latent_trainer.fit(latent_projection, latent_train_dataloader, latent_val_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name       | Type       | Params
------------------------------------------
0 | projection | Sequential | 58.6 K
1 | criterion  | MSELoss    | 0     
------------------------------------------
58.6 K    Trainable params
0         Non-trainable params
58.6 K    Total params
0.234     Total estimated model params size (MB)


Epoch 3:  63%|██████▎   | 1090/1719 [00:11<00:06, 96.98it/s, v_num=2, train_loss_step=11.70, val_loss_step=8.510, val_loss_epoch=11.00, train_loss_epoch=11.00]