In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')
project_dir = '/content/drive/My Drive/ssm_ehr'
print(os.path.exists(project_dir))

Mounted at /content/drive
True


In [None]:
print(os.getcwd())

/content


In [None]:
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
!pip uninstall mamba-ssm causal-conv1d
!pip install causal-conv1d && pip install mamba-ssm

Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.4.0
  Downloading https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp310-cp310-linux_x86_64.whl (799.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m799.1/799.1 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.19.0
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.19.0%2Bcu121-cp310-cp310-linux_x86_64.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m111.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchaudio==2.4.0
  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.4.0%2Bcu121-cp310-cp310-linux_x86_64.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m102.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.4.0)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_nvrtc_c

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba  # Assuming Mamba is installed
import math
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  def backward(ctx, dout):
  def forward(
  def backward(ctx, dout, *args):
  def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
  def backward(ctx, grad_output):
  def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
  def backward(ctx, dout, *args):


In [None]:
from sklearn.metrics import average_precision_score

In [None]:
class MoEMambaAttentionClassifier(nn.Module):
    def __init__(self, ts_feature_dim, static_feature_dim, hidden_dim, num_classes, num_experts=4, max_time_steps=1000):
        super(MoEMambaAttentionClassifier, self).__init__()

        self.num_experts = num_experts
        self.hidden_dim = hidden_dim

        # Gating network
        self.gating_network = nn.Sequential(
            nn.Linear(ts_feature_dim, num_experts),
            nn.Softmax(dim=-1)
        )

        # Input projection to match hidden_dim
        self.input_projection = nn.Linear(ts_feature_dim, hidden_dim)

        # Define experts (using Mamba layers here)
        self.experts = nn.ModuleList([
            nn.Sequential(
                Mamba(d_model=hidden_dim, d_state=hidden_dim, d_conv=4, expand=2),
                nn.LayerNorm(hidden_dim),
                nn.Dropout(0.3)
            )
            for _ in range(num_experts)
        ])

        # Remaining layers (similar to your current model)
        self.projection = nn.Linear(hidden_dim, hidden_dim)
        self.multihead_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, dropout=0.2, batch_first=True)
        self.attention_layer = nn.Linear(hidden_dim, 1)
        self.static_fc = nn.utils.weight_norm(nn.Linear(static_feature_dim, hidden_dim))
        self.static_norm = nn.LayerNorm(hidden_dim)
        # self.classifier = nn.Sequential(
        #     nn.Linear(hidden_dim * 2, hidden_dim),
        #     nn.ReLU(),
        #     nn.Dropout(0.5),
        #     nn.Linear(hidden_dim, hidden_dim // 2),
        #     nn.ReLU(),
        #     nn.Linear(hidden_dim // 2, num_classes)
        # )
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),  # Single output for binary classification
            nn.Sigmoid()                   # Sigmoid activation for probabilities
        )

    def forward(self, ts_values, ts_indicators, ts_time, static):
      # Mask missing time-series values
      ts_values = ts_values * ts_indicators

      # Compute gating scores using the original ts_values (before projection)
      gating_scores = self.gating_network(ts_values.mean(dim=1))  # Shape: (batch_size, num_experts)
      gating_weights = F.softmax(gating_scores, dim=-1)  # Shape: (batch_size, num_experts)

      # Project input to match hidden_dim for experts
      ts_values_projected = self.input_projection(ts_values)  # Shape: (batch_size, seq_len, hidden_dim)

      # Expert outputs
      expert_outputs = []
      for i, expert in enumerate(self.experts):
          expert_output = expert(ts_values_projected)  # Shape: (batch_size, seq_len, hidden_dim)
          expert_outputs.append(expert_output)
      expert_outputs = torch.stack(expert_outputs, dim=1)  # Shape: (batch_size, num_experts, seq_len, hidden_dim)

      # Combine expert outputs using gating weights
      ts_encoded = torch.einsum('be,besh->bsh', gating_weights, expert_outputs)  # Weighted sum

      # Multi-head attention
      attn_output, _ = self.multihead_attention(ts_encoded, ts_encoded, ts_encoded)
      ts_encoded = ts_encoded + attn_output  # Residual connection

      # Compute learnable attention weights
      attn_scores = self.attention_layer(ts_encoded).squeeze(-1)
      attn_weights = F.softmax(attn_scores, dim=1)
      ts_attended = torch.sum(ts_encoded * attn_weights.unsqueeze(-1), dim=1)

      # Static features
      static_encoded = F.relu(self.static_fc(static))
      static_encoded = self.static_norm(static_encoded)

      # Concatenate features
      combined = torch.cat([ts_attended, static_encoded], dim=1)

      # Classification
      output = self.classifier(combined)
      return output


