# Local File

In [None]:
!pip install pytorch_lightning --quiet
!pip install wandb --quiet
!pip install pandas --quiet
!pip install numpy --quiet
!pip install scikit-learn --quiet
!pip install matplotlib --quiet
!pip install seaborn --quiet

In [1]:
# Import
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader, random_split
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import wandb
import pprint
import os
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import random
import torch.nn.functional as F

In [2]:
# Reproducibility
generator = torch.Generator()
generator.manual_seed(0)

<torch._C.Generator at 0x18f39d3f0>

In [3]:
# Constants
file_1_dataset  = "/content/mydrive/MyDrive/ColabNotebooks/EAI_Napoli/Dataset/Copia di Copy of fused_transposed_GG1.csv"
dataset_path_list = [
    'Copia di S1E1.csv', # Task 4 --> 0
    'Copia di S1E2.csv', # Baseline 1 --> 1
    'Copia di S1E3.csv', # Baseline 2 --> 2
    'Copia di S1E4.csv', # Task 1 --> 3
    'Copia di S2E1.csv', # Baseline 1 --> 4
    'Copia di S2E2.csv', # Baseline 2 --> 5
    'Copia di S2E3.csv', # Task 2 --> 6
    'Copia di S2E4.csv', # Task 1 --> 7
    'Copia di S3E1.csv', # Baseline 1 --> 8
    'Copia di S3E2.csv', # Task 1 --> 9
    'Copia di S3E3.csv', # Task 3 --> 10
    'Copia di S3E4.csv', # Task 2 --> 11
    'Copia di S4E1.csv', # Task 1 --> 12
    'Copia di S4E2.csv', # Task 1 --> 13
    'Copia di S4E3.csv', # Task 3  --> 14
    'Copia di S4E4.csv', # Task 3 --> 15
    'Copia di S5E1.csv', # Task 1 --> 16
    'Copia di S5E2.csv', # Baseline 1 --> 17
    'Copia di S5E3.csv', # Task 1 --> 18
    'Copia di S5E4.csv', # Task 4 --> 19
    'Copia di S6E1.csv', # Task 1 --> 20
    'Copia di S6E2.csv', # Task 3 --> 21
    'Copia di S6E3.csv', # Baseline 2 --> 22
    'Copia di S6E4.csv', # Task 3 --> 23
    'Copia di S7E1.csv', # Task 2 --> 24
    'Copia di S7E2.csv', # Task 1 --> 25
    'Copia di S7E3.csv', # Task 4 --> 26
    'Copia di S7E4.csv', # Task 1 --> 27
]

batch_size = 32
window_size = 1 # size of the window to consider when selecting a sample, then a sample will be composed of window_size rows
enable_wandb = False

dataset_base_path = "Dataset"
project_base_path = "/"


# Definition of an utity object that divides the dataset files according to the task they belong to
dataset_task_mapping = {}
dataset_task_mapping['task_1'] = []

# Task 1
task_1_file_index = [3, 7, 9, 12, 13, 16, 18, 20, 25, 27]
for index in task_1_file_index:
    dict_file = {}
    dict_file['file_path'] = os.path.join(dataset_base_path, dataset_path_list[index])
    #file_name
    dict_file['file_name'] = dataset_path_list[index].split("/")[-1]
    #remove the extension
    dict_file['file_name'] = dict_file['file_name'].split(".")[0]
    dataset_task_mapping['task_1'].append(dict_file)

# Task 2
dataset_task_mapping['task_2'] = []
task_2_file_index = [6, 11, 24]
for index in task_2_file_index:
    dict_file = {}
    dict_file['file_path'] = os.path.join(dataset_base_path, dataset_path_list[index])
    dict_file['file_name'] = dataset_path_list[index].split("/")[-1]
    dict_file['file_name'] = dict_file['file_name'].split(".")[0]
    dataset_task_mapping['task_2'].append(dict_file)

# Task 3
dataset_task_mapping['task_3'] = []
task_3_file_index = [10, 14, 21, 23, 26]
for index in task_3_file_index:
    dict_file = {}
    dict_file['file_path'] = os.path.join(dataset_base_path, dataset_path_list[index])
    dict_file['file_name'] = dataset_path_list[index].split("/")[-1]
    dict_file['file_name'] = dict_file['file_name'].split(".")[0]
    dataset_task_mapping['task_3'].append(dict_file)

# Task 4
dataset_task_mapping['task_4'] = []
task_4_file_index = [0, 19]
for index in task_4_file_index:
    dict_file = {}
    dict_file['file_path'] = os.path.join(dataset_base_path, dataset_path_list[index])
    dict_file['file_name'] = dataset_path_list[index].split("/")[-1]
    dict_file['file_name'] = dict_file['file_name'].split(".")[0]
    dataset_task_mapping['task_4'].append(dict_file)

# Baseline 1
dataset_task_mapping['baseline_1'] = []
baseline_1_file_index = [1, 4, 8, 17]
for index in baseline_1_file_index:
    dict_file = {}
    dict_file['file_path'] = os.path.join(dataset_base_path, dataset_path_list[index])
    dict_file['file_name'] = dataset_path_list[index].split("/")[-1]
    dict_file['file_name'] = dict_file['file_name'].split(".")[0]
    dataset_task_mapping['baseline_1'].append(dict_file)

# Baseline 2
dataset_task_mapping['baseline_2'] = []
baseline_2_file_index = [2, 5, 22]
for index in baseline_2_file_index:
    dict_file = {}
    dict_file['file_path'] = os.path.join(dataset_base_path, dataset_path_list[index])
    dict_file['file_name'] = dataset_path_list[index].split("/")[-1]
    dict_file['file_name'] = dict_file['file_name'].split(".")[0]
    dataset_task_mapping['baseline_2'].append(dict_file)

# Approach 1

## Dataset

In [None]:
class DatasetApproach1(Dataset):
  def __init__(self, task_name, window_size=1):
    assert task_name in ['task_1','task_2','task_3','task_4'] , "The task_name is not valid (must be one of ['task_1','task_2','task_3','task_4'])"

    self.task_name = task_name
    self.dataset = []

    # Create a unique dataframe that is composed by the concatenation of all the files that belong to the task + the baselines files
    for file in dataset_task_mapping[task_name]:
      df = pd.read_csv(file['file_path'])
      self.dataset.append(df)

    # Baseline 1
    for file in dataset_task_mapping['baseline_1']:
      df = pd.read_csv(file['file_path'])
      self.dataset.append(df)

    # Baseline 2
    for file in dataset_task_mapping['baseline_2']:
      df = pd.read_csv(file['file_path'])
      self.dataset.append(df)

    # Concatenate the dataframes
    print(f"Concatenating the dataframes ({len(self.dataset)})")
    self.dataset = pd.concat(self.dataset)
    # Create a dataframe
    print(f"Dataset shape: {self.dataset.shape}")


    # Windowing
    self.window_size = window_size

  def __len__(self):
    return len(self.dataset) - self.window_size

  def __getitem__(self, idx):
    # return as a tensor
    print(f"Index: {idx}")
    return torch.tensor(self.dataset.iloc[idx].values)

  def get_dataframe(self):
    return self.dataset

## Task 1

In [None]:
# Create the dataset
selected_task = 'task_1'
dataset = DatasetApproach1(selected_task, window_size=window_size)
print(f"Dataset length: {len(dataset)}, Number of files used ({len(dataset_task_mapping[selected_task])} + {len(dataset_task_mapping['baseline_1'])} + {len(dataset_task_mapping['baseline_2'])})")

## Task 2

In [None]:
# Create the dataset
selected_task = 'task_2'
dataset = DatasetApproach1(selected_task, window_size=window_size)
print(f"Dataset length: {len(dataset)}, Number of files used ({len(dataset_task_mapping[selected_task])} + {len(dataset_task_mapping['baseline_1'])} + {len(dataset_task_mapping['baseline_2'])})")

## Task 3

In [None]:
# Create the dataset
selected_task = 'task_3'
dataset = DatasetApproach1(selected_task, window_size=window_size)
print(f"Dataset length: {len(dataset)}, Number of files used ({len(dataset_task_mapping[selected_task])} + {len(dataset_task_mapping['baseline_1'])} + {len(dataset_task_mapping['baseline_2'])})")

## Task 4

In [None]:
# Create the dataset
selected_task = 'task_4'
dataset = DatasetApproach1(selected_task, window_size=window_size)
print(f"Dataset length: {len(dataset)}, Number of files used ({len(dataset_task_mapping[selected_task])} + {len(dataset_task_mapping['baseline_1'])} + {len(dataset_task_mapping['baseline_2'])})")

## Dataframe 

In [None]:
# check unique values of column "labels"
print(f"Unique values of the labels: {dataset.get_dataframe()['labels'].unique()}")

# Shuffle the rows of the dataset using sklearn (making sure the shuffle is reproducible)
from sklearn.utils import shuffle
data = shuffle(dataset.get_dataframe(), random_state=0)
# Remove the index column
data = data.reset_index(drop=True)

# Splitting into train and test sets (80% training data, 20% testing data)
train_df, test_df = train_test_split(data, test_size=0.15, random_state=42)

# Splitting the train_df further into train and validation sets (70% training data, 30% validation data)
train_df, val_df = train_test_split(train_df, test_size=0.15, random_state=42)

print(f"Data: {len(data)} ,Train size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}")

In [None]:
#Create the Dataframe classe Approach 1
class DataFrameApproach1(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe.iloc[:, :-1].values
        self.targets = dataframe['labels'].values

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx])
        y = self.targets[idx]
        return x, y

