In [3]:
import pandas as pd
import random
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from datasets import load_dataset
from PyEMD import EMD
from scipy.signal import hilbert

In [9]:
class EEG_Dataset(Dataset):
    def __init__(self, dataset):
        data_df = dataset.to_pandas()
        features_names = list(data_df.columns)[1:]
        self.electrodes = list(set([f.split("-")[0] for f in features_names]))
        self.nb_channel = len(self.electrodes)
        self.nb_samples = (max([int(f.split("-")[1]) for f in features_names]))+1
        self.datas = self._norm(data_df)
        self.features = self.datas[:,1:]
        self.labels = list(self.datas[:,0].astype(int))


    def _norm(self, data_df):
        electrodes_min = {}
        electrodes_max = {}
        for electrode in self.electrodes :
            electrodes_min[electrode] = data_df[[f'{electrode}-{i}'for i in range(self.nb_samples)]].to_numpy().reshape(-1,).max()
            electrodes_max[electrode] = data_df[[f'{electrode}-{i}'for i in range(self.nb_samples)]].to_numpy().reshape(-1,).min()

        for electrode in self.electrodes:
            data_df[[f'{electrode}-{i}'for i in range(self.nb_samples)]] = (data_df[[f'{electrode}-{i}'for i in range(self.nb_samples)]] - electrodes_min[electrode])/(electrodes_max[electrode] -electrodes_min[electrode])
        return data_df.to_numpy()


    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.tensor(self.features[idx]).view(self.nb_channel,-1).float(), self.labels[idx]

In [3]:
class UNetAutoencoder(pl.LightningModule):
    def __init__(self, input_channels=5, sequence_length=256, latent_dim=64, dropout_rate=0.2, mask_ratio=0.15):
        super(UNetAutoencoder, self).__init__()

        self.input_channels = input_channels
        self.sequence_length = sequence_length
        self.mask_ratio = mask_ratio

        # Encoder
        self.enc1 = self.conv_block(input_channels, 32, dropout_rate)
        self.enc2 = self.conv_block(32, 64, dropout_rate)
        self.enc3 = self.conv_block(64, 128, dropout_rate)
        self.enc4 = self.conv_block(128, 256, dropout_rate)

        # Calculate the size of the encoder output
        self.encoder_output_size = sequence_length // 16 * 256

        # Latent space
        self.fc1 = nn.Sequential(
            nn.Linear(self.encoder_output_size, latent_dim),
            nn.Dropout(dropout_rate)
        )
        self.fc2 = nn.Linear(latent_dim, self.encoder_output_size)

        # Decoder
        self.dec4 = self.conv_block(256, 128, dropout_rate, transpose=True)
        self.dec3 = self.conv_block(256, 64, dropout_rate, transpose=True)
        self.dec2 = self.conv_block(128, 32, dropout_rate, transpose=True)
        self.dec1 = nn.ConvTranspose1d(64, input_channels, kernel_size=3, stride=2, padding=1, output_padding=1)

        self.criterion = nn.MSELoss()

    def conv_block(self, in_channels, out_channels, dropout_rate, transpose=False):
        if not transpose:
            return nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm1d(out_channels),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm1d(out_channels),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose1d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm1d(out_channels),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm1d(out_channels),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            )

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)

        # Flatten
        x = e4.view(e4.size(0), -1)

        # Latent space
        x = self.fc1(x)
        x = self.fc2(x)

        # Reshape for decoder
        x = x.view(x.size(0), 256, -1)

        # Decoder with skip connections
        x = self.dec4(x)
        x = torch.cat([x, e3], dim=1)
        x = self.dec3(x)
        x = torch.cat([x, e2], dim=1)
        x = self.dec2(x)
        x = torch.cat([x, e1], dim=1)
        x = self.dec1(x)

        return x

    def get_encoder(self):
        return nn.Sequential(
            self.enc1, self.enc2, self.enc3, self.enc4,
            nn.Flatten(),
            self.fc1
        )

    def apply_mask(self, x):
        batch_size, _, _ = x.shape
        mask = torch.rand(batch_size, self.input_channels, self.sequence_length, device=x.device) > self.mask_ratio
        return x * mask, mask

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x_masked, mask = self.apply_mask(x)
        x_hat = self.forward(x_masked)
        loss = self.criterion(x_hat * mask, x * mask)  # Compute loss only on unmasked parts
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x_masked, mask = self.apply_mask(x)
        x_hat = self.forward(x_masked)
        loss = self.criterion(x_hat * mask, x * mask)  # Compute loss only on unmasked parts
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=False):
        super(ResidualBlock, self).__init__()
        stride = 2 if downsample else 1
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.norm1 = nn.BatchNorm1d(out_channels)
        self.mish = nn.Mish()
        
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.BatchNorm1d(out_channels)
        
        # Skip connection
        self.downsample = downsample
        if downsample or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm1d(out_channels)
            )
        else:
            self.skip = nn.Identity()

    def forward(self, x):
        identity = self.skip(x)
        out = self.mish(self.norm1(self.conv1(x)))
        out = self.norm2(self.conv2(out))
        out += identity
        return self.mish(out)