In [None]:
# drive.mount('/content/drive')
# project_dir = '/content/drive/My Drive/ssm_ehr'
train_data = np.load('/content/drive/MyDrive/ssm_ehr/datasets/split_1/train_physionet2012_1.npy', allow_pickle=True)
test_data = np.load('/content/drive/MyDrive/ssm_ehr/datasets/split_1/test_physionet2012_1.npy', allow_pickle=True)
val_data = np.load('/content/drive/MyDrive/ssm_ehr/datasets/split_1/validation_physionet2012_1.npy', allow_pickle=True)

In [None]:

def custom_collate_fn(batch):
    """
    Custom collate function to handle batches with variable-length time-series data and static features.

    Args:
        batch (list of tuples): Each tuple contains (ts_values, ts_indicators, ts_time, static, labels).

    Returns:
        tuple: Padded time-series values, indicators, times, static features, and labels.
    """
    ts_values = [sample[0].clone().detach().float() for sample in batch]
    ts_indicators = [sample[1].clone().detach().float() for sample in batch]
    ts_times = [sample[2].clone().detach().float() for sample in batch]
    static = torch.stack([sample[3].clone().detach().float() for sample in batch])
    labels = torch.tensor([sample[4] for sample in batch], dtype=torch.float32)

    # Pad ts_values, ts_indicators, and ts_time
    ts_values_padded = pad_sequence(ts_values, batch_first=True)
    ts_indicators_padded = pad_sequence(ts_indicators, batch_first=True)
    ts_times_padded = pad_sequence(ts_times, batch_first=True)

    return ts_values_padded, ts_indicators_padded, ts_times_padded, static, labels

In [None]:

class ICUTimeSeriesDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return (
            torch.tensor(sample['ts_values'], dtype=torch.float32),  # Time-series values
            torch.tensor(sample['ts_indicators'], dtype=torch.float32),  # Missing indicators
            torch.tensor(sample['ts_times'], dtype=torch.float32),  # Time steps
            torch.tensor(sample['static'], dtype=torch.float32),  # Static features
            torch.tensor(sample['labels'], dtype=torch.float32)  # Label
        )


train_dataset = ICUTimeSeriesDataset(train_data)
val_dataset = ICUTimeSeriesDataset(val_data)
test_dataset = ICUTimeSeriesDataset(test_data)

# Dataloader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define model

model = MoEMambaAttentionClassifier(
    ts_feature_dim=37,
    static_feature_dim=8,
    hidden_dim=16,
    num_classes=2
)