def collate_fn(batch):
    data = [item[0] for item in batch]
    targets = [item[1] for item in batch]

    # Apply min-max normalization to each column
    scaler = MinMaxScaler()
    normalized_data = scaler.fit_transform(data)

    return torch.tensor(normalized_data), targets

# Creating datasets and data loaders for each split
train_dataset = DataFrameApproach1(train_df)
val_dataset = DataFrameApproach1(val_df)
test_dataset = DataFrameApproach1(test_df)

# NOTE -> Dataloader are created in the HPO sweeps

## AutoEncoder Model

### Encoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, window_size = 1, enable_sparsity_loss=False):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(128),
        )

        # Apply He initialization to the linear layers
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, x):
        batch_size, input_dim = x.size()  # Obtain the shape of the input [bs, input_dim]
        input = x
        x = self.encoder(input)
        return x

### Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, input_dim, window_size, enable_sparsity_loss=False):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),
            nn.Linear(256, window_size * input_dim),
            nn.Sigmoid()
        )

        # Apply He initialization to the linear layers
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, x):
        # x = x.to(torch.float32)
        x = self.decoder(x)
        return x

### AE Lightning Module ⚡️⚡️

In [None]:
class Autoencoder(LightningModule):
    def __init__(self, input_dim, batch_size, sparsity_factor=0.1, sparsity_loss_coef = 1e-3, weight_decay=0.001, window_size=window_size, enable_sparsity_loss=False, enable_weight_decay_loss=False ,enable_non_negativity_constraint=False,enable_wandb = False):
        super(Autoencoder, self).__init__()

        if( enable_sparsity_loss == True and enable_non_negativity_constraint== True):
          print("The combination of constraints enable_sparsity_loss and enable_non_negativity_constraint both true leads to error in to the model matrix multiplication. This will be solved by setting enable_non_negativity_constraint to False.")

        self.save_hyperparameters()
        self.encoder = Encoder(input_dim=input_dim, window_size=window_size, enable_sparsity_loss = enable_sparsity_loss)
        self.decoder = Decoder(input_dim=input_dim, window_size=window_size, enable_sparsity_loss = enable_sparsity_loss)
        self.train_loss_memory = []
        self.train_rec_loss_memory = []

        self.val_loss_memory = []
        self.val_rec_loss_memory = []

        self.test_loss_memory = []
        self.test_rec_loss_memory = []

        self.batch_size = batch_size



        # --- Loss Settings
        self.enable_sparsity_loss = enable_sparsity_loss
        if enable_sparsity_loss:
          self.sparsity_loss_coef = sparsity_loss_coef
          self.sparsity_factor = sparsity_factor
          print(f"Enabled Sparsity term in the loss with sparsity loss coeff => {self.sparsity_loss_coef} and sparsity factor=>{self.sparsity_factor}")

          # self.sparsity_loss = nn.KLDivLoss(reduction='batchmean')
          # Memory logs for sparsity
          self.train_sparsity_loss_memory = []
          self.val_sparsity_loss_memory = []
          self.test_sparsity_loss_memory = []

          self.enable_non_negativity_constraint = False
        else:
          self.enable_non_negativity_constraint = enable_non_negativity_constraint
          if enable_non_negativity_constraint:
            print("Enabled non negativity constraint")


        self.enable_weight_decay_loss = enable_weight_decay_loss
        if enable_weight_decay_loss:
          print("Enabled weight decay")
          self.weight_decay = weight_decay

        self.wandb_log = enable_wandb

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        if torch.cuda.is_available():
            if torch.cuda.device_count() > 1:
                device = torch.device('cuda:0')
                print('Using device:', device)
            else:
                device = torch.device('cuda')
                print('Using device:', device)
        else:
            device = torch.device('cpu')
            print('Using device:', device)


        print('Using device:', device)

        self.to(device)
        print(f"Initialized Model on {self.device}")

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

    def kl_div(self, p, p_hat):
      funcs = nn.Sigmoid()
      p_hat = torch.mean(funcs(p_hat), 1)
      p_tensor = torch.Tensor([p] * p_hat.shape[0]).to(self.device)


      return torch.sum(p_tensor * torch.log(p_tensor) - p_tensor * torch.log(p_hat) + (1 - p_tensor) * torch.log(1 - p_tensor) - (1 - p_tensor) * torch.log(1 - p_hat))

    def sparse_loss(self, values):
      loss = 0
      values = values.view(self.batch_size, -1)

      # Encoder sparsity
      lyrs_encoder = list(self.encoder.encoder.children())
      for i, lyr in enumerate(lyrs_encoder):
          if isinstance(lyr, nn.Linear):
            values = lyr(values)
            # loss += self.sparsity_loss(torch.tensor([self.sparsity_factor]).to(self.device), values.to(self.device))
            loss += self.kl_div(self.sparsity_factor, values.to(self.device))

      # Decoder sparsity
      lyrs_decoder = list(self.decoder.decoder.children())
      for i, lyr in enumerate(lyrs_decoder):
          if isinstance(lyr, nn.Linear):
              values = lyr(values)
              # loss += self.sparsity_loss(torch.tensor([self.sparsity_factor]).to(self.device), values.to(self.device))
              loss += self.kl_div(self.sparsity_factor, values.to(self.device))

      return loss

    def calculate_weight_decay_loss(self):
        weight_decay_loss = 0.0
        for param in self.parameters():
            weight_decay_loss += 0.5 * self.weight_decay * torch.norm(param, p=2) ** 2
        return weight_decay_loss

    def enforce_non_negativity(self):
      for param in self.parameters():
        param.data.clamp_(min=0, max=None)

    def training_step(self, batch, batch_idx):
        x = batch[0].to(torch.float32) #[bs, input_dim]
        _, reconstructions = self(x)

        x = x.view(-1) # [bs * input_dim]
        reconstructions = reconstructions.view(-1)

        loss_mse = nn.MSELoss()(reconstructions, x)
        loss = loss_mse

        if self.enable_sparsity_loss:
          # sparsity_loss = self.sparsity_loss(torch.log(reconstructions).to(self.device), torch.tensor([self.sparsity_factor]).to(self.device))
          sparsity_loss = self.sparse_loss(x) * self.sparsity_loss_coef
          loss += sparsity_loss
          self.train_sparsity_loss_memory.append(sparsity_loss)

        if self.enable_weight_decay_loss:
          weight_decay_loss = self.calculate_weight_decay_loss()
          loss += weight_decay_loss

        self.train_loss_memory.append(loss)
        self.train_rec_loss_memory.append(loss_mse)

        if self.wandb_log:
          wandb.log({"train_total_loss": loss})
          wandb.log({"train_reconstruction_loss": loss_mse})

        return loss

    def validation_step(self, batch, batch_idx):
      x = batch[0].to(torch.float32)
      _, reconstructions = self(x)

      x = x.view(-1) #[]
      reconstructions = reconstructions.view(-1)

      loss_mse = nn.MSELoss()(reconstructions, x)
      loss = loss_mse

      if self.enable_sparsity_loss:
        # sparsity_loss = self.sparsity_loss(torch.log(reconstructions).to(self.device), torch.tensor([self.sparsity_factor]).to(self.device))
        sparsity_loss = self.sparse_loss(x) * self.sparsity_loss_coef
        loss += sparsity_loss
        self.val_sparsity_loss_memory.append(sparsity_loss)

      if self.enable_weight_decay_loss:
        weight_decay_loss = self.calculate_weight_decay_loss()
        loss += weight_decay_loss

      if self.enable_non_negativity_constraint:
        self.enforce_non_negativity()

      self.val_loss_memory.append(loss)
      self.val_rec_loss_memory.append(loss_mse)

      if self.wandb_log:
        wandb.log({"val_total_loss": loss})
        wandb.log({"val_reconstruction_loss": loss_mse})


      return loss

    def test_step(self, batch, batch_idx):
      x = batch[0].to(torch.float32)
      _, reconstructions = self(x)

      x = x.view(-1) #[]
      reconstructions = reconstructions.view(-1)

      loss_mse = nn.MSELoss()(reconstructions, x)
      loss = loss_mse

      if self.enable_sparsity_loss:
        # sparsity_loss = self.sparsity_loss(torch.log(reconstructions).to(self.device), torch.tensor([self.sparsity_factor]).to(self.device))
        sparsity_loss = self.sparse_loss(x) * self.sparsity_loss_coef
        loss += sparsity_loss
        self.test_sparsity_loss_memory.append(sparsity_loss)

      if self.enable_weight_decay_loss:
        weight_decay_loss = self.calculate_weight_decay_loss()
        loss += weight_decay_loss

      if self.enable_non_negativity_constraint:
        self.enforce_non_negativity()

      self.test_loss_memory.append(loss)
      self.test_rec_loss_memory.append(loss_mse)


      return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        scheduler = CosineAnnealingLR(optimizer, T_max=10)  # Adjust T_max as needed

        return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'epoch'}}

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
        # step
        optimizer.step(closure=optimizer_closure)

        if self.enable_non_negativity_constraint:
          self.enforce_non_negativity()

    def on_epoch_end(self):
        if self.wandb_log:
            wandb.log({'epoch': self.current_epoch})

    def on_train_epoch_end(self):
        # Access the training loss from the outputs
        train_loss = torch.stack([x for x in self.train_loss_memory]).mean()
        train_rec_loss = torch.stack([x for x in self.train_rec_loss_memory]).mean()

        # Print the training loss
        print_log = f'Training Loss - Epoch {self.current_epoch}: Total Loss => {train_loss.item()} MSE => {train_rec_loss}'

        self.train_loss_memory.clear()
        self.train_rec_loss_memory.clear()

        if self.enable_sparsity_loss:
          train_sparsity_loss = torch.stack([x for x in self.train_sparsity_loss_memory]).mean()
          print_log += f' SPARSE => {train_sparsity_loss}'
          self.train_sparsity_loss_memory.clear()

        if self.wandb_log:
          wandb.log({"train_total_loss": train_loss})
          wandb.log({"train_reconstruction_loss": train_rec_loss})
          if self.enable_sparsity_loss:
            wandb.log({"train_sparse_loss": train_sparsity_loss})

        print(print_log)

    def on_validation_epoch_end(self):
        # Access the training loss from the outputs
        val_loss = torch.stack([x for x in self.val_loss_memory]).mean()
        val_rec_loss = torch.stack([x for x in self.val_rec_loss_memory]).mean()

        # Print the training loss
        print_log = f'Validation Loss - Epoch {self.current_epoch}: Total Loss => {val_loss.item()} MSE => {val_rec_loss}'

        # For early stop and Model checkpoint callbacks
        self.log('val_reconstruction_loss', val_rec_loss.item())


        self.val_loss_memory.clear()
        self.val_rec_loss_memory.clear()

        if self.enable_sparsity_loss:
          val_sparsity_loss = torch.stack([x for x in self.val_sparsity_loss_memory]).mean()
          print_log += f' SPARSE => {val_sparsity_loss}'
          self.val_sparsity_loss_memory.clear()

        if self.wandb_log:
          wandb.log({"val_total_loss": val_loss})
          wandb.log({"val_reconstruction_loss": val_rec_loss})
          if self.enable_sparsity_loss:
            wandb.log({"val_sparse_loss": val_sparsity_loss})


        print(print_log)

    def on_test_epoch_end(self):
        # Access the training loss from the outputs
        test_loss = torch.stack([x for x in self.test_loss_memory]).mean()
        test_rec_loss = torch.stack([x for x in self.test_rec_loss_memory]).mean()

        # Print the training loss
        print_log = f'Test Loss - Epoch {self.current_epoch}: Total Loss => {test_loss.item()} MSE => {test_rec_loss}'

        self.test_loss_memory.clear()
        self.test_rec_loss_memory.clear()

        if self.enable_sparsity_loss:
          test_sparsity_loss = torch.stack([x for x in self.test_sparsity_loss_memory]).mean()
          print_log += f' SPARSE => {test_sparsity_loss}'
          self.test_sparsity_loss_memory.clear()

        if self.wandb_log:
          wandb.log({"test_total_loss": test_loss})
          wandb.log({"test_reconstruction_loss": test_rec_loss})
          if self.enable_sparsity_loss:
            wandb.log({"test_sparse_loss": test_sparsity_loss})

        self.test_rec_loss = test_rec_loss

        print(print_log)

