In [1]:
import mlflow
import mlflow.pytorch
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix
from sklearn.model_selection import StratifiedShuffleSplit
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader, Subset
from tqdm.notebook import tqdm
import os
from sklearn.preprocessing import LabelEncoder
import torch
import math 

In [2]:
mlflow.set_tracking_uri("http://mlflow:5000")

print('tracking uri:', mlflow.get_tracking_uri())

tracking uri: http://mlflow:5000


# More Complex ANN

In [66]:
# Configuration
DATA_DIR = '../data/data_normalized_exp2'
SEQ_LENGTH = 5000
BATCH_SIZE = 64
NUM_EPOCHS = 200
LEARNING_RATE = 0.0005
EXPERIMENT_NAME = "IEEG_Classification_Baseline"
RUN_NAME = "ANN-MoreComplex_NO_Response"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_SIZE = SEQ_LENGTH

In [9]:
class IeegDataset(Dataset):
    def __init__(self, data_dir, seq_length=5000):
        self.data_dir = data_dir
        self.signals = os.listdir(self.data_dir)
        self.seq_length = seq_length
        self.data = []
        self.labels = []
        
        self.classes = [f.split('_')[0] for f in self.signals]

        self.label_encoder = LabelEncoder()
        self.label_encoder.fit(self.classes)

        for file in self.signals:
            file_path = os.path.join(data_dir, file)
            df = pd.read_csv(file_path)

            for column in df.columns: 
                for idx in range(0,math.floor(df.shape[0]/self.seq_length)):
                    # signal_1 = test[column].values[i*5000:i*5000+5000]

                    signal_window = df[column].values[idx*self.seq_length:idx*self.seq_length+self.seq_length]

    
                    class_label = self.label_encoder.transform([file.split('_')[0]])
                    self.data.append(signal_window)
                    self.labels.append(class_label)
        
        self.data = torch.tensor(self.data, dtype=torch.float32)
        self.labels = torch.tensor(self.labels, dtype=torch.long)

    def __len__(self):
        return(len(self.data))
    def __getitem__(self, index) -> torch.tensor:
        return self.data[index] , self.labels[index]
    def get_class_mapping(self):
        return {i: class_name for i, class_name in enumerate(self.label_encoder.classes_)}

In [26]:
# Define Model
class Deep_ANN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(Deep_ANN, self).__init__()
        self.fc1 = nn.Linear(input_size, 2048)
        self.fc2 = nn.Linear(2048, 1024)
        self.fc3 = nn.Linear(1024, 256)
        self.fc4 = nn.Linear(256, 64)
        self.fc5 = nn.Linear(64, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.relu(self.fc3(x))
        x = self.dropout(x)
        x = self.relu(self.fc4(x))
        x = self.dropout(x)
        x = self.fc5(x)
        return x

In [10]:
# Data Loaders and Partitioning
def create_data_loaders(dataset, batch_size):
    labels = np.array([dataset[i][1].item() for i in range(len(dataset))])
    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
    train_index, test_index = next(sss.split(np.zeros(len(labels)), labels))

    train_dataset = Subset(dataset, train_index)
    test_dataset = Subset(dataset, test_index)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

In [11]:
# Plotting Functions
def plot_pie_chart(counts, class_names, title):
    fig, ax = plt.subplots()
    ax.pie(counts, labels=class_names, autopct='%1.1f%%', startangle=90, counterclock=False)
    ax.axis('equal')
    plt.title(title)
    return fig

In [21]:
# Training Function
def train_model(model, train_loader, optimizer, criterion, num_epochs, device):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        y_true_train = []
        y_pred_train = []

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels.squeeze())
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            y_true_train.extend(labels.squeeze().cpu().numpy())
            y_pred_train.extend(predicted.cpu().numpy())

            avg_loss = running_loss / len(train_loader)
            train_accuracy = accuracy_score(y_true_train, y_pred_train)
            precision, recall, f1, _ = precision_recall_fscore_support(y_true_train, y_pred_train, average='weighted', zero_division=0)

            progress_bar.set_postfix(loss=avg_loss, accuracy=train_accuracy, precision=precision, recall=recall, f1=f1)

        mlflow.log_metric("train_loss", avg_loss, step=epoch)
        mlflow.log_metric("train_accuracy", train_accuracy, step=epoch)
        mlflow.log_metric("train_precision", precision, step=epoch)
        mlflow.log_metric("train_recall", recall, step=epoch)
        mlflow.log_metric("train_f1", f1, step=epoch)