def train(model, train_loader, val_loader, num_epochs = 100):
  model.to(device)

  # Loss and optimizer
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

  # Training loop
  train_losses = []
  val_losses = []
  for epoch in range(num_epochs):  # Adjust epochs as needed
      model.train()
      loss_train = 0
      for ts_values, ts_indicators, ts_time, static, labels in train_loader:
          ts_values,ts_indicators, ts_time , static, labels = ts_values.to(device), ts_indicators.to(device), ts_time.to(device), static.to(device), labels.to(device)


          # print(ts_values.shape)
          # print(ts_indicators.shape)
          # print(ts_time.shape)
          # print(static.shape)
          # print(labels.shape)
          # break

          # Forward pass
          logits = model(ts_values, ts_indicators, ts_time, static)
          # Apply threshold to convert probabilities to binary predictions
          predictions = (logits >= 0.5).long()
          loss = criterion(predictions, labels)

          # Backward pass
          optimizer.zero_grad()
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Clip gradients
          optimizer.step()

          loss_train += loss.item()

      train_losses.append(loss_train/len(train_loader))

      #validation loss
      model.eval().to(device)
      labels_list = torch.LongTensor([]).to(device)
      predictions_list = torch.FloatTensor([]).to(device)
      with torch.no_grad():
          for ts_values, ts_indicators, ts_time, static, labels in val_loader:
              ts_values,ts_indicators, ts_time, static, labels = ts_values.to(device), ts_indicators.to(device), ts_time.to(device), static.to(device), labels.to(device).long()
              labels_list = torch.cat((labels_list, labels), dim=0)
              predicition = (model(ts_values, ts_indicators, ts_time, static) >= 0.5).long()
              predictions_list = torch.cat((predictions_list, predicition), dim=0)

          probs = torch.nn.functional.softmax(predictions_list, dim=1)
          auc_score = roc_auc_score(labels_list.cpu().numpy(), probs[:, 1].cpu().numpy())
          aupr_score = average_precision_score(labels_list.cpu().numpy(), probs[:, 1].cpu().numpy())
          accuracy = accuracy_score(labels_list.cpu().numpy(), (probs[:, 1] >= 0.5).cpu().numpy())

      val_loss = criterion(predictions_list, labels_list)
      val_losses.append(val_loss)
      print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {loss_train/len(train_loader):.4f}, Val Loss: {val_loss:.4f}, AUC: {auc_score:.4f}, AUPR: {aupr_score:.4f}, Accuracy: {accuracy:.4f}")


      # Validation Loop
      model.eval().to(device)
      labels_list = torch.LongTensor([]).to(device)
      probs_list = torch.FloatTensor([]).to(device)

      with torch.no_grad():
          for ts_values, ts_indicators, ts_time, static, labels in val_loader:
              # Move data to device
              ts_values, ts_indicators, ts_time, static, labels = (
                  ts_values.to(device),
                  ts_indicators.to(device),
                  ts_time.to(device),
                  static.to(device),
                  labels.to(device).long(),
              )
              # Collect ground truth labels
              labels_list = torch.cat((labels_list, labels), dim=0)

              # Get model predictions (probabilities)
              logits = model(ts_values, ts_indicators, ts_time, static)
              probabilities = torch.sigmoid(logits)  # Sigmoid for binary classification
              probs_list = torch.cat((probs_list, probabilities), dim=0)

          # Calculate metrics
          auc_score = roc_auc_score(labels_list.cpu().numpy(), probs_list.cpu().numpy())
          aupr_score = average_precision_score(labels_list.cpu().numpy(), probs_list.cpu().numpy())
          accuracy = accuracy_score(labels_list.cpu().numpy(), (probs_list >= 0.5).cpu().numpy())

          # Calculate validation loss
          val_loss = criterion(probs_list, labels_list.float())  # Ensure labels are float for BCE
          val_losses.append(val_loss.item())

      print(
          f"Epoch [{epoch+1}/{num_epochs}], "
          f"Train Loss: {loss_train/len(train_loader):.4f}, "
          f"Val Loss: {val_loss:.4f}, "
          f"AUC: {auc_score:.4f}, "
          f"AUPR: {aupr_score:.4f}, "
          f"Accuracy: {accuracy:.4f}"
      )

  return model, train_losses, val_losses