class ResNetClassifier(pl.LightningModule):
    def __init__(self, input_channels=5, sequence_length=256, num_classes=10, dropout_rate=0.5):  # Higher dropout
        super(ResNetClassifier, self).__init__()
        
        # Initial layer
        self.initial_conv = nn.Sequential(
            nn.Conv1d(input_channels, 16, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm1d(16),
            nn.Mish(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        )
        
        # Reduced Residual Blocks
        self.layer1 = ResidualBlock(16, 32, downsample=True)
        self.layer2 = ResidualBlock(32, 64, downsample=True)
        
        # Calculate size before FC layers
        reduced_length = sequence_length // 4
        self.fc_input_size = 1024
        
        # Fully connected layers with residual connection
        self.fc1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.fc_input_size, 256),  # Reduced size
            nn.Mish(),
            nn.Dropout(dropout_rate)
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(256, 256),  # Reduced size
            nn.Mish(),
            nn.Dropout(dropout_rate)
        )
        
        # Output layer
        self.classifier = nn.Linear(256, num_classes)
        
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.initial_conv(x)
        x = self.layer1(x)
        x = self.layer2(x)
        
        # Residual connection in fully connected layers
        x = self.fc1(x)
        x = x + self.fc2(x)
        logits = self.classifier(x)
        
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_accuracy", accuracy, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_accuracy", accuracy, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.01, weight_decay=1e-3)
        
        # Cosine Annealing Scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
        
        return [optimizer], [scheduler]


