In [2]:
import pandas as pd
import random
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class EEG_Dataset(Dataset):
    def __init__(self, dataset):
        data_df = dataset.to_pandas()
        features_names = list(data_df.columns)[1:]
        self.electrodes = list(set([f.split("-")[0] for f in features_names]))
        self.nb_channel = len(self.electrodes)
        self.nb_samples = (max([int(f.split("-")[1]) for f in features_names]))+1
        self.datas = self._norm(data_df)
        self.features = self.datas[:,1:]
        self.labels = list(self.datas[:,0].astype(int))


    def _norm(self, data_df):
        electrodes_min = {}
        electrodes_max = {}
        for electrode in self.electrodes :
            electrodes_min[electrode] = data_df[[f'{electrode}-{i}'for i in range(self.nb_samples)]].to_numpy().reshape(-1,).max()
            electrodes_max[electrode] = data_df[[f'{electrode}-{i}'for i in range(self.nb_samples)]].to_numpy().reshape(-1,).min()

        for electrode in self.electrodes:
            data_df[[f'{electrode}-{i}'for i in range(self.nb_samples)]] = (data_df[[f'{electrode}-{i}'for i in range(self.nb_samples)]] - electrodes_min[electrode])/(electrodes_max[electrode] -electrodes_min[electrode])
        return data_df.to_numpy()


    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.tensor(self.features[idx]).view(self.nb_channel,-1).float(), self.labels[idx]