def test(model, test_loader):
    model.eval().to(device)  # Set model to evaluation mode

    # Loss and metrics
    criterion = nn.CrossEntropyLoss()
    test_losses = []
    labels_list = torch.LongTensor([]).to(device)
    predictions_list = torch.FloatTensor([]).to(device)

    with torch.no_grad():
        loss_test = 0
        for ts_values, ts_indicators, ts_time, static, labels in test_loader:
            # Move data to device
            ts_values, ts_indicators, ts_time, static, labels = (
                ts_values.to(device),
                ts_indicators.to(device),
                ts_time.to(device),
                static.to(device),
                labels.to(device).long(),
            )

            # Forward pass
            predictions = model(ts_values, ts_indicators, ts_time, static)
            loss = criterion(predictions, labels)

            # Accumulate test loss
            loss_test += loss.item()

            # Collect labels and predictions for metrics
            labels_list = torch.cat((labels_list, labels), dim=0)
            predictions_list = torch.cat((predictions_list, predictions), dim=0)

        # Compute average test loss
        test_losses.append(loss_test / len(test_loader))

        # Compute probabilities for metrics
        probs = torch.nn.functional.softmax(predictions_list, dim=1)
        auc_score = roc_auc_score(labels_list.cpu().numpy(), probs[:, 1].cpu().numpy())
        aupr_score = average_precision_score(labels_list.cpu().numpy(), probs[:, 1].cpu().numpy())
        predicted_labels = (probs[:, 1] >= 0.5).cpu().numpy().astype(int)
        accuracy = accuracy_score(labels_list.cpu().numpy(), predicted_labels)

    # Print test results
    print(f"Test Loss: {test_losses[-1]:.4f}, AUC: {auc_score:.4f}, AUPR: {aupr_score:.4f}, Accuracy: {accuracy:.4f}")

    return test_losses, auc_score, aupr_score, accuracy

model, train_losses, val_losses = train(model, train_loader, val_loader, 50)
print()
test_losses, auc_score, aupr_score, accuracy = test(model, test_loader)

RuntimeError: "host_softmax" not implemented for 'Long'

In [None]:

import matplotlib.pyplot as plt

def plot_losses(train_losses, validation_losses):
    # Convert validation_losses to CPU and detach before plotting
    validation_losses = [v.cpu().detach().numpy() for v in validation_losses]

    # Set the figure size and style
    plt.figure(figsize=(10, 6))

    # Plot training loss with markers
    plt.plot(train_losses, label='Training Loss', color='tab:blue', marker='o', markersize=6, linestyle='-', linewidth=2)

    # Plot validation loss with markers and different style
    plt.plot(validation_losses, label='Validation Loss', color='tab:orange', marker='s', markersize=6, linestyle='--', linewidth=2)

    # Add labels, title, and legend with improved styles
    plt.xlabel('Epochs', fontsize=14, fontweight='bold')
    plt.ylabel('Loss', fontsize=14, fontweight='bold')
    plt.title('Training and Validation Loss Over Epochs', fontsize=16, fontweight='bold')

    # Display the legend with adjusted positioning
    plt.legend(loc='upper right', fontsize=12)

    # Adjust x and y ticks for better readability
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)

    # Show the plot
    plt.tight_layout()
    plt.show()

plot_losses(train_losses, val_losses)

In [None]:
# Evaluation loop
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for ts_values, ts_indicators, ts_time, static, labels in val_loader:
        ts_values,ts_indicators,ts_time, static, labels = ts_values.to(device), ts_indicators.to(device),ts_time.to(device), static.to(device), labels.to(device)
        outputs = model(ts_values, ts_indicators,ts_time, static)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels.long()).sum().item()

    print(f"Validation Accuracy: {100 * correct / total:.2f}%")

In [None]:
def evaluate_model(model, data_loader, device):
    """
    Evaluates the model on the given data loader and calculates evaluation metrics.

    Args:
        model (torch.nn.Module): Trained model.
        data_loader (torch.utils.data.DataLoader): Data loader for validation/test set.
        device (torch.device): Device to perform computation on (CPU/GPU).

    Returns:
        dict: A dictionary containing evaluation metrics.
    """
    model.eval()  # Set model to evaluation mode
    y_true = []
    y_pred = []
    y_prob = []

    with torch.no_grad():
        for ts_values, ts_indicators, ts_time, static, labels in data_loader:
            # Move data to device
            ts_values,ts_indicators, ts_time, static, labels = ts_values.to(device), ts_indicators.to(device), ts_time.to(device), static.to(device), labels.to(device)

            # Get model predictions
            outputs = model(ts_values, ts_indicators, ts_time, static)  # Raw logits
            probabilities = torch.softmax(outputs, dim=1)[:, 1]  # Probability for class 1
            predictions = torch.argmax(outputs, dim=1)  # Predicted class labels

            # Collect predictions and ground truth
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predictions.cpu().numpy())
            y_prob.extend(probabilities.cpu().numpy())

    # Calculate evaluation metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=1)
    recall = recall_score(y_true, y_pred, zero_division=1)
    f1 = f1_score(y_true, y_pred, zero_division=1)
    roc_auc = roc_auc_score(y_true, y_prob)

    return {
        "Accuracy": accuracy,
        "Precision": precision,
        "Recall": recall,
        "F1-Score": f1,
        "ROC-AUC": roc_auc,
    }