In [40]:
class CNNClassifier(pl.LightningModule):
    def __init__(self, input_channels=5, sequence_length=256, num_classes=10, dropout_rate=0.2):
        super(CNNClassifier, self).__init__()
        
        # Layer 1: Conv + LeakyReLU
        self.conv1 = nn.Conv1d(input_channels, input_channels, kernel_size=5, stride=1, padding=2)
        self.leaky1 = nn.LeakyReLU(negative_slope=0.1)
        self.norm1 = nn.BatchNorm1d(input_channels)
        
        # Layer 3: Conv + LeakyReLU
        self.conv2 = nn.Conv1d(input_channels, 3, kernel_size=5, stride=1, padding=2)
        self.leaky2 = nn.LeakyReLU(negative_slope=0.1)
        self.norm2 = nn.BatchNorm1d(3)
        
        # Layer 5: Conv + LeakyReLU
        self.conv3 = nn.Conv1d(3, 3, kernel_size=5, stride=1, padding=2)
        self.leaky3 = nn.LeakyReLU(negative_slope=0.1)
        self.norm3 = nn.BatchNorm1d(3)
        
        # Calculate size before FC layers
        self.fc_input_size = 3 * sequence_length  # 10 channels * sequence_length
        
        # Layer 7: Full-connected
        self.fc1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.fc_input_size, self.fc_input_size),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Dropout(dropout_rate)
        )
        
        # Layer 8: Full-connected
        self.fc2 = nn.Sequential(
            nn.Linear(self.fc_input_size, self.fc_input_size//4),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Dropout(dropout_rate)
        )
        
        # Output layer
        self.classifier = nn.Linear(self.fc_input_size//4, num_classes)
        
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        # Convolutional layers
        x = self.conv1(x)
        x = self.leaky1(x)
        x = self.norm1(x)
        
        x = self.conv2(x)
        x = self.leaky2(x)
        x = self.norm2(x)
        
        x = self.conv3(x)
        x = self.leaky3(x)
        x = self.norm3(x)
        
        # Fully connected layers
        x = self.fc1(x)
        x = self.fc2(x)
        
        # Classification head
        logits = self.classifier(x)
        
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_accuracy", accuracy, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_accuracy", accuracy, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-2)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.2, patience=2, verbose=True, min_lr=0.000001,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

In [6]:
class MNISTClassifier(pl.LightningModule):
    def __init__(self, latent_dim=64, num_classes=10):
        super(MNISTClassifier, self).__init__()
        
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, latent_dim)
        )
        
        self.classifier = nn.Linear(latent_dim, num_classes)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        features = self.feature_extractor(x)
        return self.classifier(features), features

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits, _ = self(x)
        loss = self.criterion(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits, _ = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss)
        self.log("val_acc", acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [7]:
class EncodedDataset:
    def __init__(self, dataloader, encoder):
        super().__init__()
        self.encoded_datas, self.labels = self._encode_dataset(dataloader, encoder)

    def _encode_dataset(self, dataloader, encoder):
        encoded_datas = []
        labels = []
        for batch in dataloader:
            encoded_data = encoder(batch[0]).detach()
            encoded_datas.append(encoded_data)
            labels.append(batch[1])
        encoded_datas = torch.cat(encoded_datas)
        labels = torch.cat(labels)
        return encoded_datas, labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.encoded_datas[idx], self.labels[idx].item()


In [8]:
class LatentDataset(torch.utils.data.Dataset):
    def __init__(self, eeg_data, mnist_data):
        self.paired_data = self._pair_data(eeg_data, mnist_data)

    def _pair_data(self, eeg_data, mnist_data):
        eeg_by_label = {i: [] for i in range(10)}
        mnist_by_label = {i: [] for i in range(10)}

        # Group EEG data by label
        for eeg, label in eeg_data: 
            if 0 <= label < 10:
                eeg_by_label[label].append(eeg)

        # Group MNIST data by label
        for img, label in mnist_data:
            mnist_by_label[label].append(img)

        # Pair EEG and MNIST data
        paired_data = []
        for label in eeg_by_label.keys():
            eeg_samples = eeg_by_label[label]
            mnist_samples = mnist_by_label[label]

            # Use the maximum number of samples available
            n_samples = max(len(eeg_samples), len(mnist_samples))

            # Replicate samples if necessary
            if len(eeg_samples) < n_samples:
                eeg_samples = self._replicate_samples(eeg_samples, n_samples)
            if len(mnist_samples) < n_samples:
                mnist_samples = self._replicate_samples(mnist_samples, n_samples)

            for eeg, mnist in zip(eeg_samples, mnist_samples):
                paired_data.append((eeg, mnist, label))

        return paired_data

    def _replicate_samples(self, samples, target_size):
        """Replicate samples randomly to reach the target size."""
        while len(samples) < target_size:
            samples.append(random.choice(samples))
        return samples

    def __len__(self):
        return len(self.paired_data)

    def __getitem__(self, idx):
        return self.paired_data[idx]

In [9]:
class LatentProjection(pl.LightningModule):
    def __init__(self, eeg_latent_dim, mnist_latent_dim, hidden_dims=[256, 128]):
        super(LatentProjection, self).__init__()
        
        layers = []
        input_dim = eeg_latent_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.2)
            ])
            input_dim = hidden_dim
        
        layers.append(nn.Linear(input_dim, mnist_latent_dim))
        
        self.projection = nn.Sequential(*layers)
        self.criterion = nn.MSELoss()

    def forward(self, x):
        return self.projection(x)

    def training_step(self, batch, batch_idx):
        eeg_latent, mnist_latent, _ = batch
        projected = self(eeg_latent)
        loss = self.criterion(projected, mnist_latent)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)        
        return loss

    def validation_step(self, batch, batch_idx):
        eeg_latent, mnist_latent, _ = batch
        projected = self(eeg_latent)
        val_loss = self.criterion(projected, mnist_latent)
        self.log("val_loss", val_loss, on_step=True, on_epoch=True, prog_bar=True)        
        return val_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