## TRAINING: Hyper Parameter Optimization with Weights and Biases Sweeps 🔎🔎

In [None]:
# WANDB Sweep for HPO
sweep_config = {
    'method': 'bayes'
}
metric = {
  'name': 'val_reconstruction_loss',
  'goal': 'minimize'
}
sweep_config['metric'] = metric
parameters_dict = {
    'batch_size': {
          'values': [64]
        },
    'epochs': {
          'values': [1000]
        },
    'sparsity_factor': {
        'values': [0.1, 0.05, 0.005]
      },
    'wdecay_loss':{
        'values': [True,False]
      },
    'sparsity_loss':{
        'values': [True,False]
      },
    'non_negative_constraint':{
        'values': [True,False]
      }
}

sweep_config['parameters'] = parameters_dict

#Create the sweep
sweep_id = wandb.sweep(sweep_config,entity="rucci-2053183", project="Project_EAI_BrainComputerInterface")

In [None]:
def train(config=None):
  i=0
  with wandb.init(config=config):
    i = i + 1
    config = wandb.config
    if config.sparsity_loss == True and config.non_negative_constraint == True:
      print(f"Skipping following config becouse not supported combination sparsity_loss =>{config.sparsity_loss}, non_negative_constraint =>{config.non_negative_constraint}")
      print(f"Config ==>{config}")
    else:
      # bs given by the agent
      train_dataset = DataFrameApproach1(train_df)
      val_dataset = DataFrameApproach1(val_df)
      test_dataset = DataFrameApproach1(test_df)
      train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)
      val_loader = DataLoader(val_dataset, batch_size=config.batch_size,shuffle=False, collate_fn=collate_fn, drop_last=True)
      test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=collate_fn, drop_last=True)

      batch = next(iter(train_loader))
      input_dim = batch[0].shape[-1]

      # Model
      model = Autoencoder(input_dim=input_dim, batch_size = config.batch_size,sparsity_factor=config.sparsity_factor ,enable_sparsity_loss=config.sparsity_loss, enable_weight_decay_loss=config.wdecay_loss, enable_non_negativity_constraint=config.non_negative_constraint, enable_wandb = True)
      early_stop = EarlyStopping(monitor="val_reconstruction_loss", mode="min", patience=30, min_delta=0.001)


      # Define the ModelCheckpoint callback to save the best model
      checkpoint_callback = ModelCheckpoint(
          dirpath="saved_models/Approach_1/"+selected_task+"/",
          filename="best_model",
          monitor="val_reconstruction_loss",
          mode="min",
          save_top_k=1,
          save_last=True
      )

      trainer = Trainer(max_epochs=config.epochs, default_root_dir="saved_models/Approach_1/", callbacks=[early_stop, checkpoint_callback],fast_dev_run=False)
      trainer.fit(model, train_loader, val_loader)
      trainer.test(model, test_loader)


wandb.agent(sweep_id, train, count=10)

## Evaluation of the BEST AE on test set

In [None]:
batch_size = 64
train_dataset = DataFrameApproach1(train_df)
val_dataset = DataFrameApproach1(val_df)
test_dataset = DataFrameApproach1(test_df)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,shuffle=False, collate_fn=collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, drop_last=True)

In [None]:
batch= next(iter(test_loader))

In [None]:
base_model_dir = "saved_models/Approach_1/"+selected_task+"/lightning_logs"

model_to_test_paths = [ ]

# Add to the list the models .ckpt from the directory /saved_models/Approach_1/selected_task/
for root, dirs, files in os.walk("saved_models/Approach_1/"+selected_task):
  for file in files:
    if file.endswith(".ckpt"):
      model_to_test_paths.append(os.path.join(root, file))

print(f"Models to test => {model_to_test_paths}")

best_metric = 1000000000
best_model = ""
for model_path_ in model_to_test_paths:
  model_path = model_path_
  input_dim = batch[0].shape[-1]
  version = model_path_.split('/')[-1]

  checkpoint_model = Autoencoder(input_dim=input_dim, batch_size = batch_size,sparsity_factor=0.005,enable_sparsity_loss=False, enable_weight_decay_loss=False, enable_non_negativity_constraint=False, enable_wandb = False)

  checkpoint_model.load_state_dict(torch.load(model_path, map_location=checkpoint_model.device)['state_dict']) # ------> PyTorch Lightning API

  trainer = Trainer(accelerator = 'auto', fast_dev_run=False)
  print(f"Evaluation => {version}")
  trainer.test(checkpoint_model, dataloaders=test_loader)

  if(checkpoint_model.test_rec_loss < best_metric):
    best_metric = checkpoint_model.test_rec_loss
    best_model = version

In [None]:
print(f"BEST MODEL => FILE = {best_model}, MSE = {best_metric}")

## Classifier Model
This is a model that make use of the z vector extracted from the autoencoder: z goes into an MLP and we discriminate between the classes.

This is done for each task to study how a specific model performs over a specific tasks.

The process of training the MLP allowa to backpropagate until the encoder of the AE and basically finetuning it for the classification task.

In [None]:
# First create a mapping utility to go from label to idx and vice versa
label2idx= {}
idx2label = {}
labels_task = dataset.get_dataframe()['labels'].unique()

for i in range(len(labels_task)):
  label2idx[labels_task[i]] = i
  idx2label[str(i)] = labels_task[i]

