In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..', 'src')))

from cnn_tnn895 import ConvTransformerModel
from cnn_lstm import ParallelCNNLSTMModel
from ann_baseline import Baseline_ANN
from cnn_lstm import  ParallelCNNLSTMModel
from cnn_model import CNN_Ieeg_Model
from lstm_model import LSTM_Ieeg_Model


from utils import get_loaders_no_sss, import_checkpoint, save_checkpoint
import torch
import multiprocessing
import mlflow
import mlflow.pytorch
import torch
import torch.optim as optim
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix,roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import psutil
from torch.cuda.amp import GradScaler, autocast
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import math


# MLflow Setup

In [2]:
os.environ['AWS_ACCESS_KEY_ID'] = 'dIgexhE2iDrGls2qargL'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'IzEzgQpztotDnrIInJdUfUIYngpjJoT18d0FDZf7'
os.environ['MLFLOW_S3_ENDPOINT_URL'] = 'http://localhost:9000'
os.environ['MLFLOW_S3_IGNORE_TLS'] = 'true'
os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true"
mlflow.set_tracking_uri("http://localhost:5000")

print('tracking uri:', mlflow.get_tracking_uri())

tracking uri: http://localhost:5000


In [3]:
# Configuration
DATA_DIR = '../data/data_normalized_exp2_splited'
DATA_DIR_EV = '../data/data_normalized_with_ev_split'
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
TEST_DIR = os.path.join(DATA_DIR, 'test')
TRAIN_DIR_EV = os.path.join(DATA_DIR_EV, 'train')
TEST_DIR_EV = os.path.join(DATA_DIR_EV, 'test')
class_mapping = {'CA': 0, 'CA1': 1, 'Thalamus': 2, 'vM1': 3}
SEQ_LENGTH = 500
BATCH_SIZE = 256
NUM_EPOCHS = 50
LEARNING_RATE = 0.0001
EXPERIMENT_NAME = "IEEG_MODELS_COMP_FINAL_2"
# RUN_NAME = "CNN"
PIN_MEMORY = True
LOAD_MODEL = False
NUM_WORKERS = multiprocessing.cpu_count()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_SIZE = SEQ_LENGTH
NUM_CLASSES = 4
CHECKPOINTS_PATH = '../models/checkpoints'

# Before Training

In [4]:
def get_model_size(model):
    param_size = 0
    buffer_size = 0
    for param in model.parameters():
        param_size += param.numel() * param.element_size()
    for buffer in model.buffers():
        buffer_size += buffer.numel() * buffer.element_size()
    size_all_mb = (param_size + buffer_size) / 1024 ** 2
    return size_all_mb