In [160]:
class EEG_Spectogram_Dataset(Dataset):
    def __init__(self, dataset):
        data_df = dataset.to_pandas()
        features_names = list(data_df.columns)[1:]
        self.electrodes = list(set([f.split("-")[0] for f in features_names]))
        self.nb_channel = len(self.electrodes)
        self.nb_samples = (max([int(f.split("-")[1]) for f in features_names])) + 1
        datas = data_df.to_numpy()
        self.features = datas[:, 1:]
        self.labels = list(datas[:, 0].astype(int))
        self.spectrograms = self._norm(self._spectogram(features=self.features)).float() 
        del datas

    def _spectogram(self, features, n_fft=256, hop_length=10, fs=128, lowcut=0.5):
        features = torch.tensor(
            features.reshape(-1, features.shape[-1] // self.nb_channel)
        )
        window = torch.hann_window(n_fft)
        stft_result = torch.stft(
            features,
            n_fft=n_fft,
            hop_length=hop_length,
            window=window,
            return_complex=True,
        )
        spectrogram = torch.abs(stft_result)
        spectrogram = spectrogram.reshape(
            -1, self.nb_channel, spectrogram.shape[-2], spectrogram.shape[-1]
        )
        # Remove frequencies below lowcut Hz
        freqs = np.linspace(0, fs // 2, spectrogram.shape[2])
        start_freq = np.argwhere(freqs > lowcut)[0].item()
        spectrogram = spectrogram[:, :, start_freq:, :]
        return spectrogram

    def _norm(self, spectrogram):
        spectrogram_min = torch.amin(spectrogram, dim=(0, 2, 3), keepdim=True)
        spectrogram_max = torch.amax(spectrogram, dim=(0, 2, 3), keepdim=True)
        return (spectrogram - spectrogram_min) / (spectrogram_max - spectrogram_min)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.spectrograms[idx], self.labels[idx]

In [3]:
class UNetAutoencoder(pl.LightningModule):
    def __init__(self, input_channels=5, sequence_length=256, latent_dim=64, dropout_rate=0.2, mask_ratio=0.15):
        super(UNetAutoencoder, self).__init__()

        self.input_channels = input_channels
        self.sequence_length = sequence_length
        self.mask_ratio = mask_ratio

        # Encoder
        self.enc1 = self.conv_block(input_channels, 32, dropout_rate)
        self.enc2 = self.conv_block(32, 64, dropout_rate)
        self.enc3 = self.conv_block(64, 128, dropout_rate)
        self.enc4 = self.conv_block(128, 256, dropout_rate)

        # Calculate the size of the encoder output
        self.encoder_output_size = sequence_length // 16 * 256

        # Latent space
        self.fc1 = nn.Sequential(
            nn.Linear(self.encoder_output_size, latent_dim),
            nn.Dropout(dropout_rate)
        )
        self.fc2 = nn.Linear(latent_dim, self.encoder_output_size)

        # Decoder
        self.dec4 = self.conv_block(256, 128, dropout_rate, transpose=True)
        self.dec3 = self.conv_block(256, 64, dropout_rate, transpose=True)
        self.dec2 = self.conv_block(128, 32, dropout_rate, transpose=True)
        self.dec1 = nn.ConvTranspose1d(64, input_channels, kernel_size=3, stride=2, padding=1, output_padding=1)

        self.criterion = nn.MSELoss()

    def conv_block(self, in_channels, out_channels, dropout_rate, transpose=False):
        if not transpose:
            return nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm1d(out_channels),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm1d(out_channels),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose1d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm1d(out_channels),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm1d(out_channels),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            )

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)

        # Flatten
        x = e4.view(e4.size(0), -1)

        # Latent space
        x = self.fc1(x)
        x = self.fc2(x)

        # Reshape for decoder
        x = x.view(x.size(0), 256, -1)

        # Decoder with skip connections
        x = self.dec4(x)
        x = torch.cat([x, e3], dim=1)
        x = self.dec3(x)
        x = torch.cat([x, e2], dim=1)
        x = self.dec2(x)
        x = torch.cat([x, e1], dim=1)
        x = self.dec1(x)

        return x

    def get_encoder(self):
        return nn.Sequential(
            self.enc1, self.enc2, self.enc3, self.enc4,
            nn.Flatten(),
            self.fc1
        )

    def apply_mask(self, x):
        batch_size, _, _ = x.shape
        mask = torch.rand(batch_size, self.input_channels, self.sequence_length, device=x.device) > self.mask_ratio
        return x * mask, mask

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x_masked, mask = self.apply_mask(x)
        x_hat = self.forward(x_masked)
        loss = self.criterion(x_hat * mask, x * mask)  # Compute loss only on unmasked parts
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x_masked, mask = self.apply_mask(x)
        x_hat = self.forward(x_masked)
        loss = self.criterion(x_hat * mask, x * mask)  # Compute loss only on unmasked parts
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=False):
        super(ResidualBlock, self).__init__()
        stride = 2 if downsample else 1
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.norm1 = nn.BatchNorm1d(out_channels)
        self.mish = nn.Mish()
        
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.BatchNorm1d(out_channels)
        
        # Skip connection
        self.downsample = downsample
        if downsample or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm1d(out_channels)
            )
        else:
            self.skip = nn.Identity()

    def forward(self, x):
        identity = self.skip(x)
        out = self.mish(self.norm1(self.conv1(x)))
        out = self.norm2(self.conv2(out))
        out += identity
        return self.mish(out)

class ResNetClassifier(pl.LightningModule):
    def __init__(self, input_channels=5, sequence_length=256, num_classes=10, dropout_rate=0.5):  # Higher dropout
        super(ResNetClassifier, self).__init__()
        
        # Initial layer
        self.initial_conv = nn.Sequential(
            nn.Conv1d(input_channels, 16, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm1d(16),
            nn.Mish(),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        )
        
        # Reduced Residual Blocks
        self.layer1 = ResidualBlock(16, 32, downsample=True)
        self.layer2 = ResidualBlock(32, 64, downsample=True)
        
        # Calculate size before FC layers
        reduced_length = sequence_length // 4
        self.fc_input_size = 1024
        
        # Fully connected layers with residual connection
        self.fc1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.fc_input_size, 256),  # Reduced size
            nn.Mish(),
            nn.Dropout(dropout_rate)
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(256, 256),  # Reduced size
            nn.Mish(),
            nn.Dropout(dropout_rate)
        )
        
        # Output layer
        self.classifier = nn.Linear(256, num_classes)
        
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.initial_conv(x)
        x = self.layer1(x)
        x = self.layer2(x)
        
        # Residual connection in fully connected layers
        x = self.fc1(x)
        x = x + self.fc2(x)
        logits = self.classifier(x)
        
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_accuracy", accuracy, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_accuracy", accuracy, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.01, weight_decay=1e-3)
        
        # Cosine Annealing Scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
        
        return [optimizer], [scheduler]