# Evaluation Function
def evaluate_model(model, test_loader, dataset, device, img_path, run_name):
    model.eval()
    y_true_test = []
    y_pred_test = []

    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Evaluating", unit="batch"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            y_true_test.extend(labels.squeeze().cpu().numpy())
            y_pred_test.extend(predicted.cpu().numpy())

    test_accuracy = accuracy_score(y_true_test, y_pred_test)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true_test, y_pred_test, average='weighted', zero_division=0)

    print(f'Accuracy of the model on the test data: {test_accuracy:.2f}%')
    print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')

    mlflow.log_metric("test_accuracy", test_accuracy)
    mlflow.log_metric("test_precision", precision)
    mlflow.log_metric("test_recall", recall)
    mlflow.log_metric("test_f1", f1)

    # Confusion matrix
    cm = confusion_matrix(y_true_test, y_pred_test)
    cm_df = pd.DataFrame(cm, index=dataset.label_encoder.classes_, columns=dataset.label_encoder.classes_)

    plt.figure(figsize=(10, 7))
    sns.heatmap(cm_df, annot=True, fmt='d', cmap='Blues')
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title('Confusion Matrix')
    img_file = os.path.join(img_path, f"confusion_matrix_{run_name}.png")
    plt.savefig(img_file)
    mlflow.log_artifact(img_file)
    plt.close()

In [11]:
dataset = IeegDataset(DATA_DIR, SEQ_LENGTH)
train_loader, test_loader = create_data_loaders(dataset, BATCH_SIZE)
NUM_CLASSES = len(dataset.label_encoder.classes_)



  self.data = torch.tensor(self.data, dtype=torch.float32)


In [27]:
model = Deep_ANN(INPUT_SIZE, NUM_CLASSES).to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
model

Deep_ANN(
  (fc1): Linear(in_features=5000, out_features=2048, bias=True)
  (fc2): Linear(in_features=2048, out_features=1024, bias=True)
  (fc3): Linear(in_features=1024, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=64, bias=True)
  (fc5): Linear(in_features=64, out_features=5, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.1, inplace=False)
)

In [28]:
# Train and evaluate the model while tracking with MLflow
# Ensure the experiment is created or active
experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
if experiment is None:
    experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
else:
    experiment_id = experiment.experiment_id
    if experiment.lifecycle_stage == 'deleted':
        mlflow.tracking.MlflowClient().restore_experiment(experiment_id)

# Start MLflow experiment
mlflow.set_experiment(EXPERIMENT_NAME)
with mlflow.start_run(run_name=RUN_NAME) as run:
    # Log parameters
    mlflow.log_param("epochs", NUM_EPOCHS)
    mlflow.log_param("batch_size", BATCH_SIZE)
    mlflow.log_param("learning_rate", LEARNING_RATE)
    mlflow.log_param("model", "Deep_ANN")
    mlflow.log_param("input_size", INPUT_SIZE)
    mlflow.log_param("num_classes", NUM_CLASSES)
    mlflow.log_param("dropout", 0.1)

    mlflow.log_dict(dataset.get_class_mapping(), "class_mapping.json")

    print("Training in: {}".format(DEVICE))
    train_model(model, train_loader, optimizer, criterion, NUM_EPOCHS, device=DEVICE)
    evaluate_model(model, test_loader, dataset,  DEVICE, img_path='../plots', run_name=RUN_NAME)

    # Log the model
    mlflow.pytorch.log_model(model, "model")

Training in: cuda