In [5]:
def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, optimizer: optim.Optimizer, 
                criterion: nn.Module, num_epochs: int, device: torch.device, save_checkpoint_interval: int = 10, 
                early_stopping_patience: int = 15, checkpoint_dir: str = '../models/checkpoints', 
                results_dir: str = '../models/results', accumulation_steps: int = 2,
                cnn=False, model_name='CNN'):
    """
    Train a deep learning model with the given parameters and log metrics to MLflow.

    Args:
        model (nn.Module): The model to train.
        train_loader (DataLoader): DataLoader for the training data.
        val_loader (DataLoader): DataLoader for the validation data.
        optimizer (optim.Optimizer): Optimizer for updating model parameters.
        criterion (nn.Module): Loss function.
        num_epochs (int): Number of epochs to train.
        device (torch.device): Device to use for training (CPU or GPU).
        save_checkpoint_interval (int, optional): Interval for saving checkpoints. Default is 10.
        early_stopping_patience (int, optional): Patience for early stopping. Default is 15.
        checkpoint_dir (str, optional): Directory to save checkpoints. Default is 'checkpoints'.
        results_dir (str, optional): Directory to save results. Default is 'results'.
        accumulation_steps (int, optional): Number of steps to accumulate gradients before updating weights. Default is 2.
        cnn (bool, optional): If True, use CNN mode. Default is False.
        model_name (str, optional): Name of the model for saving checkpoints. Default is 'CNN'.
    """
    scaler = GradScaler()  # For mixed precision training
    best_val_loss = float('inf')  # Track the best validation loss for early stopping
    patience_counter = 0  # Counter for early stopping

    # Ensure results and checkpoint directories exist
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    train_metrics = []
    val_metrics = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        y_true_train = []
        y_pred_train = []

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        optimizer.zero_grad()  # Reset gradients at the start of each epoch

        for batch_idx, (inputs, labels) in enumerate(progress_bar):
            inputs, labels = inputs.to(device), labels.to(device)

            with autocast():  # Mixed precision training
                if not cnn:
                    outputs = model(inputs)
                else:
                    outputs, _ = model(inputs)
                    

                loss = criterion(outputs, labels.squeeze())

            scaler.scale(loss).backward()  # Backpropagation

            scaler.step(optimizer)  # Update weights
            scaler.update()
            optimizer.zero_grad()  # Reset gradients after updating weights

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            y_true_train.extend(labels.squeeze().cpu().numpy())
            y_pred_train.extend(predicted.cpu().numpy())

            avg_loss = running_loss / (batch_idx + 1)
            train_accuracy = accuracy_score(y_true_train, y_pred_train)
            precision, recall, f1, _ = precision_recall_fscore_support(y_true_train, y_pred_train, average='weighted', zero_division=0)

            progress_bar.set_postfix(train_loss=avg_loss, train_accuracy=train_accuracy, train_precision=precision, train_recall=recall, train_f1=f1)

        # Log training metrics to MLflow
        mlflow.log_metric("train_loss", avg_loss, step=epoch)
        mlflow.log_metric("train_accuracy", train_accuracy, step=epoch)
        mlflow.log_metric("train_precision", precision, step=epoch)
        mlflow.log_metric("train_recall", recall, step=epoch)
        mlflow.log_metric("train_f1_score", f1, step=epoch)

        # Store training metrics in DataFrame
        train_metrics.append({
            "epoch": epoch + 1,
            "model_name": model_name,
            "train_loss": avg_loss,
            "train_accuracy": train_accuracy,
            "train_precision": precision,
            "train_recall": recall,
            "train_f1": f1
        })

        # Validation step
        model.eval()
        val_loss = 0.0
        y_true_val = []
        y_pred_val = []

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)

                with autocast():  # Mixed precision inference
                    if not cnn:
                        outputs = model(inputs)
                    else:
                        outputs, _ = model(inputs)
                    loss = criterion(outputs, labels.squeeze())

                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                y_true_val.extend(labels.squeeze().cpu().numpy())
                y_pred_val.extend(predicted.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = accuracy_score(y_true_val, y_pred_val)
        val_precision, val_recall, val_f1, _ = precision_recall_fscore_support(y_true_val, y_pred_val, average='weighted', zero_division=0)

        # Log validation metrics to MLflow
        mlflow.log_metric("val_loss", avg_val_loss, step=epoch)
        mlflow.log_metric("val_accuracy", val_accuracy, step=epoch)
        mlflow.log_metric("val_precision", val_precision, step=epoch)
        mlflow.log_metric("val_recall", val_recall, step=epoch)
        mlflow.log_metric("val_f1_score", val_f1, step=epoch)

        # Store validation metrics in DataFrame
        val_metrics.append({
            "epoch": epoch + 1,
            "model_name": model_name,
            "val_loss": avg_val_loss,
            "val_accuracy": val_accuracy,
            "val_precision": val_precision,
            "val_recall": val_recall,
            "val_f1": val_f1
        })

        # Update the progress bar with validation metrics
        progress_bar.set_postfix(train_loss=avg_loss, train_accuracy=train_accuracy, train_precision=precision, train_recall=recall, train_f1=f1, val_loss=avg_val_loss, val_accuracy=val_accuracy, val_precision=val_precision, val_recall=val_recall, val_f1=val_f1)

        # Save checkpoint every 'save_checkpoint_interval' epochs
        if (epoch + 1) % save_checkpoint_interval == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{model_name}.pth.tar')
            save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, checkpoint_path)
            mlflow.log_artifact(checkpoint_path, artifact_path="checkpoints")

        # Early stopping based on validation loss
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0  # Reset counter if we get a new best validation loss
        else:
            patience_counter += 1

        if patience_counter >= early_stopping_patience:
            print(f"Early stopping at epoch {epoch + 1} due to no improvement in validation loss.")
            mlflow.log_param(f"{model_name}_epochs_actual", epoch + 1)
            break

        # Clear CUDA cache after each epoch
        torch.cuda.empty_cache()

    # Save training and validation metrics as CSV files
    train_metrics_df = pd.DataFrame(train_metrics)
    val_metrics_df = pd.DataFrame(val_metrics)
    train_metrics_df.to_csv(os.path.join(results_dir, f'train_metrics_{model_name}.csv'), index=False)
    val_metrics_df.to_csv(os.path.join(results_dir, f'val_metrics_{model_name}.csv'), index=False)

    # Clear CUDA cache at the end of training
    torch.cuda.empty_cache()

In [6]:


def evaluate_model(model: nn.Module, test_loader: DataLoader, class_mapping: dict, 
                   device: torch.device, img_path: str, results_dir: str,
                   run_name: str, cnn=False, save_fm=False):
    """
    Evaluate a deep learning model and log metrics to MLflow.

    Args:
        model (nn.Module): The model to evaluate.
        test_loader (DataLoader): DataLoader for the test data.
        class_mapping (dict): A dictionary mapping class names to class indices.
        device (torch.device): Device to use for evaluation (CPU or GPU).
        img_path (str): Path to save the confusion matrix image.
        run_name (str): Name of the MLflow run.
        batch_size (int, optional): Batch size for evaluation. Default is 16.
        cnn (bool, optional): If True, handle model output as CNN. Default is False.
    """
    model.eval()
    y_true_test = []
    y_pred_test = []
    feature_maps = []

    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Evaluating", unit="batch"):
            inputs, labels = inputs.to(device), labels.to(device)
            if not cnn: 
                outputs = model(inputs)
            else:
                outputs, feature_map = model(inputs)
                feature_maps.append([fm.cpu() for fm in feature_map])  # Move feature maps to CPU to free GPU memory

            _, predicted = torch.max(outputs, 1)
            y_true_test.extend(labels.squeeze().cpu().numpy())
            y_pred_test.extend(predicted.cpu().numpy())

            # Clear cache to free up memory
            torch.cuda.empty_cache()

    test_accuracy = accuracy_score(y_true_test, y_pred_test)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true_test, y_pred_test, average='weighted', zero_division=0)

    print(f'Accuracy of the model on the test data: {test_accuracy:.2f}%')
    print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')

    mlflow.log_metric("test_accuracy", test_accuracy)
    mlflow.log_metric("test_precision", precision)
    mlflow.log_metric("test_recall", recall)
    mlflow.log_metric("test_f1", f1)

    # Confusion matrix
    class_names = {v: k for k, v in class_mapping.items()}
    cm = confusion_matrix(y_true_test, y_pred_test)
    cm_df = pd.DataFrame(cm, index=[class_names[i] for i in range(len(class_names))], 
                         columns=[class_names[i] for i in range(len(class_names))])
    plt.figure(figsize=(10, 7))
    sns.heatmap(cm_df, annot=True, fmt='d', cmap='Blues')
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title('Confusion Matrix')
    img_file = os.path.join(img_path, f"confusion_matrix_{run_name}.png")
    cm_file = os.path.join(results_dir, f"confusion_matrix_{run_name}.csv")

    plt.savefig(img_file)
    cm_df.to_csv(cm_file)

    mlflow.log_artifact(img_file)
    mlflow.log_artifact(cm_file)

    plt.close()

    # Save feature maps and labels to file
    if save_fm:
        feature_maps_file = os.path.join(img_path, f"feature_maps_{run_name}.pt")
        torch.save((feature_maps, y_true_test, y_pred_test), feature_maps_file)
        mlflow.log_artifact(feature_maps_file)

    return y_true_test, y_pred_test



# ANN with evoked response

In [17]:
train_loader, val_loader, test_loader, train_distribution, val_distribution, test_distribution = get_loaders_no_sss(train_dir=TRAIN_DIR_EV, test_dir=TEST_DIR_EV, with_val_loader=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
                                        pin_memory=PIN_MEMORY, seq_length=SEQ_LENGTH, model_type="mlp")

Applying this mapping: {'CA': 0, 'CA1': 1, 'Thalamus': 2, 'vM1': 3}
Applying this mapping: {'CA': 0, 'CA1': 1, 'Thalamus': 2, 'vM1': 3}
Train dataset size: 27165
Test dataset size: 9480
Train dataset length: 27165, Val dataset length: 948, Test dataset length: 8532


In [18]:
hidden_layers = [4096,4096,2048,2048,1024,1024,512]
model = Baseline_ANN(INPUT_SIZE, NUM_CLASSES, hidden_layers=hidden_layers, dropout=0.3).to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
model

Baseline_ANN(
  (hidden_layers): ModuleList(
    (0): Linear(in_features=2000, out_features=4096, bias=True)
    (1): Linear(in_features=4096, out_features=4096, bias=True)
    (2): Linear(in_features=4096, out_features=2048, bias=True)
    (3): Linear(in_features=2048, out_features=2048, bias=True)
    (4): Linear(in_features=2048, out_features=1024, bias=True)
    (5): Linear(in_features=1024, out_features=1024, bias=True)
    (6): Linear(in_features=1024, out_features=512, bias=True)
  )
  (dropout): Dropout(p=0.3, inplace=False)
  (activation): ReLU()
  (output_layer): Linear(in_features=512, out_features=4, bias=True)
)

In [19]:
print(f'Model size: {get_model_size(model):.3f} MB')


Model size: 157.314 MB


In [20]:
mlflow.set_experiment(EXPERIMENT_NAME)

<Experiment: artifact_location='mlflow-artifacts:/20', creation_time=1719540405626, experiment_id='20', last_update_time=1719540405626, lifecycle_stage='active', name='IEEG_MODELS_COMP_FINAL_2', tags={}>

In [21]:
# Train and Evaluate the Model with MLflow
run_name = "ANN_with_er_no_ss"
model_name = "ANN_baseline_wr_no_ss"
results_dir = "../models/results"
with mlflow.start_run(run_name=run_name) as run:
    # Log parameters
    mlflow.log_param("epochs", NUM_EPOCHS)
    mlflow.log_param("batch_size", BATCH_SIZE)
    mlflow.log_param("learning_rate", LEARNING_RATE)
    mlflow.log_param("model", model_name)
    mlflow.log_param("input_size", SEQ_LENGTH)
    mlflow.log_param("num_classes", NUM_CLASSES)
    # mlflow.log_dict(dataset.get_class_mapping(), "class_mapping.json")

    # Train and Evaluate the Model
    train_model(model, train_loader,val_loader, optimizer, criterion, NUM_EPOCHS, DEVICE, 
                save_checkpoint_interval=10, checkpoint_dir=CHECKPOINTS_PATH, 
                model_name=model_name, early_stopping_patience=20, cnn=False)
    _,_ = evaluate_model(model, test_loader, class_mapping, DEVICE, 
                            results_dir=results_dir,
                            img_path='../plots', 
                            run_name=run_name,
                            cnn=False)

    # Log the model
    mlflow.pytorch.log_model(model, model_name)