In [40]:
class CNNClassifier(pl.LightningModule):
    def __init__(self, input_channels=5, sequence_length=256, num_classes=10, dropout_rate=0.2):
        super(CNNClassifier, self).__init__()
        
        # Layer 1: Conv + LeakyReLU
        self.conv1 = nn.Conv1d(input_channels, input_channels, kernel_size=5, stride=1, padding=2)
        self.leaky1 = nn.LeakyReLU(negative_slope=0.1)
        self.norm1 = nn.BatchNorm1d(input_channels)
        
        # Layer 3: Conv + LeakyReLU
        self.conv2 = nn.Conv1d(input_channels, 3, kernel_size=5, stride=1, padding=2)
        self.leaky2 = nn.LeakyReLU(negative_slope=0.1)
        self.norm2 = nn.BatchNorm1d(3)
        
        # Layer 5: Conv + LeakyReLU
        self.conv3 = nn.Conv1d(3, 3, kernel_size=5, stride=1, padding=2)
        self.leaky3 = nn.LeakyReLU(negative_slope=0.1)
        self.norm3 = nn.BatchNorm1d(3)
        
        # Calculate size before FC layers
        self.fc_input_size = 3 * sequence_length  # 10 channels * sequence_length
        
        # Layer 7: Full-connected
        self.fc1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.fc_input_size, self.fc_input_size),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Dropout(dropout_rate)
        )
        
        # Layer 8: Full-connected
        self.fc2 = nn.Sequential(
            nn.Linear(self.fc_input_size, self.fc_input_size//4),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Dropout(dropout_rate)
        )
        
        # Output layer
        self.classifier = nn.Linear(self.fc_input_size//4, num_classes)
        
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        # Convolutional layers
        x = self.conv1(x)
        x = self.leaky1(x)
        x = self.norm1(x)
        
        x = self.conv2(x)
        x = self.leaky2(x)
        x = self.norm2(x)
        
        x = self.conv3(x)
        x = self.leaky3(x)
        x = self.norm3(x)
        
        # Fully connected layers
        x = self.fc1(x)
        x = self.fc2(x)
        
        # Classification head
        logits = self.classifier(x)
        
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_accuracy", accuracy, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_accuracy", accuracy, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-2)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.2, patience=2, verbose=True, min_lr=0.000001,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

In [177]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_convs=2, pool_kernel=2):
        super(ConvBlock, self).__init__()
        layers = []
        for _ in range(num_convs):
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
            layers.append(nn.ReLU(inplace=True))
            in_channels = out_channels  # Update in_channels for the next conv layer
        layers.append(nn.MaxPool2d(kernel_size=pool_kernel, stride=2))  # Downsampling
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class VGGish(pl.LightningModule):
    def __init__(self, num_classes=10):
        super(VGGish, self).__init__()
        
        # Define VGG-like convolutional blocks using ConvBlock
        self.features = nn.Sequential(
            ConvBlock(5, 64, num_convs=2),    # Block 1
            ConvBlock(64, 128, num_convs=2),  # Block 2
            ConvBlock(128, 256, num_convs=3), # Block 3
            ConvBlock(256, 512, num_convs=3), # Block 4
        )

        # Fully connected layers
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 7 * 1, 4096),  # Adjusted input size based on downsampling
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes),  # Output for MNIST (10 classes)
        )

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_accuracy", accuracy, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        accuracy = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_accuracy", accuracy, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-2)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode="min",
            factor=0.2,
            patience=2,
            verbose=True,
            min_lr=0.000001,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