In [None]:
class ClassifierPerTask_Approach1(LightningModule):
    def __init__(self, encoder, text_labels, task_name, enable_wandb=False):
        super(ClassifierPerTask_Approach1, self).__init__()
        self.save_hyperparameters()

        self.encoder = encoder
        self.task_name = task_name
        self.text_labels = text_labels

        # HEAD 3
        self.classifier = nn.Sequential(
            nn.Linear(encoder.z_dim, 256),
            nn.BatchNorm1d(256),  # Batch normalization
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),  # Batch normalization
              nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, len(text_labels))
        )

        self.train_loss = []
        self.train_accuracy = []
        self.val_loss = []
        self.val_accuracy = []
        self.test_loss = []
        self.test_accuracy = []

        self.enable_wandb = enable_wandb

        if self.enable_wandb:
          wandb.init(project="Project_EAI_BrainComputerInterface", entity="rucci-2053183", group="approach1_classifier_"+task_name)

    def forward(self, z):
        return self.classifier(z)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        inputs = inputs.to(torch.float32)
        z = self.encoder(inputs)
        outputs = self(z)
        labels = self.labels2TargetTensor(labels).to(torch.long)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        preds = torch.argmax(outputs, dim=1)
        acc = (preds == labels).float().mean()
        self.log('train_loss', loss)
        self.log('test_accuracy', acc)

        self.train_loss.append(loss)
        self.train_accuracy.append(acc)

        return loss

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        inputs = inputs.to(torch.float32)
        z = self.encoder(inputs)
        outputs = self(z)
        labels = self.labels2TargetTensor(labels).to(torch.long)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        preds = torch.argmax(outputs, dim=1)
        acc = (preds == labels).float().mean()
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)

        self.test_loss.append(loss)
        self.test_accuracy.append(acc)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        inputs = inputs.to(torch.float32)
        z = self.encoder(inputs)
        outputs = self(z)
        labels = self.labels2TargetTensor(labels).to(torch.long)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        preds = torch.argmax(outputs, dim=1)
        acc = (preds == labels).float().mean()
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)

        self.val_loss.append(loss)
        self.val_accuracy.append(acc)

        return loss

    def configure_optimizers(self):
      optimizer = optim.Adam(self.parameters(), lr=0.001)
      scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
      return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

    def labels2TargetTensor(self, labels):
      target = []
      for item in labels:
        target.append(label2idx[item])

      return torch.Tensor(target)

    def on_train_epoch_end(self):
        train_loss = torch.stack([x for x in self.train_loss]).mean()
        train_acc = torch.stack([x for x in self.train_accuracy]).mean()

        # Print the training loss
        print_log = f'Training - Epoch {self.current_epoch}: Loss => {train_loss.item()} ACCURACY => {train_acc}'

        self.train_loss.clear()
        self.train_accuracy.clear()

        if self.enable_wandb:
            # Log mean training loss
            wandb.log({"epoch_train_loss": train_loss, "epoch_train_accuracy": train_acc})

        print(print_log)

    def on_test_epoch_end(self):
        test_loss = torch.stack([x for x in self.test_loss]).mean()
        test_acc = torch.stack([x for x in self.test_accuracy]).mean()

        # Print the training loss
        print_log = f'Test - Epoch {self.current_epoch}: Loss => {test_loss.item()} ACCURACY => {test_acc}'

        self.test_loss.clear()
        self.test_accuracy.clear()

        if self.enable_wandb:
            # Log mean test loss and accuracy
            wandb.log({"epoch_test_loss": test_loss, "epoch_test_accuracy": test_acc})

        print(print_log)

        self.test_acc = test_acc

    def on_validation_epoch_end(self):
        val_loss = torch.stack([x for x in self.val_loss]).mean()
        val_acc = torch.stack([x for x in self.val_accuracy]).mean()

        # Print the training loss
        print_log = f'Validation - Epoch {self.current_epoch}: Loss => {val_loss.item()} ACCURACY => {val_acc}'

        self.val_loss.clear()
        self.val_accuracy.clear()
        self.log("epoch_val_accuracy", val_acc)
        if self.enable_wandb:
            # Log mean validation loss and accuracy
            wandb.log({"epoch_val_loss": val_loss, "epoch_val_accuracy": val_acc})
            wandb.log({"epoch": self.current_epoch})

        print(print_log)

## Train the Classifier ⚡️⚡️

In [None]:
# Import the best AE
base_model_dir = "saved_models/Approach_1/"+selected_task
best_model_path = base_model_dir+"/"+best_model

batch = next(iter(train_loader))
input_dim = batch[0].shape[-1]
checkpoint_model = Autoencoder(input_dim=input_dim, batch_size = batch_size,sparsity_factor=0.005,enable_sparsity_loss=False, enable_weight_decay_loss=False, enable_non_negativity_constraint=False, enable_wandb = False)
checkpoint_model.load_state_dict(torch.load(best_model_path, map_location=checkpoint_model.device)['state_dict']) # ------> PyTorch Lightning API

In [None]:
# Initialize the Classifier Module for training
encoder = checkpoint_model.encoder
encoder.z_dim = 128
classifier = ClassifierPerTask_Approach1(encoder, labels_task, selected_task, enable_wandb=True)

early_stop = EarlyStopping(monitor="epoch_val_accuracy", min_delta=0.00, patience=30, verbose=True, mode="max")
checkpoint_callback = ModelCheckpoint(
     monitor='epoch_val_accuracy',
     dirpath="saved_models/Approach_1/"+selected_task+"/classifier/",
     filename='approach1-epoch{epoch:02d}-'+selected_task,
     auto_insert_metric_name=False,
     mode="max",
     save_top_k=2
)

trainer_classifier = Trainer(max_epochs=100, default_root_dir="saved_models/Approach_1/"+selected_task+"/classifier/", callbacks=[early_stop,checkpoint_callback],fast_dev_run=False)
trainer_classifier.fit(classifier, train_loader, val_loader)
# trainer.test(classifier, test_loader)

In [None]:
wandb.finish()

# Task 1 classifer wandb table 
# ancient paper task 1

## Evaluate the classifier

In [None]:
batch_size = 64
train_dataset = DataFrameApproach1(train_df)
val_dataset = DataFrameApproach1(val_df)
test_dataset = DataFrameApproach1(test_df)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,shuffle=False, collate_fn=collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, drop_last=True)

In [None]:
base_model_dir = "saved_models/Approach_1/"+selected_task+"/classifier/"

model_to_test_paths = [
    # "approach1-epoch36-task_1.ckpt",
    # "approach1-epoch57-task_1.ckpt",
]

# Add to the list the models .ckpt from the directory /saved_models/Approach_1/selected_task/classifier/
for root, dirs, files in os.walk(base_model_dir):
  for file in files:
    if file.endswith(".ckpt"):
      model_to_test_paths.append(os.path.join(root, file))

print(f"Models to test => {len(model_to_test_paths)}")

best_metric = 0
best_model = ""
for model_path_ in model_to_test_paths:
  model_path = model_path_
  input_dim = batch[0].shape[-1]
  version = model_path

  checkpoint_model = ClassifierPerTask_Approach1.load_from_checkpoint(model_path, enable_wandb=False)
  

  trainer = Trainer(accelerator = 'auto', fast_dev_run=False)
  print(f"Evaluation => {version}")
  trainer.test(checkpoint_model, dataloaders=test_loader)

  if(checkpoint_model.test_acc > best_metric):
    best_metric = checkpoint_model.test_acc
    best_model = version

In [None]:
print(f"BEST MODEL => FILE = {best_model}, MSE = {best_metric}")

### Confusion Matrix for the best model

In [None]:
# Confusion Matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

def plot_confusion_matrix(y_true, y_pred, labels, title):
    cm = confusion_matrix(y_true, y_pred)
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plt.figure(figsize=(10, 10))
    sns.heatmap(cm, annot=True, fmt=".2f", cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title(title)
    plt.show()

# Import the best Classifier model
model_path = best_model
input_dim = batch[0].shape[-1]
checkpoint_model = ClassifierPerTask_Approach1.load_from_checkpoint(model_path, enable_wandb=False)
checkpoint_model.eval()

# Get the predictions
y_true = []
y_pred = []
for batch in test_loader:
    inputs, labels = batch
    inputs = inputs.to(torch.float32)
    z = checkpoint_model.encoder(inputs)
    outputs = checkpoint_model(z)
    preds = torch.argmax(outputs, dim=1)
    y_true.extend(labels)
    y_pred.extend(preds)

y_true = [label2idx[item] for item in y_true]
y_pred = [item.item() for item in y_pred]

plot_confusion_matrix(y_true, y_pred, labels_task, "Confusion Matrix - "+selected_task)

# Approach 2
**NOTE: You will find some section identical to approach 1, this is to make the running of each approach independent one from the other.**

This is to address the study: "**Observe the performance of a general model that can be used to identify states across all the tasks.**"



## Dataset

In [None]:
class DatasetApproach2(Dataset):
  def __init__(self, window_size=1):
    self.dataset = []

    # Create a unique dataframe that is composed by the concatenation of all the files that belong to the task + the baselines files
    tasks_name = ['task_1','task_2','task_3','task_4']
    for task_name in tasks_name:
      for file in dataset_task_mapping[task_name]:
        df = pd.read_csv(file['file_path'])
        self.dataset.append(df)

    # Baseline 1
    for file in dataset_task_mapping['baseline_1']:
      df = pd.read_csv(file['file_path'])
      self.dataset.append(df)

    # Baseline 2
    for file in dataset_task_mapping['baseline_2']:
      df = pd.read_csv(file['file_path'])
      self.dataset.append(df)

    # Concatenate the dataframes
    print(f"Concatenating the dataframes ({len(self.dataset)})")
    self.dataset = pd.concat(self.dataset)
    # Create a dataframe
    print(f"Dataset shape: {self.dataset.shape}")


    # Windowing
    self.window_size = window_size

  def __len__(self):
    return len(self.dataset) - self.window_size

  def __getitem__(self, idx):
    # return as a tensor
    print(f"Index: {idx}")
    return torch.tensor(self.dataset.iloc[idx].values)

  def get_dataframe(self):
    return self.dataset

In [None]:
dataset = DatasetApproach2()
print(f"Dataset length: {len(dataset)}, Number of files used (Task 1: {len(dataset_task_mapping['task_1'])} + Task 2: {len(dataset_task_mapping['task_2'])} + Task 3: {len(dataset_task_mapping['task_3'])} + Task 4: {len(dataset_task_mapping['task_4'])} + Baseline 1 {len(dataset_task_mapping['baseline_1'])} + Baseline 2: {len(dataset_task_mapping['baseline_2'])})")

print(f"Unique values of the labels: {dataset.get_dataframe()['labels'].unique()}")

# Shuffle the rows of the dataset using sklearn (making sure the shuffle is reproducible)
from sklearn.utils import shuffle
data = shuffle(dataset.get_dataframe(), random_state=0)
# Remove the index column
data = data.reset_index(drop=True)

# Splitting into train and test sets (80% training data, 20% testing data)
train_df, test_df = train_test_split(data, test_size=0.15, random_state=42)

# Splitting the train_df further into train and validation sets (70% training data, 30% validation data)
train_df, val_df = train_test_split(train_df, test_size=0.15, random_state=42)

print(f"Data: {len(data)} ,Train size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}")


#Create the Dataframe classe Approach 2
class DataFrameApproach2(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe.iloc[:, :-1].values
        self.targets = dataframe['labels'].values

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx])
        y = self.targets[idx]
        return x, y