2024/06/27 22:28:21 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.


Epoch 1/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 2/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 3/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 4/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 5/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 6/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 7/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 8/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 9/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 10/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Checkpoint saved successfully.


Epoch 11/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 12/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 13/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 14/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 15/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 16/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 17/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 18/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 19/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 20/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Checkpoint saved successfully.


Epoch 21/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 22/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 23/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 24/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 25/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 26/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 27/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 28/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 29/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Epoch 30/50:   0%|          | 0/424 [00:00<?, ?batch/s]

Checkpoint saved successfully.
Early stopping at epoch 30 due to no improvement in validation loss.


Evaluating:   0%|          | 0/133 [00:00<?, ?batch/s]

Accuracy of the model on the test data: 0.69%
Precision: 0.6803, Recall: 0.6949, F1 Score: 0.6778


2024/06/27 22:34:49 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2024/06/27 22:34:49 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!


# ANN Baseline

In [7]:
train_loader, val_loader, test_loader, train_distribution, val_distribution, test_distribution = get_loaders_no_sss(train_dir=TRAIN_DIR, test_dir=TEST_DIR, with_val_loader=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
                                        pin_memory=PIN_MEMORY, seq_length=SEQ_LENGTH, model_type="mlp")

Applying this mapping: {'CA': 0, 'CA1': 1, 'Thalamus': 2, 'vM1': 3}
Applying this mapping: {'CA': 0, 'CA1': 1, 'Thalamus': 2, 'vM1': 3}
Train dataset size: 21732
Test dataset size: 8304
Train dataset length: 21732, Val dataset length: 830, Test dataset length: 7474


In [8]:
hidden_layers = [4096,4096,2048,2048,1024,1024,512]
model = Baseline_ANN(INPUT_SIZE, NUM_CLASSES, hidden_layers=hidden_layers, dropout=0.3).to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
model

Baseline_ANN(
  (hidden_layers): ModuleList(
    (0): Linear(in_features=2000, out_features=4096, bias=True)
    (1): Linear(in_features=4096, out_features=4096, bias=True)
    (2): Linear(in_features=4096, out_features=2048, bias=True)
    (3): Linear(in_features=2048, out_features=2048, bias=True)
    (4): Linear(in_features=2048, out_features=1024, bias=True)
    (5): Linear(in_features=1024, out_features=1024, bias=True)
    (6): Linear(in_features=1024, out_features=512, bias=True)
  )
  (dropout): Dropout(p=0.3, inplace=False)
  (activation): ReLU()
  (output_layer): Linear(in_features=512, out_features=4, bias=True)
)

In [9]:
print(f'Model size: {get_model_size(model):.3f} MB')


Model size: 157.314 MB


In [10]:
mlflow.set_experiment(EXPERIMENT_NAME)

<Experiment: artifact_location='mlflow-artifacts:/20', creation_time=1719540405626, experiment_id='20', last_update_time=1719540405626, lifecycle_stage='active', name='IEEG_MODELS_COMP_FINAL_2', tags={}>

In [11]:
# Train and Evaluate the Model with MLflow
run_name = "ANN_no_sss"
model_name = "ANN_baseline_no_sss"
results_dir = "../models/results"
with mlflow.start_run(run_name=run_name) as run:
    # Log parameters
    mlflow.log_param("epochs", NUM_EPOCHS)
    mlflow.log_param("batch_size", BATCH_SIZE)
    mlflow.log_param("learning_rate", LEARNING_RATE)
    mlflow.log_param("model", model_name)
    mlflow.log_param("input_size", SEQ_LENGTH)
    mlflow.log_param("num_classes", NUM_CLASSES)
    # mlflow.log_dict(dataset.get_class_mapping(), "class_mapping.json")

    # Train and Evaluate the Model
    train_model(model, train_loader,val_loader, optimizer, criterion, NUM_EPOCHS, DEVICE, 
                save_checkpoint_interval=10, checkpoint_dir=CHECKPOINTS_PATH, 
                model_name=model_name, early_stopping_patience=20, cnn=False)
    _,_ = evaluate_model(model, test_loader, class_mapping, DEVICE, 
                            results_dir=results_dir,
                            img_path='../plots', 
                            run_name=run_name,
                            cnn=False)

    # Log the model
    mlflow.pytorch.log_model(model, model_name)

2024/06/27 22:45:41 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.


Epoch 1/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 2/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 3/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 4/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 5/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 6/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 7/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 8/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 9/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 10/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Checkpoint saved successfully.


Epoch 11/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 12/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 13/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 14/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 15/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 16/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 17/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 18/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 19/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 20/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Checkpoint saved successfully.


Epoch 21/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 22/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 23/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 24/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 25/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 26/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 27/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 28/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 29/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 30/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Checkpoint saved successfully.