In [6]:
class MNISTClassifier(pl.LightningModule):
    def __init__(self, latent_dim=64, num_classes=10):
        super(MNISTClassifier, self).__init__()
        
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, latent_dim)
        )
        
        self.classifier = nn.Linear(latent_dim, num_classes)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        features = self.feature_extractor(x)
        return self.classifier(features), features

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits, _ = self(x)
        loss = self.criterion(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits, _ = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss)
        self.log("val_acc", acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [7]:
class EncodedDataset:
    def __init__(self, dataloader, encoder):
        super().__init__()
        self.encoded_datas, self.labels = self._encode_dataset(dataloader, encoder)

    def _encode_dataset(self, dataloader, encoder):
        encoded_datas = []
        labels = []
        for batch in dataloader:
            encoded_data = encoder(batch[0]).detach()
            encoded_datas.append(encoded_data)
            labels.append(batch[1])
        encoded_datas = torch.cat(encoded_datas)
        labels = torch.cat(labels)
        return encoded_datas, labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.encoded_datas[idx], self.labels[idx].item()


In [8]:
class LatentDataset(torch.utils.data.Dataset):
    def __init__(self, eeg_data, mnist_data):
        self.paired_data = self._pair_data(eeg_data, mnist_data)

    def _pair_data(self, eeg_data, mnist_data):
        eeg_by_label = {i: [] for i in range(10)}
        mnist_by_label = {i: [] for i in range(10)}

        # Group EEG data by label
        for eeg, label in eeg_data: 
            if 0 <= label < 10:
                eeg_by_label[label].append(eeg)

        # Group MNIST data by label
        for img, label in mnist_data:
            mnist_by_label[label].append(img)

        # Pair EEG and MNIST data
        paired_data = []
        for label in eeg_by_label.keys():
            eeg_samples = eeg_by_label[label]
            mnist_samples = mnist_by_label[label]

            # Use the maximum number of samples available
            n_samples = max(len(eeg_samples), len(mnist_samples))

            # Replicate samples if necessary
            if len(eeg_samples) < n_samples:
                eeg_samples = self._replicate_samples(eeg_samples, n_samples)
            if len(mnist_samples) < n_samples:
                mnist_samples = self._replicate_samples(mnist_samples, n_samples)

            for eeg, mnist in zip(eeg_samples, mnist_samples):
                paired_data.append((eeg, mnist, label))

        return paired_data

    def _replicate_samples(self, samples, target_size):
        """Replicate samples randomly to reach the target size."""
        while len(samples) < target_size:
            samples.append(random.choice(samples))
        return samples

    def __len__(self):
        return len(self.paired_data)

    def __getitem__(self, idx):
        return self.paired_data[idx]

In [9]:
class LatentProjection(pl.LightningModule):
    def __init__(self, eeg_latent_dim, mnist_latent_dim, hidden_dims=[256, 128]):
        super(LatentProjection, self).__init__()
        
        layers = []
        input_dim = eeg_latent_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.2)
            ])
            input_dim = hidden_dim
        
        layers.append(nn.Linear(input_dim, mnist_latent_dim))
        
        self.projection = nn.Sequential(*layers)
        self.criterion = nn.MSELoss()

    def forward(self, x):
        return self.projection(x)

    def training_step(self, batch, batch_idx):
        eeg_latent, mnist_latent, _ = batch
        projected = self(eeg_latent)
        loss = self.criterion(projected, mnist_latent)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)        
        return loss

    def validation_step(self, batch, batch_idx):
        eeg_latent, mnist_latent, _ = batch
        projected = self(eeg_latent)
        val_loss = self.criterion(projected, mnist_latent)
        self.log("val_loss", val_loss, on_step=True, on_epoch=True, prog_bar=True)        
        return val_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

In [181]:
class LossTracker(Callback):
    def __init__(self):
        self.train_losses = []
        self.val_losses = []

    def on_train_epoch_end(self, trainer, pl_module):
        self.train_losses.append(trainer.callback_metrics["train_loss"].item())

    def on_validation_epoch_end(self, trainer, pl_module):
        self.val_losses.append(trainer.callback_metrics["val_loss"].item())