In [15]:
class LossTracker(Callback):
    def __init__(self):
        self.train_losses = []
        self.val_losses = []

    def on_train_epoch_end(self, trainer, pl_module):
        self.train_losses.append(trainer.callback_metrics["train_loss"].item())

    def on_validation_epoch_end(self, trainer, pl_module):
        self.val_losses.append(trainer.callback_metrics["val_loss"].item())

In [11]:
def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.show()

    # Print final losses
    print(f"Final training loss: {train_losses[-1]}")
    print(f"Final validation loss: {val_losses[-1]}")

## EEG

In [10]:
train_dataset = load_dataset("DavidVivancos/MindBigData2022_MNIST_IN", split="train")

In [29]:
full_eeg_dataset = EEG_Dataset(train_dataset)

# Calculate sizes for train and validation sets
total_size = len(full_eeg_dataset)
train_size = int(0.8 * total_size)  # 80% for training
val_size = total_size - train_size  # Remaining 20% for validation

# Use random_split to create train and validation datasets
train_eeg_dataset, val_eeg_dataset = random_split(full_eeg_dataset, [train_size, val_size], 
                                          generator=torch.Generator().manual_seed(42))

In [12]:
full_eeg_dataset[0]

(tensor([[0.4339, 0.4346, 0.4309,  ..., 0.4258, 0.4293, 0.4287],
         [0.4346, 0.4343, 0.4306,  ..., 0.4300, 0.4323, 0.4299],
         [0.4309, 0.4309, 0.4280,  ..., 0.4322, 0.4350, 0.4333],
         [0.2718, 0.2676, 0.2668,  ..., 0.2569, 0.2674, 0.2687],
         [0.3430, 0.3411, 0.3390,  ..., 0.3431, 0.3451, 0.3459]]),
 6)

In [13]:
batch_size = 128
eeg_train_dataloader = DataLoader(
    dataset=train_eeg_dataset,
    batch_size=batch_size,
    shuffle=True,
)

eeg_val_dataloader = DataLoader(
    dataset=val_eeg_dataset,
    batch_size=batch_size,
    shuffle=False,
)

In [63]:
latent_dim = 64

In [None]:
# checkpoint_path = "lightning_logs/version_16/checkpoints/epoch=9-step=5220.ckpt"
# autoencoder = CNNAutoencoder.load_from_checkpoint(checkpoint_path)
loss_tracker = LossTracker()
autoencoder = UNetAutoencoder(latent_dim=latent_dim)
trainer = Trainer(max_epochs=10, callbacks=[loss_tracker])
trainer.fit(autoencoder, eeg_train_dataloader, eeg_val_dataloader)