def collate_fn(batch):
    data = [item[0] for item in batch]
    targets = [item[1] for item in batch]

    # Apply min-max normalization to each column
    scaler = MinMaxScaler()
    normalized_data = scaler.fit_transform(data)

    return torch.tensor(normalized_data), targets

# Creating datasets and data loaders for each split
train_dataset = DataFrameApproach2(train_df)
val_dataset = DataFrameApproach2(val_df)
test_dataset = DataFrameApproach2(test_df)

## AutoEncoder Model

### Encoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, window_size = 1, enable_sparsity_loss=False):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(128),
        )

        # Apply He initialization to the linear layers
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, x):
        batch_size, input_dim = x.size()  # Obtain the shape of the input [bs, input_dim]
        input = x
        x = self.encoder(input)
        return x

### Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, input_dim, window_size, enable_sparsity_loss=False):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),
            nn.Linear(256, window_size * input_dim),
            nn.Sigmoid()
        )

        # Apply He initialization to the linear layers
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, x):
        # x = x.to(torch.float32)
        x = self.decoder(x)
        return x

### AE Lightning Module ⚡️⚡️

In [None]:
class Autoencoder(LightningModule):
    def __init__(self, input_dim, batch_size, sparsity_factor=0.1, sparsity_loss_coef = 1e-3, weight_decay=0.001, window_size=window_size, enable_sparsity_loss=False, enable_weight_decay_loss=False ,enable_non_negativity_constraint=False,enable_wandb = False):
        super(Autoencoder, self).__init__()

        if( enable_sparsity_loss == True and enable_non_negativity_constraint== True):
          print("The combination of constraints enable_sparsity_loss and enable_non_negativity_constraint both true leads to error in to the model matrix multiplication. This will be solved by setting enable_non_negativity_constraint to False.")

        self.save_hyperparameters()
        self.encoder = Encoder(input_dim=input_dim, window_size=window_size, enable_sparsity_loss = enable_sparsity_loss)
        self.decoder = Decoder(input_dim=input_dim, window_size=window_size, enable_sparsity_loss = enable_sparsity_loss)
        self.train_loss_memory = []
        self.train_rec_loss_memory = []

        self.val_loss_memory = []
        self.val_rec_loss_memory = []

        self.test_loss_memory = []
        self.test_rec_loss_memory = []

        self.batch_size = batch_size



        # --- Loss Settings
        self.enable_sparsity_loss = enable_sparsity_loss
        if enable_sparsity_loss:
          self.sparsity_loss_coef = sparsity_loss_coef
          self.sparsity_factor = sparsity_factor
          print(f"Enabled Sparsity term in the loss with sparsity loss coeff => {self.sparsity_loss_coef} and sparsity factor=>{self.sparsity_factor}")

          # self.sparsity_loss = nn.KLDivLoss(reduction='batchmean')
          # Memory logs for sparsity
          self.train_sparsity_loss_memory = []
          self.val_sparsity_loss_memory = []
          self.test_sparsity_loss_memory = []

          self.enable_non_negativity_constraint = False
        else:
          self.enable_non_negativity_constraint = enable_non_negativity_constraint
          if enable_non_negativity_constraint:
            print("Enabled non negativity constraint")


        self.enable_weight_decay_loss = enable_weight_decay_loss
        if enable_weight_decay_loss:
          print("Enabled weight decay")
          self.weight_decay = weight_decay

        self.wandb_log = enable_wandb

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        if torch.cuda.is_available():
            if torch.cuda.device_count() > 1:
                device = torch.device('cuda:0')
                print('Using device:', device)
            else:
                device = torch.device('cuda')
                print('Using device:', device)
        else:
            device = torch.device('cpu')
            print('Using device:', device)


        print('Using device:', device)

        self.to(device)
        print(f"Initialized Model on {self.device}")

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

    def kl_div(self, p, p_hat):
      funcs = nn.Sigmoid()
      p_hat = torch.mean(funcs(p_hat), 1)
      p_tensor = torch.Tensor([p] * p_hat.shape[0]).to(self.device)


      return torch.sum(p_tensor * torch.log(p_tensor) - p_tensor * torch.log(p_hat) + (1 - p_tensor) * torch.log(1 - p_tensor) - (1 - p_tensor) * torch.log(1 - p_hat))

    def sparse_loss(self, values):
      loss = 0
      values = values.view(self.batch_size, -1)

      # Encoder sparsity
      lyrs_encoder = list(self.encoder.encoder.children())
      for i, lyr in enumerate(lyrs_encoder):
          if isinstance(lyr, nn.Linear):
            values = lyr(values)
            # loss += self.sparsity_loss(torch.tensor([self.sparsity_factor]).to(self.device), values.to(self.device))
            loss += self.kl_div(self.sparsity_factor, values.to(self.device))

      # Decoder sparsity
      lyrs_decoder = list(self.decoder.decoder.children())
      for i, lyr in enumerate(lyrs_decoder):
          if isinstance(lyr, nn.Linear):
              values = lyr(values)
              # loss += self.sparsity_loss(torch.tensor([self.sparsity_factor]).to(self.device), values.to(self.device))
              loss += self.kl_div(self.sparsity_factor, values.to(self.device))

      return loss

    def calculate_weight_decay_loss(self):
        weight_decay_loss = 0.0
        for param in self.parameters():
            weight_decay_loss += 0.5 * self.weight_decay * torch.norm(param, p=2) ** 2
        return weight_decay_loss

    def enforce_non_negativity(self):
      for param in self.parameters():
        param.data.clamp_(min=0, max=None)

    def training_step(self, batch, batch_idx):
        x = batch[0].to(torch.float32) #[bs, input_dim]
        _, reconstructions = self(x)

        x = x.view(-1) # [bs * input_dim]
        reconstructions = reconstructions.view(-1)

        loss_mse = nn.MSELoss()(reconstructions, x)
        loss = loss_mse

        if self.enable_sparsity_loss:
          # sparsity_loss = self.sparsity_loss(torch.log(reconstructions).to(self.device), torch.tensor([self.sparsity_factor]).to(self.device))
          sparsity_loss = self.sparse_loss(x) * self.sparsity_loss_coef
          loss += sparsity_loss
          self.train_sparsity_loss_memory.append(sparsity_loss)

        if self.enable_weight_decay_loss:
          weight_decay_loss = self.calculate_weight_decay_loss()
          loss += weight_decay_loss

        self.train_loss_memory.append(loss)
        self.train_rec_loss_memory.append(loss_mse)

        if self.wandb_log:
          wandb.log({"train_total_loss": loss})
          wandb.log({"train_reconstruction_loss": loss_mse})

        return loss

    def validation_step(self, batch, batch_idx):
      x = batch[0].to(torch.float32)
      _, reconstructions = self(x)

      x = x.view(-1) #[]
      reconstructions = reconstructions.view(-1)

      loss_mse = nn.MSELoss()(reconstructions, x)
      loss = loss_mse

      if self.enable_sparsity_loss:
        # sparsity_loss = self.sparsity_loss(torch.log(reconstructions).to(self.device), torch.tensor([self.sparsity_factor]).to(self.device))
        sparsity_loss = self.sparse_loss(x) * self.sparsity_loss_coef
        loss += sparsity_loss
        self.val_sparsity_loss_memory.append(sparsity_loss)

      if self.enable_weight_decay_loss:
        weight_decay_loss = self.calculate_weight_decay_loss()
        loss += weight_decay_loss

      if self.enable_non_negativity_constraint:
        self.enforce_non_negativity()

      self.val_loss_memory.append(loss)
      self.val_rec_loss_memory.append(loss_mse)

      if self.wandb_log:
        wandb.log({"val_total_loss": loss})
        wandb.log({"val_reconstruction_loss": loss_mse})

      # For early stop and Model checkpoint callbacks
      self.log("val_reconstruction_loss",loss_mse)

      return loss

    def test_step(self, batch, batch_idx):
      x = batch[0].to(torch.float32)
      _, reconstructions = self(x)

      x = x.view(-1) #[]
      reconstructions = reconstructions.view(-1)

      loss_mse = nn.MSELoss()(reconstructions, x)
      loss = loss_mse

      if self.enable_sparsity_loss:
        # sparsity_loss = self.sparsity_loss(torch.log(reconstructions).to(self.device), torch.tensor([self.sparsity_factor]).to(self.device))
        sparsity_loss = self.sparse_loss(x) * self.sparsity_loss_coef
        loss += sparsity_loss
        self.test_sparsity_loss_memory.append(sparsity_loss)

      if self.enable_weight_decay_loss:
        weight_decay_loss = self.calculate_weight_decay_loss()
        loss += weight_decay_loss

      if self.enable_non_negativity_constraint:
        self.enforce_non_negativity()

      self.test_loss_memory.append(loss)
      self.test_rec_loss_memory.append(loss_mse)


      return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        scheduler = CosineAnnealingLR(optimizer, T_max=10)  # Adjust T_max as needed

        return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'epoch'}}

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
        # step
        optimizer.step(closure=optimizer_closure)

        if self.enable_non_negativity_constraint:
          self.enforce_non_negativity()

    def on_epoch_end(self):
        if self.wandb_log:
            wandb.log({'epoch': self.current_epoch})

    def on_train_epoch_end(self):
        # Access the training loss from the outputs
        train_loss = torch.stack([x for x in self.train_loss_memory]).mean()
        train_rec_loss = torch.stack([x for x in self.train_rec_loss_memory]).mean()

        # Print the training loss
        print_log = f'Training Loss - Epoch {self.current_epoch}: Total Loss => {train_loss.item()} MSE => {train_rec_loss}'

        self.train_loss_memory.clear()
        self.train_rec_loss_memory.clear()

        if self.enable_sparsity_loss:
          train_sparsity_loss = torch.stack([x for x in self.train_sparsity_loss_memory]).mean()
          print_log += f' SPARSE => {train_sparsity_loss}'
          self.train_sparsity_loss_memory.clear()

        if self.wandb_log:
          wandb.log({"train_total_loss": train_loss})
          wandb.log({"train_reconstruction_loss": train_rec_loss})
          if self.enable_sparsity_loss:
            wandb.log({"train_sparse_loss": train_sparsity_loss})

        print(print_log)

    def on_validation_epoch_end(self):
        # Access the training loss from the outputs
        val_loss = torch.stack([x for x in self.val_loss_memory]).mean()
        val_rec_loss = torch.stack([x for x in self.val_rec_loss_memory]).mean()

        # Print the training loss
        print_log = f'Validation Loss - Epoch {self.current_epoch}: Total Loss => {val_loss.item()} MSE => {val_rec_loss}'

        self.val_loss_memory.clear()
        self.val_rec_loss_memory.clear()

        if self.enable_sparsity_loss:
          val_sparsity_loss = torch.stack([x for x in self.val_sparsity_loss_memory]).mean()
          print_log += f' SPARSE => {val_sparsity_loss}'
          self.val_sparsity_loss_memory.clear()

        if self.wandb_log:
          wandb.log({"val_total_loss": val_loss})
          wandb.log({"val_reconstruction_loss": val_rec_loss})
          if self.enable_sparsity_loss:
            wandb.log({"val_sparse_loss": val_sparsity_loss})

        print(print_log)

    def on_test_epoch_end(self):
        # Access the training loss from the outputs
        test_loss = torch.stack([x for x in self.test_loss_memory]).mean()
        test_rec_loss = torch.stack([x for x in self.test_rec_loss_memory]).mean()

        # Print the training loss
        print_log = f'Test Loss - Epoch {self.current_epoch}: Total Loss => {test_loss.item()} MSE => {test_rec_loss}'

        self.test_loss_memory.clear()
        self.test_rec_loss_memory.clear()

        if self.enable_sparsity_loss:
          test_sparsity_loss = torch.stack([x for x in self.test_sparsity_loss_memory]).mean()
          print_log += f' SPARSE => {test_sparsity_loss}'
          self.test_sparsity_loss_memory.clear()

        if self.wandb_log:
          wandb.log({"test_total_loss": test_loss})
          wandb.log({"test_reconstruction_loss": test_rec_loss})
          if self.enable_sparsity_loss:
            wandb.log({"test_sparse_loss": test_sparsity_loss})

        self.test_rec_loss = test_rec_loss

        print(print_log)