In [11]:
def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.show()

    # Print final losses
    print(f"Final training loss: {train_losses[-1]}")
    print(f"Final validation loss: {val_losses[-1]}")

In [None]:
def plot_spectrogram(spectrogram, batch_idx, channel_idx, fs=128, hop_length=10, freq_limit=(0.5, 30)):
    """
    Plots the spectrogram for a given batch and channel.

    Parameters:
    - spectrogram: The spectrogram data (numpy array).
    - batch_idx: Index of the batch/sample to plot.
    - channel_idx: Index of the channel to plot.
    - fs: Sampling rate (default is 128).
    - hop_length: Hop length for the spectrogram (default is 1).
    - freq_limit: Tuple indicating the frequency range to display (default is (0.5, 5) Hz).
    """
    # Extract the spectrogram for the selected sample and channel
    sample_spectrogram = spectrogram[batch_idx, channel_idx].numpy()  # Shape: (freq_bins, time_steps)

    # Generate frequency and time axes to match the shape of `sample_spectrogram`
    freqs = np.linspace(0, fs // 2, sample_spectrogram.shape[0])  # Frequency axis
    times = np.arange(sample_spectrogram.shape[1]) * hop_length / fs  # Time axis

    # Filter frequencies to show only from 1 Hz
    freq_mask = (freqs > freq_limit[0]) & (freqs <= freq_limit[1])
    filtered_spectrogram = sample_spectrogram[freq_mask, :]
    filtered_freqs = freqs[freq_mask]

    # Plot the spectrogram for visualization
    plt.figure(figsize=(10, 5))
    plt.pcolormesh(times, filtered_freqs, filtered_spectrogram, shading='gouraud')
    plt.colorbar(label='Magnitude')
    plt.xlabel('Time [sec]')
    plt.ylabel('Frequency [Hz]')
    plt.title(f'Spectrogram - Sample {batch_idx}, Channel {channel_idx} ({freq_limit[0]} Hz and above)')
    plt.ylim(freq_limit)  # Optionally limit to the specified frequency range
    plt.show()

In [167]:
def create_dataloaders(dataset, batch_size=128, train_split=0.8, seed=42):
    total_size = len(dataset)
    train_size = int(train_split * total_size)
    val_size = total_size - train_size

    # Use random_split to create train and validation datasets
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size], 
                                              generator=torch.Generator().manual_seed(seed))

    # Create DataLoaders for the train and validation datasets
    train_dataloader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
    )

    val_dataloader = DataLoader(
        dataset=val_dataset,
        batch_size=batch_size,
        shuffle=False,
    )

    return train_dataloader, val_dataloader

In [4]:
train_dataset = load_dataset("DavidVivancos/MindBigData2022_MNIST_IN", split="train")

## Raw EEG

In [168]:
full_eeg_dataset = EEG_Dataset(train_dataset)
eeg_train_dataloader, eeg_val_dataloader = create_dataloaders(full_eeg_dataset, batch_size=128, train_split=0.8, seed=42)

In [None]:
loss_tracker = LossTracker()
cnn_classifier = CNNClassifier(input_channels=5, sequence_length=256, dropout_rate=.2)
trainer = Trainer(max_epochs=50, callbacks=[loss_tracker])
trainer.fit(cnn_classifier, eeg_train_dataloader, eeg_val_dataloader)

## Spectogram EEG

In [179]:
full_eeg_spectogram_dataset = EEG_Spectogram_Dataset(train_dataset)
eeg_spectogram_train_dataloader, eeg_spectogram_val_dataloader = create_dataloaders(full_eeg_spectogram_dataset, batch_size=128, train_split=0.8, seed=42)