Evaluate performance on all splits
- Save data in the same format as baseline models.

In [None]:
import time as time


def train(model, train_loader, val_loader, num_epochs = 100):
  model.to(device)

  # Loss and optimizer
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

  # Training loop
  train_losses = []
  val_losses = []
  AUC_scores = []
  times = []

  for epoch in range(num_epochs):  # Adjust epochs as needed
      start_time = time.time()
      model.train()
      loss_train = 0
      for ts_values, ts_indicators, ts_time, static, labels in train_loader:
          ts_values,ts_indicators, ts_time , static, labels = ts_values.to(device), ts_indicators.to(device), ts_time.to(device), static.to(device), labels.to(device)

          # Forward pass
          outputs = model(ts_values, ts_indicators, ts_time, static)
          loss = criterion(outputs, labels.long())

          # Backward pass
          optimizer.zero_grad()
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Clip gradients
          optimizer.step()

          loss_train += loss.item()

      train_losses.append(loss_train/len(train_loader))

      # validation loss
      model.eval().to(device)
      labels_list = torch.LongTensor([]).to(device)
      predictions_list = torch.FloatTensor([]).to(device)
      with torch.no_grad():
          for ts_values, ts_indicators, ts_time, static, labels in val_loader:
              ts_values,ts_indicators, ts_time, static, labels = ts_values.to(device), ts_indicators.to(device), ts_time.to(device), static.to(device), labels.to(device).long()
              labels_list = torch.cat((labels_list, labels), dim=0)
              predicition = model(ts_values, ts_indicators, ts_time, static)
              predictions_list = torch.cat((predictions_list, predicition), dim=0)

          probs = torch.nn.functional.softmax(predictions_list, dim=1)
          auc_score = roc_auc_score(labels_list.cpu().numpy(), probs[:, 1].cpu().numpy())
          aupr_score = average_precision_score(labels_list.cpu().numpy(), probs[:, 1].cpu().numpy())
          accuracy = accuracy_score(labels_list.cpu().numpy(), (probs[:, 1] >= 0.5).cpu().numpy())
          AUC_scores.append(auc_score)

      val_loss = criterion(predictions_list, labels_list)
      val_losses.append(val_loss)
      end_time = time.time()
      delta = end_time-start_time
      times.append(delta)

      if epoch%20==0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {loss_train/len(train_loader):.4f}, Val Loss: {val_loss:.4f}, AUC: {auc_score:.4f}, AUPR: {aupr_score:.4f}, Accuracy: {accuracy:.4f}, Time: {delta}")

  return model, train_losses, val_losses, AUC_scores, times

def test(model, test_loader):
    model.eval().to(device)  # Set model to evaluation mode

    # Loss and metrics
    criterion = nn.CrossEntropyLoss()
    test_losses = []
    labels_list = torch.LongTensor([]).to(device)
    predictions_list = torch.FloatTensor([]).to(device)

    with torch.no_grad():
        loss_test = 0
        for ts_values, ts_indicators, ts_time, static, labels in test_loader:
            # Move data to device
            ts_values, ts_indicators, ts_time, static, labels = (
                ts_values.to(device),
                ts_indicators.to(device),
                ts_time.to(device),
                static.to(device),
                labels.to(device).long(),
            )

            # Forward pass
            predictions = model(ts_values, ts_indicators, ts_time, static)
            loss = criterion(predictions, labels)

            # Accumulate test loss
            loss_test += loss.item()

            # Collect labels and predictions for metrics
            labels_list = torch.cat((labels_list, labels), dim=0)
            predictions_list = torch.cat((predictions_list, predictions), dim=0)

        # Compute average test loss
        test_losses.append(loss_test / len(test_loader))

        # Compute probabilities for metrics
        probs = torch.nn.functional.softmax(predictions_list, dim=1)
        auc_score = roc_auc_score(labels_list.cpu().numpy(), probs[:, 1].cpu().numpy())
        aupr_score = average_precision_score(labels_list.cpu().numpy(), probs[:, 1].cpu().numpy())
        predicted_labels = (probs[:, 1] >= 0.5).cpu().numpy().astype(int)
        accuracy = accuracy_score(labels_list.cpu().numpy(), predicted_labels)

    # Print test results
    print(f"Test Loss: {test_losses[-1]:.4f}, AUC: {auc_score:.4f}, AUPR: {aupr_score:.4f}, Accuracy: {accuracy:.4f}")

    return test_losses, auc_score, aupr_score, accuracy