## Classifier Model
This is a model that make use of the z vector extracted from the autoencoder: z goes into an MLP and we discriminate between the classes now across all the tasks.

The process of training the MLP allows to backpropagate until the encoder of the AE and basilly finetuning it for the classification task.

In [None]:
# First create a mapping utility to go from label to idx and vice versa
label2idx= {}
idx2label = {}
labels_task = dataset.get_dataframe()['labels'].unique()

for i in range(len(labels_task)):
  label2idx[labels_task[i]] = i
  idx2label[str(i)] = labels_task[i]

In [None]:
class ClassifierPerTask_Approach2(LightningModule):
    def __init__(self, encoder, text_labels, head_type=1, enable_wandb=False):
        super(ClassifierPerTask_Approach2, self).__init__()
        self.save_hyperparameters()

        self.encoder = encoder
        self.text_labels = text_labels
        if(head_type==1):
          # HEAD 1
          self.classifier = nn.Sequential(
              nn.Linear(encoder.z_dim, 128),
              nn.ReLU(),
              nn.Linear(128, len(text_labels))
          )
        elif (head_type ==2):
          # HEAD 2
          self.classifier = nn.Sequential(
              nn.Linear(encoder.z_dim, 256),
              nn.ReLU(),
               nn.Dropout(0.2),
              nn.Linear(256, 128),
              nn.ReLU(),
              nn.Dropout(0.2),
              nn.Linear(128, len(text_labels))
          )
        elif (head_type ==3):
          # HEAD 3
          self.classifier = nn.Sequential(
              nn.Linear(encoder.z_dim, 256),
              nn.BatchNorm1d(256),  # Batch normalization
              nn.ReLU(),
              nn.Dropout(0.2),
              nn.Linear(256, 128),
              nn.BatchNorm1d(128),  # Batch normalization
               nn.ReLU(),
              nn.Dropout(0.2),
              nn.Linear(128, len(text_labels))
          )
        else:
          # HEAD 4
          self.classifier = nn.Sequential(
              nn.Linear(encoder.z_dim, 256),
              nn.LayerNorm(256),  # Apply layer normalization
              nn.ReLU(),
              nn.Dropout(0.2),
              nn.Linear(256, 128),
              nn.LayerNorm(128),  # Apply layer normalization
              nn.ReLU(),
              nn.Dropout(0.2),
              nn.Linear(128, len(text_labels))
          )

        self.train_loss = []
        self.train_accuracy = []
        self.val_loss = []
        self.val_accuracy = []
        self.test_loss = []
        self.test_accuracy = []

        self.enable_wandb = enable_wandb

        if self.enable_wandb:
          wandb.init(project="Project_EAI_BrainComputerInterface", entity="rucci-2053183", group="approach2_classifier")

    def forward(self, z):
        return self.classifier(z)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        inputs = inputs.to(torch.float32)
        z = self.encoder(inputs)
        outputs = self(z)
        labels = self.labels2TargetTensor(labels).to(torch.long).to(outputs.device)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        preds = torch.argmax(outputs, dim=1)
        acc = (preds == labels).float().mean()
        self.log('train_loss', loss)
        self.log('test_accuracy', acc)

        self.train_loss.append(loss)
        self.train_accuracy.append(acc)

        return loss

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        inputs = inputs.to(torch.float32)
        z = self.encoder(inputs)
        outputs = self(z)
        labels = self.labels2TargetTensor(labels).to(torch.long).to(outputs.device)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        preds = torch.argmax(outputs, dim=1)
        acc = (preds == labels).float().mean()
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)

        self.test_loss.append(loss)
        self.test_accuracy.append(acc)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        inputs = inputs.to(torch.float32)
        z = self.encoder(inputs)
        outputs = self(z)
        labels = self.labels2TargetTensor(labels).to(torch.long).to(outputs.device)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        preds = torch.argmax(outputs, dim=1)
        acc = (preds == labels).float().mean()
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)

        self.val_loss.append(loss)
        self.val_accuracy.append(acc)

        return loss

    def configure_optimizers(self):
      optimizer = optim.Adam(self.parameters(), lr=0.001)
      scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
      return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

    def labels2TargetTensor(self, labels):
      target = []
      for item in labels:
        target.append(label2idx[item])

      return torch.Tensor(target)

    def on_train_epoch_end(self):
        train_loss = torch.stack([x for x in self.train_loss]).mean()
        train_acc = torch.stack([x for x in self.train_accuracy]).mean()

        # Print the training loss
        print_log = f'Training - Epoch {self.current_epoch}: Loss => {train_loss.item()} ACCURACY => {train_acc}'

        self.train_loss.clear()
        self.train_accuracy.clear()

        if self.enable_wandb:
            # Log mean training loss
            wandb.log({"epoch_train_loss": train_loss, "epoch_train_accuracy": train_acc})

        print(print_log)

    def on_test_epoch_end(self):
        test_loss = torch.stack([x for x in self.test_loss]).mean()
        test_acc = torch.stack([x for x in self.test_accuracy]).mean()

        # Print the training loss
        print_log = f'Test - Epoch {self.current_epoch}: Loss => {test_loss.item()} ACCURACY => {test_acc}'

        self.test_loss.clear()
        self.test_accuracy.clear()

        if self.enable_wandb:
            # Log mean test loss and accuracy
            wandb.log({"epoch_test_loss": test_loss, "epoch_test_accuracy": test_acc})

        print(print_log)

        self.test_acc = test_acc

    def on_validation_epoch_end(self):
        val_loss = torch.stack([x for x in self.val_loss]).mean()
        val_acc = torch.stack([x for x in self.val_accuracy]).mean()

        # Print the training loss
        print_log = f'Validation - Epoch {self.current_epoch}: Loss => {val_loss.item()} ACCURACY => {val_acc}'

        self.val_loss.clear()
        self.val_accuracy.clear()
        self.log("epoch_val_accuracy", val_acc)
        if self.enable_wandb:
            # Log mean validation loss and accuracy
            wandb.log({"epoch_val_loss": val_loss, "epoch_val_accuracy": val_acc})
            wandb.log({"epoch": self.current_epoch})

        print(print_log)