In [41]:
# checkpoint_path = "lightning_logs/version_16/checkpoints/epoch=9-step=5220.ckpt"
# autoencoder = CNNAutoencoder.load_from_checkpoint(checkpoint_path)
loss_tracker = LossTracker()
autoencoder = CNNClassifier(input_channels=5, sequence_length=256, dropout_rate=.2)
trainer = Trainer(max_epochs=50, callbacks=[loss_tracker])
trainer.fit(autoencoder, eeg_train_dataloader, eeg_val_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

   | Name       | Type             | Params | Mode 
---------------------------------------------------------
0  | conv1      | Conv1d           | 416    | train
1  | leaky1     | LeakyReLU        | 0      | train
2  | norm1      | BatchNorm1d      | 32     | train
3  | conv2      | Conv1d           | 2.6 K  | train
4  | leaky2     | LeakyReLU        | 0      | train
5  | norm2      | BatchNorm1d      | 64     | train
6  | conv3      | Conv1d           | 2.6 K  | train
7  | leaky3     | LeakyReLU        | 0      | train
8  | norm3      | BatchNorm1d      | 32     | train
9  | fc1        | Sequential       | 16.8 M | train
10 | fc2        | Sequential       | 4.2 M  | train
11 | classifier | Linear           | 10.2 K | train
12 | criterion  | CrossEntropyLoss | 0      | train
---------------------------------------------------------
21.0 M    Trainable params
0         Non-tra

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [21]:
trainer = Trainer(max_epochs=100, callbacks=[loss_tracker])
trainer.fit(autoencoder, eeg_train_dataloader, eeg_val_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

   | Name       | Type             | Params | Mode 
---------------------------------------------------------
0  | conv1      | Conv1d           | 130    | train
1  | leaky1     | LeakyReLU        | 0      | train
2  | norm1      | BatchNorm1d      | 10     | train
3  | conv2      | Conv1d           | 78     | train
4  | leaky2     | LeakyReLU        | 0      | train
5  | norm2      | BatchNorm1d      | 6      | train
6  | conv3      | Conv1d           | 48     | train
7  | leaky3     | LeakyReLU        | 0      | train
8  | norm3      | BatchNorm1d      | 6      | train
9  | fc1        | Sequential       | 590 K  | train
10 | fc2        | Sequential       | 147 K  | train
11 | classifier | Linear           | 1.9 K  | train
12 | criterion  | CrossEntropyLoss | 0      | train
---------------------------------------------------------
740 K     Trainable params
0         Non-tra

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [None]:
autoencoder = CNNClassifier()
trainer = Trainer(max_epochs=50, callbacks=[loss_tracker])
trainer.fit(autoencoder, eeg_train_dataloader, eeg_val_dataloader)

## MNIST

In [328]:
transform = transforms.Compose([transforms.ToTensor()])
mnist_full = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
mnist_train_dataset, mnist_val_dataset = random_split(mnist_full, [55000, 5000])
mnist_train_loader = DataLoader(mnist_train_dataset, batch_size=64, shuffle=True)
mnist_val_loader = DataLoader(mnist_val_dataset, batch_size=64)

In [329]:
checkpoint_path = "lightning_logs/version_14/checkpoints/epoch=9-step=8600.ckpt"
mnist_classifier = MNISTClassifier.load_from_checkpoint(checkpoint_path)
# mnist_classifier = MNISTClassifier(latent_dim=latent_dim)
# mnist_trainer = pl.Trainer(max_epochs=3)
# mnist_trainer.fit(mnist_classifier, mnist_train_loader, mnist_val_loader)

## Model Alignement

In [355]:
eeg_encoder = autoencoder.get_encoder()
mnist_encoder = mnist_classifier.feature_extractor

In [356]:
eeg_train_encoded_dataset = EncodedDataset(eeg_train_dataloader, eeg_encoder)
eeg_val_encoded_dataset = EncodedDataset(eeg_val_dataloader, eeg_encoder)

In [357]:
mnist_train_encoded_data = EncodedDataset(mnist_train_loader, mnist_encoder)
mnist_val_encoded_data = EncodedDataset(mnist_val_loader, mnist_encoder)

In [358]:
latent_train_dataset = LatentDataset(
    eeg_train_encoded_dataset, mnist_train_encoded_data
)
latent_train_dataloader = DataLoader(latent_train_dataset, batch_size=32, shuffle=True)

latent_val_dataset = LatentDataset(eeg_val_encoded_dataset, mnist_train_encoded_data)
latent_val_dataloader = DataLoader(latent_val_dataset, batch_size=32, shuffle=False)

In [359]:
loss_tracker = LossTracker()
latent_projection = LatentProjection(latent_dim,latent_dim)
latent_trainer = Trainer(max_epochs=10)
latent_trainer.fit(latent_projection, latent_train_dataloader, latent_val_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name       | Type       | Params
------------------------------------------
0 | projection | Sequential | 58.6 K
1 | criterion  | MSELoss    | 0     
------------------------------------------
58.6 K    Trainable params
0         Non-trainable params
58.6 K    Total params
0.234     Total estimated model params size (MB)


Epoch 3:  63%|██████▎   | 1090/1719 [00:11<00:06, 96.98it/s, v_num=2, train_loss_step=11.70, val_loss_step=8.510, val_loss_epoch=11.00, train_loss_epoch=11.00]