Epoch 31/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 32/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 33/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 34/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 35/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Early stopping at epoch 35 due to no improvement in validation loss.


Evaluating:   0%|          | 0/116 [00:00<?, ?batch/s]

Accuracy of the model on the test data: 0.71%
Precision: 0.6945, Recall: 0.7113, F1 Score: 0.6920


2024/06/27 22:51:27 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2024/06/27 22:51:27 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!


# CNN

In [9]:
train_loader, val_loader, test_loader, train_distribution, val_distribution, test_distribution = get_loaders_no_sss(train_dir=TRAIN_DIR, test_dir=TEST_DIR, with_val_loader=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
                                        pin_memory=PIN_MEMORY, seq_length=SEQ_LENGTH, model_type="cnn")

Applying this mapping: {'CA': 0, 'CA1': 1, 'Thalamus': 2, 'vM1': 3}
Applying this mapping: {'CA': 0, 'CA1': 1, 'Thalamus': 2, 'vM1': 3}
Train dataset size: 21732
Test dataset size: 8304
Train dataset length: 21732, Val dataset length: 830, Test dataset length: 7474


In [10]:
train_distribution, val_distribution, test_distribution

({1: 0.09939260077305356,
  0: 0.36333517393705134,
  2: 0.06626173384870238,
  3: 0.4710104914411927},
 {3: 0.40843373493975904,
  0: 0.3686746987951807,
  2: 0.08674698795180723,
  1: 0.13614457831325302},
 {3: 0.4395236820979395,
  1: 0.12938185710462938,
  0: 0.3443938988493444,
  2: 0.0867005619480867})

In [11]:
model = CNN_Ieeg_Model(SEQ_LENGTH, NUM_CLASSES).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
print(f'Model size: {get_model_size(model):.3f} MB')

model

Model size: 252.997 MB