## Evaluate the Classifier

In [None]:
batch_size = 64
train_dataset = DataFrameApproach2(train_df)
val_dataset = DataFrameApproach2(val_df)
test_dataset = DataFrameApproach2(test_df)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,shuffle=False, collate_fn=collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, drop_last=True)

In [None]:
# Import The best classifier model
base_model_dir = "saved_models/Approach_2/classifier/head_3"

model_to_test_paths = [
    # "approach1-epoch36-task_1.ckpt",
    # "approach1-epoch57-task_1.ckpt",
]

# Add to the list the models .ckpt from the directory /saved_models/Approach_1/selected_task/classifier/
for root, dirs, files in os.walk(base_model_dir):
  for file in files:
    if file.endswith(".ckpt"):
      model_to_test_paths.append(os.path.join(root, file))

print(f"Models to test => {len(model_to_test_paths)}")

best_metric = 0
best_model = ""
for model_path_ in model_to_test_paths:
  model_path = model_path_
  input_dim = batch[0].shape[-1]
  version = model_path

  checkpoint_model = ClassifierPerTask_Approach2.load_from_checkpoint(model_path, head_type=3, enable_wandb=False)
  

  trainer = Trainer(accelerator = 'auto', fast_dev_run=False)
  print(f"Evaluation => {version}")
  trainer.test(checkpoint_model, dataloaders=test_loader)

  if(checkpoint_model.test_acc > best_metric):
    best_metric = checkpoint_model.test_acc
    best_model = version

### Confusion Matrix for the best model

In [None]:
# Confusion Matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

def plot_confusion_matrix(y_true, y_pred, labels, title):
    cm = confusion_matrix(y_true, y_pred)
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plt.figure(figsize=(10, 10))
    sns.heatmap(cm, annot=True, fmt=".2f", cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title(title)
    plt.show()

# Import the best Classifier model
model_path = best_model
input_dim = batch[0].shape[-1]
checkpoint_model = ClassifierPerTask_Approach2.load_from_checkpoint(model_path, head_type=3, enable_wandb=False)
checkpoint_model.eval()

# Get the predictions
y_true = []
y_pred = []
for batch in test_loader:
    inputs, labels = batch
    inputs = inputs.to(torch.float32)
    z = checkpoint_model.encoder(inputs)
    outputs = checkpoint_model(z)
    preds = torch.argmax(outputs, dim=1)
    y_true.extend(labels)
    y_pred.extend(preds)

y_true = [label2idx[item] for item in y_true]
y_pred = [item.item() for item in y_pred]

plot_confusion_matrix(y_true, y_pred, labels_task, "Confusion Matrix - General Model")

# Approach 4
Might be interesting to investigate how much a contrastive finetuning/training of the autoencoder might affect the performance at approach 3.
This comes from the fact that the general autoencoder reconstruction task, doesn't take into account the objective of discriminate zones of the latent space according to the classes.

This also allows to make a study in which an encoder is trained from scratch in the same settings, with the same idea of use the Z representation to downstream classification tasks.

In [4]:
class DatasetApproach4(Dataset):
  def __init__(self, window_size=1):
    self.dataset = []

    # Create a unique dataframe that is composed by the concatenation of all the files that belong to the task + the baselines files
    tasks_name = ['task_1','task_2','task_3','task_4']
    for task_name in tasks_name:
      for file in dataset_task_mapping[task_name]:
        df = pd.read_csv(file['file_path'])
        self.dataset.append(df)

    # Baseline 1
    for file in dataset_task_mapping['baseline_1']:
      df = pd.read_csv(file['file_path'])
      self.dataset.append(df)

    # Baseline 2
    for file in dataset_task_mapping['baseline_2']:
      df = pd.read_csv(file['file_path'])
      self.dataset.append(df)

    # Concatenate the dataframes
    print(f"Concatenating the dataframes ({len(self.dataset)})")
    self.dataset = pd.concat(self.dataset)
    # Create a dataframe
    print(f"Dataset shape: {self.dataset.shape}")


    # Windowing
    self.window_size = window_size

  def __len__(self):
    return len(self.dataset) - self.window_size

  def __getitem__(self, idx):
    # return as a tensor
    print(f"Index: {idx}")
    return torch.tensor(self.dataset.iloc[idx].values)

  def get_dataframe(self):
    return self.dataset

In [5]:
dataset = DatasetApproach4()
print(f"Dataset length: {len(dataset)}, Number of files used (Task 1: {len(dataset_task_mapping['task_1'])} + Task 2: {len(dataset_task_mapping['task_2'])} + Task 3: {len(dataset_task_mapping['task_3'])} + Task 4: {len(dataset_task_mapping['task_4'])} + Baseline 1 {len(dataset_task_mapping['baseline_1'])} + Baseline 2: {len(dataset_task_mapping['baseline_2'])})")
print(f"Unique values of the labels: {dataset.get_dataframe()['labels'].unique()}")
dataset_labels = dataset.get_dataframe()['labels'].unique().tolist()

# Shuffle the rows of the dataset using sklearn (making sure the shuffle is reproducible)
from sklearn.utils import shuffle
data = shuffle(dataset.get_dataframe(), random_state=0)
# Remove the index column
data = data.reset_index(drop=True)

# Splitting into train and test sets (80% training data, 20% testing data)
train_df, test_df = train_test_split(data, test_size=0.15, random_state=42)

# Splitting the train_df further into train and validation sets (70% training data, 30% validation data)
train_df, val_df = train_test_split(train_df, test_size=0.15, random_state=42)

print(f"Data: {len(data)} ,Train size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}")

#Create the Dataframe class Approach 3
class DataFrameApproach4(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe.iloc[:, :-1].values
        self.targets = dataframe['labels'].values

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx])
        y = self.targets[idx]

        return x, y

def collate_fn(batch):
    data = [item[0] for item in batch]
    targets = [item[1] for item in batch]

    # Apply min-max normalization to each column
    scaler = MinMaxScaler()
    normalized_data = scaler.fit_transform(data)

    return torch.tensor(normalized_data), targets

# Creating datasets and data loaders for each split
train_dataset = DataFrameApproach4(train_df)
val_dataset = DataFrameApproach4(val_df)
test_dataset = DataFrameApproach4(test_df)

Concatenating the dataframes (27)
Dataset shape: (12477, 65)
Dataset length: 12476, Number of files used (Task 1: 10 + Task 2: 3 + Task 3: 5 + Task 4: 2 + Baseline 1 4 + Baseline 2: 3)
Unique values of the labels: ['TASK1T0' 'TASK1T2' 'TASK1T1' 'TASK2T0' 'TASK2T2' 'TASK2T1' 'TASK3T0'
 'TASK3T2' 'TASK3T1' 'TASK4T0' 'TASK4T1' 'TASK4T2' 'BASE1T0' 'BASE2T0']
Data: 12477 ,Train size: 9014, Val size: 1591, Test size: 1872


## AutoEncoder Model

### Encoder

In [6]:
class Encoder(nn.Module):
    def __init__(self, input_dim, window_size = 1, enable_sparsity_loss=False):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(128),
        )

        # Apply He initialization to the linear layers
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, x):
        batch_size, input_dim = x.size()  # Obtain the shape of the input [bs, input_dim]
        input = x
        x = self.encoder(input)
        return x

### Decoder

In [7]:
class Decoder(nn.Module):
    def __init__(self, input_dim, window_size, enable_sparsity_loss=False):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),
            nn.Linear(256, window_size * input_dim),
            nn.Sigmoid()
        )

        # Apply He initialization to the linear layers
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, x):
        # x = x.to(torch.float32)
        x = self.decoder(x)
        return x

### AE Lightning Module ⚡️⚡️