Epoch 1/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 2/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 3/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 4/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 5/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 6/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 7/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 8/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 9/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 10/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 11/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 12/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 13/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 14/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 15/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 16/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 17/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 18/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 19/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 20/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 21/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 22/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 23/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 24/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 25/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 26/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 27/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 28/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 29/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 30/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 31/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 32/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 33/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 34/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 35/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 36/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 37/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 38/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 39/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 40/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 41/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 42/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 43/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 44/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 45/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 46/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 47/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 48/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 49/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 50/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 51/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 52/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 53/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 54/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 55/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 56/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 57/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 58/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 59/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 60/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 61/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 62/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 63/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 64/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 65/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 66/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 67/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 68/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 69/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 70/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 71/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 72/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 73/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 74/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 75/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 76/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 77/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 78/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 79/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 80/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 81/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 82/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 83/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 84/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 85/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 86/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 87/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 88/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 89/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 90/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 91/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 92/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 93/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 94/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 95/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 96/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 97/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 98/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 99/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 100/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 101/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 102/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 103/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 104/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 105/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 106/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 107/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 108/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 109/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 110/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 111/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 112/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 113/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 114/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 115/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 116/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 117/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 118/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 119/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 120/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 121/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 122/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 123/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 124/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 125/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 126/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 127/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 128/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 129/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 130/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 131/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 132/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 133/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 134/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 135/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 136/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 137/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 138/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 139/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 140/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 141/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 142/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 143/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 144/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 145/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 146/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 147/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 148/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 149/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 150/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 151/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 152/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 153/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 154/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 155/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 156/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 157/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 158/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 159/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 160/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 161/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 162/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 163/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 164/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 165/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 166/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 167/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 168/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 169/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 170/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 171/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 172/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 173/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 174/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 175/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 176/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 177/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 178/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 179/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 180/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 181/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 182/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 183/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 184/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 185/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 186/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 187/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 188/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 189/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 190/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 191/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 192/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 193/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 194/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 195/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 196/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 197/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 198/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 199/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Epoch 200/200:   0%|          | 0/82 [00:00<?, ?batch/s]

Evaluating:   0%|          | 0/21 [00:00<?, ?batch/s]

Accuracy of the model on the test data: 0.69%
Precision: 0.6838, Recall: 0.6908, F1 Score: 0.6861




# CONV + BiLSTM 

In [77]:
torch.cuda.empty_cache()

In [51]:
# Configuration
DATA_DIR = '../data/data_normalized_exp2'
SEQ_LENGTH = 5000
BATCH_SIZE = 32
NUM_EPOCHS = 40
LEARNING_RATE = 0.00005
EXPERIMENT_NAME = "IEEG_Classification_CONV+BiLSTM"
RUN_NAME = "CONV+BiLSTM_First"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_SIZE = SEQ_LENGTH

In [12]:
dataset = IeegDataset(DATA_DIR, SEQ_LENGTH)
train_loader, test_loader = create_data_loaders(dataset, BATCH_SIZE)
NUM_CLASSES = len(dataset.label_encoder.classes_)


  self.data = torch.tensor(self.data, dtype=torch.float32)


In [38]:
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, input_channels, conv_neurons):
        super(ConvBlock, self).__init__()
        self.conv_layers = nn.ModuleList()
        in_channels = input_channels

        for out_channels in conv_neurons:
            self.conv_layers.append(nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1))
            in_channels = out_channels
        
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        
    def forward(self, x):
        for conv in self.conv_layers:
            x = self.pool(self.relu(conv(x)))
        return x

class ConvBiLSTM(nn.Module):
    def __init__(self, input_channels, input_size, conv_neurons, hidden_size, num_layers, num_classes):
        super(ConvBiLSTM, self).__init__()
        self.conv_block = ConvBlock(input_channels, conv_neurons)
        
        # Calculate the output size after the convolutional block
        conv_output_size = input_size // (2**len(conv_neurons))  # Pooling reduces the size by 2 per layer
        lstm_input_size = conv_neurons[-1] * conv_output_size  # Number of channels from the last conv layer
        
        self.bilstm = nn.LSTM(input_size=lstm_input_size, hidden_size=hidden_size, num_layers=num_layers, 
                              bidirectional=False, batch_first=True)
        self.fc = nn.Linear(hidden_size * 2, num_classes)  # 2 for bidirectional
        
    def forward(self, x):
        x = self.conv_block(x)
        
        # Flatten the output of the conv_block correctly
        x = x.permute(0, 2, 1)  # Change from (batch_size, channels, seq_length) to (batch_size, seq_length, channels)
        x = x.contiguous().view(x.size(0), -1, x.size(2) * x.size(1))  # Ensure the shape matches LSTM input requirements
        
        # Apply biLSTM
        x, _ = self.bilstm(x)
        
        # Only take the output of the last LSTM cell
        x = self.fc(x[:, -1, :])
        return x