In [182]:
loss_tracker = LossTracker()
vggish = VGGish()
trainer = Trainer(max_epochs=50, callbacks=[loss_tracker])
trainer.fit(vggish, eeg_spectogram_train_dataloader, eeg_spectogram_val_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/camilziane/My-projects/EEG/eeg-mnist/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name       | Type             | Params
------------------------------------------------
0 | features   | Sequential       | 7.6 M 
1 | classifier | Sequential       | 31.5 M
2 | criterion  | CrossEntropyLoss | 0     
------------------------------------------------
39.1 M    Trainable

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Users/camilziane/My-projects/EEG/eeg-mnist/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


                                                                           

/Users/camilziane/My-projects/EEG/eeg-mnist/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 0:   3%|▎         | 2/66 [00:22<12:12,  0.09it/s, v_num=31, train_loss_step=2.300, train_accuracy_step=0.102] 

/Users/camilziane/My-projects/EEG/eeg-mnist/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


## UNET

In [None]:
# checkpoint_path = "lightning_logs/version_16/checkpoints/epoch=9-step=5220.ckpt"
# autoencoder = CNNAutoencoder.load_from_checkpoint(checkpoint_path)
latent_dim = 64
loss_tracker = LossTracker()
autoencoder = UNetAutoencoder(latent_dim=latent_dim)
trainer = Trainer(max_epochs=10, callbacks=[loss_tracker])
trainer.fit(autoencoder, eeg_train_dataloader, eeg_val_dataloader)

## MNIST

In [328]:
transform = transforms.Compose([transforms.ToTensor()])
mnist_full = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
mnist_train_dataset, mnist_val_dataset = random_split(mnist_full, [55000, 5000])
mnist_train_loader = DataLoader(mnist_train_dataset, batch_size=64, shuffle=True)
mnist_val_loader = DataLoader(mnist_val_dataset, batch_size=64)

In [329]:
checkpoint_path = "lightning_logs/version_14/checkpoints/epoch=9-step=8600.ckpt"
mnist_classifier = MNISTClassifier.load_from_checkpoint(checkpoint_path)
# mnist_classifier = MNISTClassifier(latent_dim=latent_dim)
# mnist_trainer = pl.Trainer(max_epochs=3)
# mnist_trainer.fit(mnist_classifier, mnist_train_loader, mnist_val_loader)

## Model Alignement

In [355]:
eeg_encoder = autoencoder.get_encoder()
mnist_encoder = mnist_classifier.feature_extractor

In [356]:
eeg_train_encoded_dataset = EncodedDataset(eeg_train_dataloader, eeg_encoder)
eeg_val_encoded_dataset = EncodedDataset(eeg_val_dataloader, eeg_encoder)

In [357]:
mnist_train_encoded_data = EncodedDataset(mnist_train_loader, mnist_encoder)
mnist_val_encoded_data = EncodedDataset(mnist_val_loader, mnist_encoder)

In [358]:
latent_train_dataset = LatentDataset(
    eeg_train_encoded_dataset, mnist_train_encoded_data
)
latent_train_dataloader = DataLoader(latent_train_dataset, batch_size=32, shuffle=True)

latent_val_dataset = LatentDataset(eeg_val_encoded_dataset, mnist_train_encoded_data)
latent_val_dataloader = DataLoader(latent_val_dataset, batch_size=32, shuffle=False)

In [359]:
loss_tracker = LossTracker()
latent_projection = LatentProjection(latent_dim,latent_dim)
latent_trainer = Trainer(max_epochs=10)
latent_trainer.fit(latent_projection, latent_train_dataloader, latent_val_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name       | Type       | Params
------------------------------------------
0 | projection | Sequential | 58.6 K
1 | criterion  | MSELoss    | 0     
------------------------------------------
58.6 K    Trainable params
0         Non-trainable params
58.6 K    Total params
0.234     Total estimated model params size (MB)


Epoch 3:  63%|██████▎   | 1090/1719 [00:11<00:06, 96.98it/s, v_num=2, train_loss_step=11.70, val_loss_step=8.510, val_loss_epoch=11.00, train_loss_epoch=11.00]