In [8]:
class Autoencoder(LightningModule):
    def __init__(self, input_dim, batch_size, sparsity_factor=0.1, sparsity_loss_coef = 1e-3, weight_decay=0.001, window_size=window_size, enable_sparsity_loss=False, enable_weight_decay_loss=False ,enable_non_negativity_constraint=False,enable_wandb = False, decoder_none=False):
        super(Autoencoder, self).__init__()

        if( enable_sparsity_loss == True and enable_non_negativity_constraint== True):
          print("The combination of constraints enable_sparsity_loss and enable_non_negativity_constraint both true leads to error in to the model matrix multiplication. This will be solved by setting enable_non_negativity_constraint to False.")

        self.save_hyperparameters()
        self.encoder = Encoder(input_dim=input_dim, window_size=window_size, enable_sparsity_loss = enable_sparsity_loss)
        if not decoder_none:
          self.decoder = Decoder(input_dim=input_dim, window_size=window_size, enable_sparsity_loss = enable_sparsity_loss)
        self.train_loss_memory = []
        self.train_rec_loss_memory = []

        self.val_loss_memory = []
        self.val_rec_loss_memory = []

        self.test_loss_memory = []
        self.test_rec_loss_memory = []

        self.batch_size = batch_size



        # --- Loss Settings
        self.enable_sparsity_loss = enable_sparsity_loss
        if enable_sparsity_loss:
          self.sparsity_loss_coef = sparsity_loss_coef
          self.sparsity_factor = sparsity_factor
          print(f"Enabled Sparsity term in the loss with sparsity loss coeff => {self.sparsity_loss_coef} and sparsity factor=>{self.sparsity_factor}")

          # self.sparsity_loss = nn.KLDivLoss(reduction='batchmean')
          # Memory logs for sparsity
          self.train_sparsity_loss_memory = []
          self.val_sparsity_loss_memory = []
          self.test_sparsity_loss_memory = []

          self.enable_non_negativity_constraint = False
        else:
          self.enable_non_negativity_constraint = enable_non_negativity_constraint
          if enable_non_negativity_constraint:
            print("Enabled non negativity constraint")


        self.enable_weight_decay_loss = enable_weight_decay_loss
        if enable_weight_decay_loss:
          print("Enabled weight decay")
          self.weight_decay = weight_decay

        self.wandb_log = enable_wandb

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        if torch.cuda.is_available():
            if torch.cuda.device_count() > 1:
                device = torch.device('cuda:0')
                print('Using device:', device)
            else:
                device = torch.device('cuda')
                print('Using device:', device)
        else:
            device = torch.device('cpu')
            print('Using device:', device)


        print('Using device:', device)

        self.to(device)
        print(f"Initialized Model on {self.device}")

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

    def kl_div(self, p, p_hat):
      funcs = nn.Sigmoid()
      p_hat = torch.mean(funcs(p_hat), 1)
      p_tensor = torch.Tensor([p] * p_hat.shape[0]).to(self.device)


      return torch.sum(p_tensor * torch.log(p_tensor) - p_tensor * torch.log(p_hat) + (1 - p_tensor) * torch.log(1 - p_tensor) - (1 - p_tensor) * torch.log(1 - p_hat))

    def sparse_loss(self, values):
      loss = 0
      values = values.view(self.batch_size, -1)

      # Encoder sparsity
      lyrs_encoder = list(self.encoder.encoder.children())
      for i, lyr in enumerate(lyrs_encoder):
          if isinstance(lyr, nn.Linear):
            values = lyr(values)
            # loss += self.sparsity_loss(torch.tensor([self.sparsity_factor]).to(self.device), values.to(self.device))
            loss += self.kl_div(self.sparsity_factor, values.to(self.device))

      # Decoder sparsity
      lyrs_decoder = list(self.decoder.decoder.children())
      for i, lyr in enumerate(lyrs_decoder):
          if isinstance(lyr, nn.Linear):
              values = lyr(values)
              # loss += self.sparsity_loss(torch.tensor([self.sparsity_factor]).to(self.device), values.to(self.device))
              loss += self.kl_div(self.sparsity_factor, values.to(self.device))

      return loss

    def calculate_weight_decay_loss(self):
        weight_decay_loss = 0.0
        for param in self.parameters():
            weight_decay_loss += 0.5 * self.weight_decay * torch.norm(param, p=2) ** 2
        return weight_decay_loss

    def enforce_non_negativity(self):
      for param in self.parameters():
        param.data.clamp_(min=0, max=None)

    def training_step(self, batch, batch_idx):
        x = batch[0].to(torch.float32) #[bs, input_dim]
        _, reconstructions = self(x)

        x = x.view(-1) # [bs * input_dim]
        reconstructions = reconstructions.view(-1)

        loss_mse = nn.MSELoss()(reconstructions, x)
        loss = loss_mse

        if self.enable_sparsity_loss:
          # sparsity_loss = self.sparsity_loss(torch.log(reconstructions).to(self.device), torch.tensor([self.sparsity_factor]).to(self.device))
          sparsity_loss = self.sparse_loss(x) * self.sparsity_loss_coef
          loss += sparsity_loss
          self.train_sparsity_loss_memory.append(sparsity_loss)

        if self.enable_weight_decay_loss:
          weight_decay_loss = self.calculate_weight_decay_loss()
          loss += weight_decay_loss

        self.train_loss_memory.append(loss)
        self.train_rec_loss_memory.append(loss_mse)

        if self.wandb_log:
          wandb.log({"train_total_loss": loss})
          wandb.log({"train_reconstruction_loss": loss_mse})

        return loss

    def validation_step(self, batch, batch_idx):
      x = batch[0].to(torch.float32)
      _, reconstructions = self(x)

      x = x.view(-1) #[]
      reconstructions = reconstructions.view(-1)

      loss_mse = nn.MSELoss()(reconstructions, x)
      loss = loss_mse

      if self.enable_sparsity_loss:
        # sparsity_loss = self.sparsity_loss(torch.log(reconstructions).to(self.device), torch.tensor([self.sparsity_factor]).to(self.device))
        sparsity_loss = self.sparse_loss(x) * self.sparsity_loss_coef
        loss += sparsity_loss
        self.val_sparsity_loss_memory.append(sparsity_loss)

      if self.enable_weight_decay_loss:
        weight_decay_loss = self.calculate_weight_decay_loss()
        loss += weight_decay_loss

      if self.enable_non_negativity_constraint:
        self.enforce_non_negativity()

      self.val_loss_memory.append(loss)
      self.val_rec_loss_memory.append(loss_mse)

      if self.wandb_log:
        wandb.log({"val_total_loss": loss})
        wandb.log({"val_reconstruction_loss": loss_mse})

      # For early stop and Model checkpoint callbacks
      self.log("val_reconstruction_loss",loss_mse)

      return loss

    def test_step(self, batch, batch_idx):
      x = batch[0].to(torch.float32)
      _, reconstructions = self(x)

      x = x.view(-1) #[]
      reconstructions = reconstructions.view(-1)

      loss_mse = nn.MSELoss()(reconstructions, x)
      loss = loss_mse

      if self.enable_sparsity_loss:
        # sparsity_loss = self.sparsity_loss(torch.log(reconstructions).to(self.device), torch.tensor([self.sparsity_factor]).to(self.device))
        sparsity_loss = self.sparse_loss(x) * self.sparsity_loss_coef
        loss += sparsity_loss
        self.test_sparsity_loss_memory.append(sparsity_loss)

      if self.enable_weight_decay_loss:
        weight_decay_loss = self.calculate_weight_decay_loss()
        loss += weight_decay_loss

      if self.enable_non_negativity_constraint:
        self.enforce_non_negativity()

      self.test_loss_memory.append(loss)
      self.test_rec_loss_memory.append(loss_mse)


      return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        scheduler = CosineAnnealingLR(optimizer, T_max=10)  # Adjust T_max as needed

        return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'epoch'}}

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
        # step
        optimizer.step(closure=optimizer_closure)

        if self.enable_non_negativity_constraint:
          self.enforce_non_negativity()

    def on_epoch_end(self):
        if self.wandb_log:
            wandb.log({'epoch': self.current_epoch})

    def on_train_epoch_end(self):
        # Access the training loss from the outputs
        train_loss = torch.stack([x for x in self.train_loss_memory]).mean()
        train_rec_loss = torch.stack([x for x in self.train_rec_loss_memory]).mean()

        # Print the training loss
        print_log = f'Training Loss - Epoch {self.current_epoch}: Total Loss => {train_loss.item()} MSE => {train_rec_loss}'

        self.train_loss_memory.clear()
        self.train_rec_loss_memory.clear()

        if self.enable_sparsity_loss:
          train_sparsity_loss = torch.stack([x for x in self.train_sparsity_loss_memory]).mean()
          print_log += f' SPARSE => {train_sparsity_loss}'
          self.train_sparsity_loss_memory.clear()

        if self.wandb_log:
          wandb.log({"train_total_loss": train_loss})
          wandb.log({"train_reconstruction_loss": train_rec_loss})
          if self.enable_sparsity_loss:
            wandb.log({"train_sparse_loss": train_sparsity_loss})

        print(print_log)

    def on_validation_epoch_end(self):
        # Access the training loss from the outputs
        val_loss = torch.stack([x for x in self.val_loss_memory]).mean()
        val_rec_loss = torch.stack([x for x in self.val_rec_loss_memory]).mean()

        # Print the training loss
        print_log = f'Validation Loss - Epoch {self.current_epoch}: Total Loss => {val_loss.item()} MSE => {val_rec_loss}'

        self.val_loss_memory.clear()
        self.val_rec_loss_memory.clear()

        if self.enable_sparsity_loss:
          val_sparsity_loss = torch.stack([x for x in self.val_sparsity_loss_memory]).mean()
          print_log += f' SPARSE => {val_sparsity_loss}'
          self.val_sparsity_loss_memory.clear()

        if self.wandb_log:
          wandb.log({"val_total_loss": val_loss})
          wandb.log({"val_reconstruction_loss": val_rec_loss})
          if self.enable_sparsity_loss:
            wandb.log({"val_sparse_loss": val_sparsity_loss})

        print(print_log)

    def on_test_epoch_end(self):
        # Access the training loss from the outputs
        test_loss = torch.stack([x for x in self.test_loss_memory]).mean()
        test_rec_loss = torch.stack([x for x in self.test_rec_loss_memory]).mean()

        # Print the training loss
        print_log = f'Test Loss - Epoch {self.current_epoch}: Total Loss => {test_loss.item()} MSE => {test_rec_loss}'

        self.test_loss_memory.clear()
        self.test_rec_loss_memory.clear()

        if self.enable_sparsity_loss:
          test_sparsity_loss = torch.stack([x for x in self.test_sparsity_loss_memory]).mean()
          print_log += f' SPARSE => {test_sparsity_loss}'
          self.test_sparsity_loss_memory.clear()

        if self.wandb_log:
          wandb.log({"test_total_loss": test_loss})
          wandb.log({"test_reconstruction_loss": test_rec_loss})
          if self.enable_sparsity_loss:
            wandb.log({"test_sparse_loss": test_sparsity_loss})

        self.test_rec_loss = test_rec_loss

        print(print_log)