# Example usage
input_channels = 1  # Number of input channels, e.g., 1 for single-channel iEEG signals
input_size = 5000  # Sequence length
conv_neurons = [64, 64]  # Number of neurons for each conv layer
hidden_size = 64  # Hidden size for LSTM
num_layers = 4  # Number of LSTM layers
num_classes = 5  # Number of output classes



# OTHER MODEL

In [46]:
import torch
import math
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath



class TransformerStem(nn.Module):
    def __init__(self, in_channels=1, out_channels=16):
        super().__init__()

        self.conv1 = nn.Conv1d(in_channels, out_channels//2, kernel_size=10, stride=2, bias=False, padding=4)
        self.act1 = nn.GELU()
        self.bn1 =  nn.BatchNorm1d(out_channels//2)

        self.conv2 = nn.Conv1d(out_channels//2, out_channels, kernel_size=3, stride=1, bias=False,padding=1)
        self.act2 = nn.GELU()
        self.bn2 =  nn.BatchNorm1d(out_channels)

        self.conv3 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=2, bias=False,padding=1)
        self.act3 = nn.GELU()
        self.bn3 =  nn.BatchNorm1d(out_channels)

        self.init_weight()

    def init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.bn1(self.act1(self.conv1(x)))
        x = self.bn2(self.act2(self.conv2(x)))
        x = self.bn3(self.act3(self.conv3(x)))

        return x

def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

class SELayer(nn.Module):
    def __init__(self, inp, oup, reduction=4):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
                nn.Linear(oup, _make_divisible(inp // reduction, 8)),
                nn.SiLU(),
                nn.Linear(_make_divisible(inp // reduction, 8), oup),
                nn.Sigmoid()
        )
 
    def forward(self, x):
        b, c, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1)
        return x * y


class MBConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expand_ratio, fused):
        super(MBConv, self).__init__()
        assert stride in [1, 2]
        hidden_dim = round(in_channels * expand_ratio)
        self.identity = stride == 1 and in_channels == out_channels
        if fused:
            self.conv = nn.Sequential(
                # Fused-MBConv
                nn.Conv1d(in_channels, hidden_dim, 3, stride, 1, bias=False),
                nn.BatchNorm1d(hidden_dim),
                nn.SiLU(),
                SELayer(in_channels, hidden_dim),
                # pw-linear
                nn.Conv1d(hidden_dim, out_channels, 1, 1, 0, bias=False),
                nn.BatchNorm1d(out_channels),
            )
        else:
             self.conv = nn.Sequential(
                nn.Conv1d(in_channels, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm1d(hidden_dim),
                nn.SiLU(),
                nn.Conv1d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm1d(hidden_dim),
                nn.SiLU(),
                SELayer(in_channels, hidden_dim),
                nn.Conv1d(hidden_dim, out_channels, 1, 1, 0, bias=False),
                nn.BatchNorm1d(out_channels),
            )
 
    def forward(self, x):
        if self.identity:
            return x + self.conv(x)
        else:
            return self.conv(x)


class ConvEmbedding(nn.Module):
    def __init__(self,in_channels,out_channels,depths,stage=1):
        super(ConvEmbedding, self).__init__()
        self.stage = stage
        
        self.mbconvs = nn.ModuleList()
        if self.stage == 1:
            self.mbconvs.append(MBConv(in_channels,out_channels,stride=2, expand_ratio=2, fused=True))
            self.mbconvs.append(MBConv(out_channels,out_channels,stride=1, expand_ratio=2, fused=True))
        else:
            self.mbconvs.append(MBConv(in_channels,out_channels,stride=2, expand_ratio=2, fused=False))
            for _ in range(depths-1):
                self.mbconvs.append(MBConv(out_channels,out_channels,stride=1, expand_ratio=2, fused=False))

        self.norm = nn.LayerNorm(out_channels)
        self.proj = nn.Conv1d(out_channels, out_channels, 1, 1, 0, bias=False)

    def forward(self, x):
        for _, mbconvs in enumerate(self.mbconvs):
            x = mbconvs(x)
        x = self.norm(self.proj(x).transpose(1,2)).transpose(1,2)
    
        return x


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def conv_separable(in_channels, out_channels, stride):
        return nn.Sequential(
            nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False),
            nn.BatchNorm1d(in_channels),
            nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm1d(out_channels)
        )


class DWCONV(nn.Module):

    def __init__(self, in_channels, out_channels, stride = 1):
        super(DWCONV, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size = 3,
            stride = stride, padding = 1, groups = in_channels, bias = False)
        self.gelu1 = nn.ReLU() 
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.init_weight()

    def init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv1(x)
        x = self.gelu1(x)
        result = self.bn1(x)
        return result

class LIL(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(LIL, self).__init__()
        self.DWConv = DWCONV(in_channels, out_channels)

    def forward(self, x):
        result = self.DWConv(x) + x
        return result

class RFFN(nn.Module):

    def __init__(self, in_channels, R):
        super(RFFN, self).__init__()
        exp_channels = int(in_channels * R)
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels, exp_channels, kernel_size = 1),
            nn.BatchNorm1d(exp_channels),
            nn.ReLU()
        )

        self.dwconv = nn.Sequential(
            DWCONV(exp_channels, exp_channels),
            nn.BatchNorm1d(exp_channels),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv1d(exp_channels, in_channels, 1),
            nn.BatchNorm1d(in_channels)
        )

    def forward(self, x):
        result = x + self.conv2(self.dwconv(self.conv1(x)))
        return result


class DANE(nn.Module):
    def __init__(self, channel, reduction=16):
        super(DANE, self).__init__()
        self.channel = channel
        self.fc_spatial = nn.Sequential(
            nn.LayerNorm(channel),
            nn.Linear(channel, 1, bias=False),
        )
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc_channel = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.SiLU(inplace=True),
            nn.LayerNorm(channel//reduction),
            nn.Linear(channel // reduction, channel, bias=False),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x_channel,x_spatial):
        #B L C
        x_spatial_mask = self.fc_spatial(x_spatial.transpose(1,2)).transpose(1,2) # B 1 L
        x_channel_mask = self.fc_channel(self.avg_pool(x_channel).transpose(1,2)).transpose(1,2) # B C 1
        x_mask = self.sigmoid(x_spatial_mask.expand_as(x_spatial) + x_channel_mask.expand_as(x_spatial))
        return x_spatial * x_mask + x_channel * (1 - x_mask)



class ConvBlock(nn.Module):

    def __init__(self, inplanes, stride, groups=1,norm_layer=nn.BatchNorm1d):
        super(ConvBlock, self).__init__()
        self.inplanes=inplanes
        self.stride = stride
        self.conv1x1_1 = nn.Sequential(nn.Conv1d(inplanes,inplanes,kernel_size=1,stride=1,padding=0,groups=groups,bias=False),
                                        norm_layer(inplanes),
                                        nn.SiLU(inplace=True),
                                        nn.Conv1d(inplanes,inplanes,kernel_size=3,stride=1,padding=1,groups=inplanes,bias=False),
                                        norm_layer(inplanes),
                                        nn.SiLU(inplace=True),
                                        nn.Conv1d(inplanes,inplanes,kernel_size=3,stride=1,padding=1,groups=inplanes,bias=False),
                                        norm_layer(inplanes),
                                        nn.SiLU(inplace=True)
                                        )

        self.conv1 = nn.Sequential(nn.Conv1d(inplanes,inplanes,kernel_size=3,stride=stride,padding=1,groups=inplanes,bias=False),
                                    norm_layer(inplanes),
                                    nn.SiLU(inplace=True)
                                    )
        self.conv1x1_2 = nn.Sequential(nn.Conv1d(inplanes,inplanes,kernel_size=1,stride=1,padding=0,groups=groups,bias=False),
                                        norm_layer(inplanes),
                                        nn.SiLU(inplace=True)
                                        )

    def forward(self, x):
        out = self.conv1x1_1(x)
        x_out = out
        out = self.conv1(out)
        out = self.conv1x1_2(out)
        return x_out,out


class SMHSA(nn.Module):

    def __init__(self, channels, d_k, d_v, stride, heads, dropout,qkv_bias=False,attn_drop=0., proj_drop=0.):
        super(SMHSA, self).__init__()
        self.dwconv_k = DWCONV(channels, channels, stride = stride)
        self.dwconv_v = DWCONV(channels, channels, stride = stride)
        self.fc_q = nn.Linear(channels, heads * d_k, bias=qkv_bias)
        self.fc_k = nn.Linear(channels, heads * d_k, bias=qkv_bias)
        self.fc_v = nn.Linear(channels, heads * d_v, bias=qkv_bias)
        self.fc_o = nn.Linear(heads * d_k, channels)

        self.channels = channels
        self.d_k = d_k
        self.d_v = d_v
        self.stride = stride
        self.heads = heads
        self.dropout = dropout
        self.scaled_factor = self.d_k ** -0.5
        self.num_patches = (self.d_k // self.stride) ** 2

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        b, c, l = x.shape

        x_reshape = x.permute(0, 2, 1) 

        # Get q, k, v
        q = self.fc_q(x_reshape)
        q = q.view(b, l, self.heads, self.d_k).permute(0, 2, 1, 3).contiguous()  

        k = self.dwconv_k(x)
        k_b, k_c, k_l = k.shape
        k = k.view(k_b, k_c, k_l).permute(0, 2, 1).contiguous()
        k = self.fc_k(k)
        k = k.view(k_b, k_l, self.heads, self.d_k).permute(0, 2, 1, 3).contiguous()  

        v = self.dwconv_v(x)
        v_b, v_c, v_l = v.shape
        v = v.view(v_b, v_c, v_l).permute(0, 2, 1).contiguous()
        v = self.fc_v(v)
        v = v.view(v_b, v_l, self.heads, self.d_v).permute(0, 2, 1, 3).contiguous() 

        attn = torch.einsum('... i d, ... j d -> ... i j', q, k) * self.scaled_factor
        attn = torch.softmax(attn, dim = -1) 

        attn = self.attn_drop(attn)

        result = torch.matmul(attn, v).permute(0, 2, 1, 3)
        result = result.contiguous().view(b, l, self.heads * self.d_v)
        result = self.fc_o(result).view(b, self.channels, -1)
        result = self.proj_drop(result)
        return result



class MyFormerBlock(nn.Module):
    def __init__(self,dim,d_k,num_heads, stride,mlp_ratio=4., 
                 qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,R=1):
        super(MyFormerBlock,self).__init__()  

        self.dim = dim
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        
        self.norm1 = norm_layer(dim)
        self.attn = SMHSA(dim, d_k, d_k, stride, num_heads, 0.0, qkv_bias=qkv_bias,attn_drop=attn_drop, proj_drop=drop)

        self.ffn = RFFN(dim, R)

        self.select = DANE(channel = dim)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)

    def forward(self,x,x_downsample):
        x = self.norm1(x.transpose(1,2)).transpose(1,2)
        x = self.attn(x)
        x = self.ffn(x)
        x = self.select(x, x_downsample)
        x = x + self.drop_path(self.mlp(self.norm2(x.transpose(1,2)))).transpose(1,2)
        return x


class BasicLayer(nn.Module):
    def __init__(self, dim, d_k, depth, num_heads,stride,mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path=0.,norm_layer=nn.LayerNorm,pe=True):
        super(BasicLayer,self).__init__()    
        self.dim = dim
        self.depth = depth
        self.pe = pe

        self.convlayer = ConvBlock(inplanes=dim,stride = 2)
        self.maxpool = nn.MaxPool1d(kernel_size = 3, stride = 2, padding = 1)
        self.multiresolution_con = conv_separable(2 * dim, dim, 1)

        if self.pe:
            self.lil = LIL(self.dim, self.dim)

        # build transformer encoders
        self.blocks = nn.ModuleList([
                    MyFormerBlock(dim=dim, d_k=d_k,
                                 num_heads=num_heads,
                                 stride=stride, 
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, 
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
                    for i in range(depth)])


    def forward(self,x):
        x_spatial, x_inter = self.convlayer(x)
        x_pool = self.maxpool(x)

        L = x_spatial.shape[2]
        pad_input = (L % 2 == 1)
        if pad_input:
            x_spatial = F.pad(x_spatial, (0, L % 2))

        x0 = x_spatial[:, :, 0::2] 
        x1 = x_spatial[:, :, 1::2]  
        x_spatial = torch.cat([x0, x1], 1)  
        x_spatial = self.multiresolution_con(x_spatial) 
        x = x_pool + x_inter

        if self.pe:
            x = self.lil(x)

        for blk in self.blocks:
            x = blk(x, x_spatial)

        return x



class IEEGHCT(nn.Module):
    def __init__(self,in_channels=1,
                num_classes=3,  # The number of classes for recognition.
                ce_depths=4, # MBConvs numbers of Convolutional Embeddind 
                embed_dim=8,
                d_k=32,
                num_heads=[1, 2, 4, 8], # The number of heads in different stages.
                strides = [4, 4, 2, 2], # SMHSA
                depths=[1, 2, 4, 8],  # The number of blocks in each stage.
                mlp_ratio=4, # The MLP expansion rate.
                qkv_bias=False,  # Whether adding bias to qkv.
                drop_rate=0.,  # Dropout rate.
                attn_drop_rate=0.,  # Dropout rate on attention values.
                norm_layer=nn.LayerNorm,  # The norm layer.
                pe=True, # Positional Embedding, LIL
                ):
        super(IEEGHCT,self).__init__()

        self.num_classes = num_classes
        self.depths = depths
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.pe = pe
        self.mlp_ratio = mlp_ratio

        self.pos_drop = nn.Dropout(p=drop_rate)

        # Stem layer
        self.stem = TransformerStem(in_channels, embed_dim)

        # CE Blocks
        self.conv_embeds = nn.ModuleList()
        for i in range(len(depths)):
            self.conv_embeds.append(ConvEmbedding(embed_dim * 2 ** i,embed_dim * 2 ** (i+1), depths=ce_depths,stage=i+1))

        # Transformer Blocks
        self.layers = nn.ModuleList()
        for k in range(self.num_layers):
            layer = BasicLayer(dim=embed_dim * 2 ** (k+1),d_k=d_k,
                                depth=depths[k],
                                num_heads=num_heads[k],
                                stride=strides[k],
                                mlp_ratio=mlp_ratio,
                                qkv_bias=qkv_bias,
                                drop=drop_rate,
                                attn_drop=attn_drop_rate,
                                norm_layer=norm_layer
                                )
            self.layers.append(layer)

        self.norm = norm_layer(embed_dim * 2 ** len(depths))
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(embed_dim * 2 ** len(depths), num_classes) if num_classes > 0 else nn.Identity()


    def forward_features(self, x):
        x = self.stem(x)
        x = self.pos_drop(x)

        for i in range(self.num_layers):
            x = self.conv_embeds[i](x)
            x = self.layers[i](x)

        x = self.norm(x.transpose(1,2)).transpose(1,2)
        return x  


    def forward(self, x):
        x = self.forward_features(x)  
        x = self.avgpool(x)  
        x = torch.flatten(x, 1)
        x = self.head(x)

        return x



In [47]:
model = IEEGHCT(in_channels=1, num_classes=NUM_CLASSES, depths=[1, 2, 4, 2]).to(DEVICE)
model

IEEGHCT(
  (pos_drop): Dropout(p=0.0, inplace=False)
  (stem): TransformerStem(
    (conv1): Conv1d(1, 4, kernel_size=(10,), stride=(2,), padding=(4,), bias=False)
    (act1): GELU(approximate='none')
    (bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(4, 8, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
    (act2): GELU(approximate='none')
    (bn2): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv1d(8, 8, kernel_size=(3,), stride=(2,), padding=(1,), bias=False)
    (act3): GELU(approximate='none')
    (bn3): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_embeds): ModuleList(
    (0): ConvEmbedding(
      (mbconvs): ModuleList(
        (0): MBConv(
          (conv): Sequential(
            (0): Conv1d(8, 16, kernel_size=(3,), stride=(2,), padding=(1,), bias=False)
            (1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine

In [52]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


In [53]:

# Training Function
def trainCL_model(model, train_loader, optimizer, criterion, num_epochs, device):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        y_true_train = []
        y_pred_train = []

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Ensure inputs have the correct shape
            if inputs.dim() == 2:  # [batch_size, seq_length]
                inputs = inputs.unsqueeze(1)  # [batch_size, 1, seq_length]

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels.squeeze())
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            y_true_train.extend(labels.squeeze().cpu().numpy())
            y_pred_train.extend(predicted.cpu().numpy())

            avg_loss = running_loss / len(train_loader)
            train_accuracy = accuracy_score(y_true_train, y_pred_train)
            precision, recall, f1, _ = precision_recall_fscore_support(y_true_train, y_pred_train, average='weighted', zero_division=0)

            progress_bar.set_postfix(loss=avg_loss, accuracy=train_accuracy, precision=precision, recall=recall, f1=f1)

        mlflow.log_metric("train_loss", avg_loss, step=epoch)
        mlflow.log_metric("train_accuracy", train_accuracy, step=epoch)
        mlflow.log_metric("train_precision", precision, step=epoch)
        mlflow.log_metric("train_recall", recall, step=epoch)
        mlflow.log_metric("train_f1", f1, step=epoch)

# Evaluation Function
def evaluateCL_model(model, test_loader, dataset, device, img_path, run_name):
    model.eval()
    y_true_test = []
    y_pred_test = []

    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Evaluating", unit="batch"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Ensure inputs have the correct shape
            if inputs.dim() == 2:  # [batch_size, seq_length]
                inputs = inputs.unsqueeze(1)  # [batch_size, 1, seq_length]

            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            y_true_test.extend(labels.squeeze().cpu().numpy())
            y_pred_test.extend(predicted.cpu().numpy())

    test_accuracy = accuracy_score(y_true_test, y_pred_test)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true_test, y_pred_test, average='weighted', zero_division=0)

    print(f'Accuracy of the model on the test data: {test_accuracy:.2f}%')
    print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')

    mlflow.log_metric("test_accuracy", test_accuracy)
    mlflow.log_metric("test_precision", precision)
    mlflow.log_metric("test_recall", recall)
    mlflow.log_metric("test_f1", f1)

    # Confusion matrix
    cm = confusion_matrix(y_true_test, y_pred_test)
    cm_df = pd.DataFrame(cm, index=dataset.label_encoder.classes_, columns=dataset.label_encoder.classes_)

    plt.figure(figsize=(10, 7))
    sns.heatmap(cm_df, annot=True, fmt='d', cmap='Blues')
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title('Confusion Matrix')
    img_file = os.path.join(img_path, f"confusion_matrix_{run_name}.png")
    plt.savefig(img_file)
    mlflow.log_artifact(img_file)
    plt.close()

In [55]:

    # Ensure the experiment is created or active
    experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
    if experiment is None:
        experiment_id = mlflow.create_experiment(EXPERIMENT_NAME)
    else:
        experiment_id = experiment.experiment_id
        if experiment.lifecycle_stage == 'deleted':
            mlflow.tracking.MlflowClient().restore_experiment(experiment_id)

    # Start MLflow experiment
    mlflow.set_experiment(EXPERIMENT_NAME)

    with mlflow.start_run(run_name=RUN_NAME) as run:
        # Log parameters
        mlflow.log_param("epochs", NUM_EPOCHS)
        mlflow.log_param("batch_size", BATCH_SIZE)
        mlflow.log_param("learning_rate", LEARNING_RATE)
        mlflow.log_param("model", "ConvBiLSTM")
        mlflow.log_param("input_size", INPUT_SIZE)
        mlflow.log_param("num_classes", NUM_CLASSES)
        mlflow.log_dict(dataset.get_class_mapping(), "class_mapping.json")

        # Train and Evaluate the Model
        #trainCL_model(model, train_loader, optimizer, criterion, NUM_EPOCHS, DEVICE)
        evaluateCL_model(model, test_loader, dataset, DEVICE, '../plots', RUN_NAME)

        # Log the model
        mlflow.pytorch.log_model(model, "model")


Evaluating:   0%|          | 0/82 [00:00<?, ?batch/s]

Accuracy of the model on the test data: 0.98%
Precision: 0.9809, Recall: 0.9802, F1 Score: 0.9803