In [None]:
import pandas as pd
import json

# empty dataframe for each split
training_log = pd.DataFrame(columns=["epoch",	"train_loss",	"val_loss",	"auc_score", "time"])
test_results = {
    "test_loss": 0,
    "accuracy": 0,
    "AUPRC": 0,
    "AUROC": 0,
}



splits = range(1, 6)

#### Loop over splits. Collect and save results

In [None]:
save_path = project_dir + f"/results/SMART_M_timed/"

for split in splits:

  # Load data
  train_data = np.load(f'/content/drive/MyDrive/ssm_ehr/datasets/split_{split}/train_physionet2012_{split}.npy', allow_pickle=True)
  test_data = np.load(f'/content/drive/MyDrive/ssm_ehr/datasets/split_{split}/test_physionet2012_{split}.npy', allow_pickle=True)
  val_data = np.load(f'/content/drive/MyDrive/ssm_ehr/datasets/split_{split}/validation_physionet2012_{split}.npy', allow_pickle=True)

  train_dataset = ICUTimeSeriesDataset(train_data)
  val_dataset = ICUTimeSeriesDataset(val_data)
  test_dataset = ICUTimeSeriesDataset(test_data)

  # Dataloader
  train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
  val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)
  test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)


  # Reinstantiate a model after each split

  model = MoEMambaAttentionClassifier(
    ts_feature_dim=37,
    static_feature_dim=8,
    hidden_dim=16,
    num_classes=2
  )



  # Training loop
  model, train_losses, val_losses, AUC_scores, times = train(model, train_loader, val_loader, num_epochs = 100)

  training_log["epoch"]=[i for i in range(1, 101)]
  # Convert CUDA tensors to numpy-compatible values
  training_log["epoch"] = [i for i in range(1, 101)]
  training_log["train_loss"] = [loss.cpu().item() if torch.is_tensor(loss) else loss for loss in train_losses]
  training_log["val_loss"] = [loss.cpu().item() if torch.is_tensor(loss) else loss for loss in val_losses]
  training_log["auc_score"] = [auc.cpu().item() if torch.is_tensor(auc) else auc for auc in AUC_scores]
  training_log["time"] = [t.cpu().item() if torch.is_tensor(t) else t for t in times]


  # Testing

  test_losses, auc_score, aupr_score, accuracy = test(model, test_loader)
  test_results["test_loss"] = test_losses.cpu().item() if torch.is_tensor(test_losses) else test_losses
  test_results["accuracy"] = accuracy.cpu().item() if torch.is_tensor(accuracy) else accuracy
  test_results["AUPRC"] = aupr_score.cpu().item() if torch.is_tensor(aupr_score) else aupr_score
  test_results["AUROC"] = auc_score.cpu().item() if torch.is_tensor(auc_score) else auc_score

  # Save results


  train_fp = save_path+f"split_{split}/training_log.csv"
  training_log.to_csv(train_fp, index=False)

  test_fp = save_path+f"split_{split}/test_results.json"
  json_results = json.dumps(test_results, indent=4)
  with open(test_fp, 'w') as file:
    file.write(json_results)

  print(f"Successfully saved data fromm split: {split}")