CNN_Ieeg_Model(
  (conv_layers): ModuleList(
    (0): Conv1d(1, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (2): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))
    (3): Conv1d(256, 512, kernel_size=(3,), stride=(1,), padding=(1,))
  )
  (bn_layers): ModuleList(
    (0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (activation): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc_layers): ModuleList(
    (0): Linear(in_features=64000, out_features=1024, bias=True)
    (1): Linear(in_features=1024, o

In [12]:
mlflow.set_experiment(EXPERIMENT_NAME)

2024/06/27 21:06:45 INFO mlflow.tracking.fluent: Experiment with name 'IEEG_MODELS_COMP_FINAL_2' does not exist. Creating a new experiment.


<Experiment: artifact_location='mlflow-artifacts:/20', creation_time=1719540405626, experiment_id='20', last_update_time=1719540405626, lifecycle_stage='active', name='IEEG_MODELS_COMP_FINAL_2', tags={}>

In [13]:
# Train and Evaluate the Model with MLflow
run_name = "CNN_no_sss"
model_name = "CNN_no_sss"
results_dir = "../models/results"
with mlflow.start_run(run_name=run_name) as run:
    # Log parameters
    mlflow.log_param("epochs", NUM_EPOCHS)
    mlflow.log_param("batch_size", BATCH_SIZE)
    mlflow.log_param("learning_rate", LEARNING_RATE)
    mlflow.log_param("model", model_name)
    mlflow.log_param("input_size", SEQ_LENGTH)
    mlflow.log_param("num_classes", NUM_CLASSES)
    # mlflow.log_dict(dataset.get_class_mapping(), "class_mapping.json")

    # Train and Evaluate the Model
    train_model(model, train_loader,val_loader, optimizer, criterion, NUM_EPOCHS, DEVICE, 
                save_checkpoint_interval=10, checkpoint_dir=CHECKPOINTS_PATH, 
                model_name=model_name, early_stopping_patience=20, cnn=True)
    _,_ = evaluate_model(model, test_loader, class_mapping, DEVICE, 
                            results_dir=results_dir,
                            img_path='../plots', 
                            run_name=run_name,
                            cnn=True)

    # Log the model
    mlflow.pytorch.log_model(model, model_name)

2024/06/27 21:06:55 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.


Epoch 1/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 2/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 3/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 4/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 5/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 6/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 7/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 8/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 9/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 10/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Checkpoint saved successfully.


Epoch 11/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 12/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 13/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 14/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 15/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 16/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 17/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 18/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 19/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 20/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Checkpoint saved successfully.


Epoch 21/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 22/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 23/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 24/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 25/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 26/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 27/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 28/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 29/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 30/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Checkpoint saved successfully.


Epoch 31/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 32/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Epoch 33/50:   0%|          | 0/169 [00:00<?, ?batch/s]

Early stopping at epoch 33 due to no improvement in validation loss.


Evaluating:   0%|          | 0/58 [00:00<?, ?batch/s]

Accuracy of the model on the test data: 0.73%
Precision: 0.6499, Recall: 0.7256, F1 Score: 0.6486


2024/06/27 21:13:23 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2024/06/27 21:13:23 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!


# CONV Transformer

In [8]:
train_loader, val_loader, test_loader, train_distribution, val_distribution, test_distribution = get_loaders_no_sss(train_dir=TRAIN_DIR, test_dir=TEST_DIR, with_val_loader=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
                                        pin_memory=PIN_MEMORY, seq_length=SEQ_LENGTH, model_type="cnn")

Applying this mapping: {'CA': 0, 'CA1': 1, 'Thalamus': 2, 'vM1': 3}
Applying this mapping: {'CA': 0, 'CA1': 1, 'Thalamus': 2, 'vM1': 3}
Train dataset size: 90550
Test dataset size: 34600
Train dataset length: 90550, Val dataset length: 3460, Test dataset length: 31140


In [9]:
test_distribution, train_distribution, val_distribution

({3: 0.4363840719332049,
  0: 0.34858702633269106,
  1: 0.12947976878612716,
  2: 0.08554913294797688},
 {1: 0.09939260077305356,
  0: 0.36333517393705134,
  2: 0.06626173384870238,
  3: 0.4710104914411927},
 {1: 0.1352601156069364,
  3: 0.4367052023121387,
  2: 0.09710982658959537,
  0: 0.3309248554913295})

In [25]:
for loader_name, loader in zip(['train_loader', 'test_loader','val_loader'], [train_loader, test_loader, val_loader]):
    for i, (inputs, labels) in enumerate(loader):
        print(f"{loader_name} - Batch {i}: inputs shape = {inputs.shape}, labels shape = {labels.shape}")
        if i == 0:  # Only print the first batch for brevity
            break

train_loader - Batch 0: inputs shape = torch.Size([256, 1, 2000]), labels shape = torch.Size([256])
test_loader - Batch 0: inputs shape = torch.Size([256, 1, 2000]), labels shape = torch.Size([256])
val_loader - Batch 0: inputs shape = torch.Size([256, 1, 2000]), labels shape = torch.Size([256])


In [17]:
# Model parameters
input_size = SEQ_LENGTH  # Use the sequence length provided by your dataset
num_classes = 4  # Number of classes for classification
transformer_dim = 256  # Smaller transformer dimension
num_heads = 4  # Fewer attention heads
transformer_depth = 2  # Fewer transformer layers
fc_neurons = [1024, 256]  # Reduced fully connected layer sizes
fc_transformer = 128
dropout = 0.3  # Dropout rate

In [15]:
del model, criterion, optimizer

In [18]:
model = ConvTransformerModel(
    input_size=input_size,
    num_classes=num_classes,
    transformer_dim=transformer_dim,
    num_heads=num_heads,
    transformer_depth=transformer_depth,
    fc_neurons=fc_neurons,
    fc_transformer=fc_transformer,
    dropout=dropout,
    activation=nn.GELU()
).to(DEVICE)

In [19]:
print(f'Model size: {get_model_size(model):.3f} MB')
model

Model size: 8.792 MB


ConvTransformerModel(
  (conv_embedding_stem): ConvEmbeddingStem(
    (conv1): Conv1d(1, 128, kernel_size=(10,), stride=(2,), padding=(4,), bias=False)
    (act1): GELU(approximate='none')
    (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout1): Dropout(p=0.3, inplace=False)
    (conv2): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
    (act2): GELU(approximate='none')
    (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout2): Dropout(p=0.3, inplace=False)
    (conv3): Conv1d(256, 256, kernel_size=(3,), stride=(2,), padding=(1,), bias=False)
    (act3): GELU(approximate='none')
    (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout3): Dropout(p=0.3, inplace=False)
  )
  (transformer_blocks): ModuleList(
    (0-1): 2 x MultiheadSelfAttentionBlock(
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=Tr

In [20]:
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion =  nn.CrossEntropyLoss()  

In [21]:
mlflow.set_experiment(EXPERIMENT_NAME)

<Experiment: artifact_location='mlflow-artifacts:/20', creation_time=1719540405626, experiment_id='20', last_update_time=1719540405626, lifecycle_stage='active', name='IEEG_MODELS_COMP_FINAL_2', tags={}>

In [22]:
# Train and Evaluate the Model with MLflow
run_name = "run_CNN_TNN_sl2000_2_no_sss"
model_name = "CNN_TNN_sl2000_no_sss"
results_dir = "../models/results"
with mlflow.start_run(run_name=run_name) as run:
    # Log parameters
    mlflow.log_param("epochs", NUM_EPOCHS)
    mlflow.log_param("batch_size", BATCH_SIZE)
    mlflow.log_param("learning_rate", LEARNING_RATE)
    mlflow.log_param("model", model_name)
    mlflow.log_param("input_size", SEQ_LENGTH)
    mlflow.log_param("num_classes", NUM_CLASSES)
    # mlflow.log_dict(dataset.get_class_mapping(), "class_mapping.json")

    # Train and Evaluate the Model
    train_model(model, train_loader,val_loader, optimizer, criterion, NUM_EPOCHS, DEVICE, 
                save_checkpoint_interval=10, checkpoint_dir=CHECKPOINTS_PATH, 
                model_name=model_name, early_stopping_patience=20, cnn=False)
    _,_ = evaluate_model(model, test_loader, class_mapping, DEVICE, 
                            results_dir=results_dir,
                            img_path='../plots', 
                            run_name=run_name,
                            cnn=False)

    # Log the model
    mlflow.pytorch.log_model(model, model_name)

2024/06/28 00:45:18 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.


Epoch 1/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 2/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 3/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 4/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 5/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 6/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 7/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 8/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 9/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 10/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Checkpoint saved successfully.


Epoch 11/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 12/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 13/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 14/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 15/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 16/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 17/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 18/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 19/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 20/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Checkpoint saved successfully.


Epoch 21/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 22/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 23/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 24/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 25/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 26/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 27/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 28/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 29/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 30/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Checkpoint saved successfully.


Epoch 31/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 32/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 33/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 34/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 35/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Epoch 36/50:   0%|          | 0/353 [00:00<?, ?batch/s]

Early stopping at epoch 36 due to no improvement in validation loss.


Evaluating:   0%|          | 0/121 [00:00<?, ?batch/s]

Accuracy of the model on the test data: 0.76%
Precision: 0.7041, Recall: 0.7581, F1 Score: 0.6739


2024/06/28 01:12:09 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2024/06/28 01:12:09 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!


# CNN + LSTM Parallel

In [7]:
train_loader, val_loader, test_loader, train_distribution, val_distribution, test_distribution = get_loaders_no_sss(train_dir=TRAIN_DIR, test_dir=TEST_DIR, with_val_loader=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
                                        pin_memory=PIN_MEMORY, seq_length=SEQ_LENGTH, model_type="cnn")

Applying this mapping: {'CA': 0, 'CA1': 1, 'Thalamus': 2, 'vM1': 3}
Applying this mapping: {'CA': 0, 'CA1': 1, 'Thalamus': 2, 'vM1': 3}
Train dataset size: 90550
Test dataset size: 34600
Train dataset length: 90550, Val dataset length: 3460, Test dataset length: 31140


In [8]:
conv_filters = [64, 128, 256, 512]
lstm_hidden_size = 256
lstm_num_layers = 8
dropout = 0.5
batch_size = 256
fc_neurons1 = 1024
fc_neurons2 = 512
activation = nn.GELU()

In [18]:
model = ParallelCNNLSTMModel(
            input_size=SEQ_LENGTH,
            input_size_lstm=1,
            num_classes=4,
            conv_filters=conv_filters,
            lstm_hidden_size=lstm_hidden_size,
            lstm_num_layers=lstm_num_layers,
            fc_neurons=[fc_neurons1, fc_neurons2],
            dropout=dropout,
            activation=activation
        ).to(DEVICE)
print(f'Model size: {get_model_size(model):.3f} MB')
model

Model size: 81.693 MB


ParallelCNNLSTMModel(
  (cnn_head): CNN_Head(
    (conv_layers): ModuleList(
      (0): Conv1d(1, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
      (2): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))
      (3): Conv1d(256, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    )
    (bn_layers): ModuleList(
      (0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (activation): GELU(approximate='none')
    (dropout): Dropout(p=0.5, inplace=False)
    (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (lstm): LSTM(1, 256, num_layers=8, batch_fi

In [19]:
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion =  nn.CrossEntropyLoss() 

In [20]:
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
if experiment is None:
    experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
else:
    experiment_id = experiment.experiment_id
    if experiment.lifecycle_stage == 'deleted':
        mlflow.tracking.MlflowClient().restore_experiment(experiment_id)

In [21]:
mlflow.set_experiment(EXPERIMENT_NAME)

<Experiment: artifact_location='mlflow-artifacts:/20', creation_time=1719540405626, experiment_id='20', last_update_time=1719540405626, lifecycle_stage='active', name='IEEG_MODELS_COMP_FINAL_2', tags={}>

In [22]:

# Train and Evaluate the Model with MLflow
run_name = "P_CNN_LSTM_no_sss_2000"
model_name = "P_CNN_LSTM_no_sss_2000"
results_dir = "../models/results"
with mlflow.start_run(run_name=run_name) as run:
    # Log parameters
    mlflow.log_param("epochs", NUM_EPOCHS)
    mlflow.log_param("batch_size", BATCH_SIZE)
    mlflow.log_param("learning_rate", LEARNING_RATE)
    mlflow.log_param("model", model_name)
    mlflow.log_param("input_size", SEQ_LENGTH)
    mlflow.log_param("num_classes", NUM_CLASSES)
    # mlflow.log_dict(dataset.get_class_mapping(), "class_mapping.json")

    # Train and Evaluate the Model
    train_model(model, train_loader,val_loader, optimizer, criterion, NUM_EPOCHS, DEVICE, 
                save_checkpoint_interval=10, checkpoint_dir=CHECKPOINTS_PATH, 
                model_name=model_name, early_stopping_patience=20, cnn=False)
    _,_ = evaluate_model(model, test_loader, class_mapping, DEVICE, 
                            results_dir=results_dir,
                            img_path='../plots', 
                            run_name=run_name,
                            cnn=False)

    # Log the model
    mlflow.pytorch.log_model(model, model_name)

2024/06/28 01:47:22 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.


Epoch 1/50:   0%|          | 0/353 [00:00<?, ?batch/s]

2024/06/28 01:47:45 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2024/06/28 01:47:45 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!


KeyboardInterrupt: 

# LSTM 

In [8]:
train_loader, val_loader, test_loader, train_distribution, val_distribution, test_distribution = get_loaders_no_sss(train_dir=TRAIN_DIR, test_dir=TEST_DIR, with_val_loader=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
                                        pin_memory=PIN_MEMORY, seq_length=SEQ_LENGTH, model_type="seq")

Applying this mapping: {'CA': 0, 'CA1': 1, 'Thalamus': 2, 'vM1': 3}
Applying this mapping: {'CA': 0, 'CA1': 1, 'Thalamus': 2, 'vM1': 3}
Train dataset size: 21732
Test dataset size: 8304
Train dataset length: 21732, Val dataset length: 830, Test dataset length: 7474


In [9]:
for loader_name, loader in zip(['train_loader', 'test_loader','val_loader'], [train_loader, test_loader, val_loader]):
    for i, (inputs, labels) in enumerate(loader):
        print(f"{loader_name} - Batch {i}: inputs shape = {inputs.shape}, labels shape = {labels.shape}")
        if i == 0:  # Only print the first batch for brevity
            break

train_loader - Batch 0: inputs shape = torch.Size([64, 2000, 1]), labels shape = torch.Size([64])
test_loader - Batch 0: inputs shape = torch.Size([64, 2000, 1]), labels shape = torch.Size([64])
val_loader - Batch 0: inputs shape = torch.Size([64, 2000, 1]), labels shape = torch.Size([64])


In [10]:
model = LSTM_Ieeg_Model(device=DEVICE,
                        input_size=1, 
                        num_classes=NUM_CLASSES, 
                        lstm_layers=8, 
                        lstm_h_size=128,
                        fc_neurons=[1024,256], 
                        bidirectional=False).to(DEVICE)
print(f'Model size: {get_model_size(model):.3f} MB')
model

Model size: 5.292 MB


LSTM_Ieeg_Model(
  (activation): ReLU()
  (lstm): LSTM(1, 128, num_layers=8, batch_first=True, dropout=0.1)
  (fc_layers): ModuleList(
    (0): Linear(in_features=128, out_features=1024, bias=True)
    (1): Linear(in_features=1024, out_features=256, bias=True)
  )
  (output_layer): Linear(in_features=256, out_features=4, bias=True)
  (dropout_layer): Dropout(p=0.1, inplace=False)
)

In [11]:
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion =  nn.CrossEntropyLoss()  

In [12]:
mlflow.set_experiment(EXPERIMENT_NAME)

<Experiment: artifact_location='mlflow-artifacts:/20', creation_time=1719540405626, experiment_id='20', last_update_time=1719540405626, lifecycle_stage='active', name='IEEG_MODELS_COMP_FINAL_2', tags={}>

In [13]:
run_name = "run_LSTM_2000_no_sss"
model_name = "LSTM_2000_no_sss"
results_dir = "../models/results"
with mlflow.start_run(run_name=run_name) as run:
    # Log parameters
    mlflow.log_param("epochs", NUM_EPOCHS)
    mlflow.log_param("batch_size", BATCH_SIZE)
    mlflow.log_param("learning_rate", LEARNING_RATE)
    mlflow.log_param("model", model_name)
    mlflow.log_param("input_size", SEQ_LENGTH)
    mlflow.log_param("num_classes", NUM_CLASSES)
    # mlflow.log_dict(dataset.get_class_mapping(), "class_mapping.json")

    # Train and Evaluate the Model
    train_model(model, train_loader,val_loader, optimizer, criterion, NUM_EPOCHS, DEVICE, 
                save_checkpoint_interval=10, checkpoint_dir=CHECKPOINTS_PATH, 
                model_name=model_name, early_stopping_patience=20, cnn=False)
    _,_ = evaluate_model(model, test_loader, class_mapping, DEVICE, 
                            results_dir=results_dir,
                            img_path='../plots', 
                            run_name=run_name,
                            cnn=False)

    # Log the model
    mlflow.pytorch.log_model(model, model_name)

2024/06/27 21:35:33 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.


Epoch 1/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 2/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 3/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 4/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 5/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 6/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 7/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 8/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 9/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 10/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Checkpoint saved successfully.


Epoch 11/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 12/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 13/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 14/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 15/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 16/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 17/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 18/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 19/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 20/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Checkpoint saved successfully.


Epoch 21/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 22/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 23/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 24/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Epoch 25/50:   0%|          | 0/339 [00:00<?, ?batch/s]

Early stopping at epoch 25 due to no improvement in validation loss.


Evaluating:   0%|          | 0/116 [00:00<?, ?batch/s]

Accuracy of the model on the test data: 0.44%
Precision: 0.1912, Recall: 0.4372, F1 Score: 0.2660


2024/06/27 21:58:43 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2024/06/27 21:58:43 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
