In [1]:
from google.colab import drive
import os
drive.mount('/content/drive')
project_dir = '/content/drive/My Drive/ssm_ehr'
print(os.path.exists(project_dir))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
True


In [2]:
print(os.getcwd())

/content


In [3]:
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
!pip uninstall mamba-ssm causal-conv1d
!pip install causal-conv1d && pip install mamba-ssm

Looking in indexes: https://download.pytorch.org/whl/cu121
[0mCollecting causal-conv1d
  Using cached causal_conv1d-1.4.0.tar.gz (9.3 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ninja (from causal-conv1d)
  Downloading ninja-1.11.1.2-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.3 kB)
Downloading ninja-1.11.1.2-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (422 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.9/422.9 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: causal-conv1d
  Building wheel for causal-conv1d (setup.py) ... [?25l[?25hdone
  Created wheel for causal-conv1d: filename=causal_conv1d-1.4.0-cp310-cp310-linux_x86_64.whl size=104867883 sha256=b5e7cf7e964b5e99275d97ba1e1b0ee4e3073f4593743ba1f1c6aa394a3008cc
  Stored in directory: /root/.cache/pip/wheels/e3/dd/4c/205f24e151736bd22f5980738dd10a19af6f093b6f4dcab006
Successfully built causal-conv1d
Installi

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba  # Assuming Mamba is installed
import math
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  def backward(ctx, dout):
  def forward(
  def backward(ctx, dout, *args):
  def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
  def backward(ctx, grad_output):
  def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
  def backward(ctx, dout, *args):


In [29]:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0).to(device)  # Shape: (1, max_len, d_model)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]  # Add positional encoding


class MambaAttentionClassifier(nn.Module):
    def __init__(self, ts_feature_dim, static_feature_dim, hidden_dim, num_classes):
        """
        Args:
            ts_feature_dim (int): Number of features in time-series data (e.g., 37).
            static_feature_dim (int): Number of features in static data (e.g., 8).
            hidden_dim (int): Dimension of hidden states in the model.
            num_classes (int): Number of output classes (e.g., 2 for binary classification).
        """
        super(MambaAttentionClassifier, self).__init__()

        # Time-series processing with Mamba
        self.positional_encoding = PositionalEncoding(d_model=ts_feature_dim)
        self.mamba_layer = Mamba(
            d_model=ts_feature_dim,  # Include time as an additional feature
            d_state=hidden_dim,         # Mamba's internal state size
            d_conv=4,                   # Convolution width for local dependencies
            expand=2                    # Expansion factor
        )

        self.projection = nn.Linear(ts_feature_dim, hidden_dim)
        self.mamba_norm = nn.LayerNorm(hidden_dim)  # Layer normalization for stability


        # Multi-head attention
        self.multihead_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, batch_first=True)

        # Static feature processing
        self.static_fc = nn.Linear(static_feature_dim, hidden_dim)
        self.static_norm = nn.LayerNorm(hidden_dim)

        # Fully connected layers for classification
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, ts_values, ts_indicators, ts_time, static):
        """
        Args:
            ts_values (torch.Tensor): Time-series data (batch_size, seq_len, ts_feature_dim).
            ts_indicators (torch.Tensor): Indicator for missing time-series data (batch_size, seq_len, ts_feature_dim).
            ts_time (torch.Tensor): Time-series timestamps (batch_size, seq_len).
            static (torch.Tensor): Static features (batch_size, static_feature_dim).
        Returns:
            torch.Tensor: Class probabilities (batch_size, num_classes).
        """
        # Ensure the shape of ts_indicators matches ts_values
        assert ts_values.shape == ts_indicators.shape, "Shape mismatch between ts_values and ts_indicators"

        # Handle missing data: Mask out the missing time-series values using ts_indicators
        ts_values = ts_values * ts_indicators  # Element-wise multiplication to mask missing data

        # Add time as an additional feature and apply positional encoding
        ts_time = ts_time.unsqueeze(-1)  # (batch_size, seq_len, 1)
        ts_combined = torch.cat([ts_values, ts_time], dim=-1)  # (batch_size, seq_len, ts_feature_dim + 1)
        ts_combined = self.positional_encoding(ts_combined)

        # Process time-series data with Mamba
        ts_encoded = self.mamba_layer(ts_combined)  # (batch_size, seq_len, hidden_dim)
        ts_encoded = self.projection(ts_encoded)  # (batch_size, seq_len, hidden_dim)
        ts_encoded = self.mamba_norm(ts_encoded)  # Normalize the Mamba output


        # Apply multi-head attention
        ts_encoded, _ = self.multihead_attention(ts_encoded, ts_encoded, ts_encoded)

        # Compute attention weights
        attn_weights = F.softmax(torch.mean(ts_encoded, dim=-1, keepdim=True), dim=1)  # (batch_size, seq_len, 1)
        ts_attended = torch.sum(attn_weights * ts_encoded, dim=1)  # (batch_size, hidden_dim)

        # Process static features
        static_encoded = F.relu(self.static_fc(static))  # (batch_size, hidden_dim)
        static_encoded = self.static_norm(static_encoded)

        # Concatenate attended time-series and static features
        combined = torch.cat([ts_attended, static_encoded], dim=1)  # (batch_size, hidden_dim * 2)

        # Classification
        output = self.classifier(combined)  # (batch_size, num_classes)

        return output


In [None]:

class MambaAttentionClassifier(nn.Module):
    def __init__(self, ts_feature_dim, static_feature_dim, hidden_dim, num_classes):
        """
        Args:
            ts_feature_dim (int): Number of features in time-series data (e.g., 37).
            static_feature_dim (int): Number of features in static data (e.g., 8).
            hidden_dim (int): Dimension of hidden states in the model.
            num_classes (int): Number of output classes (e.g., 2 for binary classification).
        """
        super(MambaAttentionClassifier, self).__init__()

        # Time-series processing with Mamba
        self.mamba_layer = Mamba(
            d_model=ts_feature_dim,  # Include time as an additional feature
            d_state=hidden_dim,         # Mamba's internal state size
            d_conv=4,                   # Convolution width for local dependencies
            expand=2                    # Expansion factor
        )

        # Static feature processing
        self.static_fc = nn.Linear(static_feature_dim, hidden_dim)

        # Attention layer to weight time-series features
        self.attention = nn.Linear(ts_feature_dim, 1)

        # Fully connected layers for classification
        self.classifier = nn.Sequential(
            nn.Linear(ts_feature_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, ts_values, ts_indicators, ts_time, static):
      """
      Args:
          ts_values (torch.Tensor): Time-series data (batch_size, seq_len, ts_feature_dim).
          ts_indicators (torch.Tensor): Indicator for missing time-series data (batch_size, seq_len, ts_feature_dim).
          ts_time (torch.Tensor): Time-series timestamps (batch_size, seq_len).
          static (torch.Tensor): Static features (batch_size, static_feature_dim).
      Returns:
          torch.Tensor: Class probabilities (batch_size, num_classes).
      """
      # Ensure the shape of ts_indicators matches ts_values
      assert ts_values.shape == ts_indicators.shape, "Shape mismatch between ts_values and ts_indicators"

      # Handle missing data: Mask out the missing time-series values using ts_indicators
      ts_values = ts_values * ts_indicators  # Element-wise multiplication to mask missing data

      ts_time = ts_time.unsqueeze(-1)  # (batch_size, seq_len, 1)
      ts_combined = torch.cat([ts_values, ts_time], dim=-1)  # (batch_size, seq_len, ts_feature_dim + 1)

      # Process time-series data with Mamba
      ts_encoded = self.mamba_layer(ts_combined)  # (batch_size, seq_len, hidden_dim)

      # print('ts_encoded shape',ts_encoded.shape)

      # Reshape ts_encoded for the attention layer
      batch_size, seq_len, hidden_dim = ts_encoded.shape
      ts_encoded_flat = ts_encoded.view(-1, hidden_dim)  # Flatten to (batch_size * seq_len, hidden_dim)

      # Compute attention scores
      attn_scores = self.attention(ts_encoded_flat)  # (batch_size * seq_len, 1)
      attn_scores = attn_scores.view(batch_size, seq_len, 1)  # Reshape back to (batch_size, seq_len, 1)

      # Compute attention weights
      attn_weights = F.softmax(attn_scores, dim=1)  # (batch_size, seq_len, 1)

      # Apply attention weights to the Mamba output
      ts_attended = torch.sum(attn_weights * ts_encoded, dim=1)  # (batch_size, hidden_dim)

      # Process static features
      static_encoded = F.relu(self.static_fc(static))  # (batch_size, hidden_dim)

      # Concatenate attended time-series and static features
      combined = torch.cat([ts_attended, static_encoded], dim=1)  # (batch_size, hidden_dim * 2)

      # print('ts_attended shape',ts_attended.shape)
      # print('static_encoded shape', static_encoded.shape)
      # print('combined shape', combined.shape)

      # Classification
      output = self.classifier(combined)  # (batch_size, num_classes)

      return output



In [7]:
# drive.mount('/content/drive')
# project_dir = '/content/drive/My Drive/ssm_ehr'
train_data = np.load('/content/drive/MyDrive/ssm_ehr/datasets/split_1/train_physionet2012_1.npy', allow_pickle=True)
test_data = np.load('/content/drive/MyDrive/ssm_ehr/datasets/split_1/train_physionet2012_1.npy', allow_pickle=True)
val_data = np.load('/content/drive/MyDrive/ssm_ehr/datasets/split_1/train_physionet2012_1.npy', allow_pickle=True)

In [8]:

def custom_collate_fn(batch):
    """
    Custom collate function to handle batches with variable-length time-series data and static features.

    Args:
        batch (list of tuples): Each tuple contains (ts_values, ts_indicators, ts_time, static, labels).

    Returns:
        tuple: Padded time-series values, indicators, times, static features, and labels.
    """
    ts_values = [sample[0].clone().detach().float() for sample in batch]
    ts_indicators = [sample[1].clone().detach().float() for sample in batch]
    ts_times = [sample[2].clone().detach().float() for sample in batch]
    static = torch.stack([sample[3].clone().detach().float() for sample in batch])
    labels = torch.tensor([sample[4] for sample in batch], dtype=torch.float32)

    # Pad ts_values, ts_indicators, and ts_time
    ts_values_padded = pad_sequence(ts_values, batch_first=True)
    ts_indicators_padded = pad_sequence(ts_indicators, batch_first=True)
    ts_times_padded = pad_sequence(ts_times, batch_first=True)

    return ts_values_padded, ts_indicators_padded, ts_times_padded, static, labels

In [9]:

class ICUTimeSeriesDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return (
            torch.tensor(sample['ts_values'], dtype=torch.float32),  # Time-series values
            torch.tensor(sample['ts_indicators'], dtype=torch.float32),  # Missing indicators
            torch.tensor(sample['ts_times'], dtype=torch.float32),  # Time steps
            torch.tensor(sample['static'], dtype=torch.float32),  # Static features
            torch.tensor(sample['labels'], dtype=torch.float32)  # Label
        )

# Create train, validation, and test datasets
# train_data = np.load('Data/P12Data_1/split_1/train_physionet2012_1.npy', allow_pickle=True)
# test_data = np.load('Data/P12Data_1/split_1/test_physionet2012_1.npy', allow_pickle=True)
# val_data = np.load('Data/P12Data_1/split_1/validation_physionet2012_1.npy', allow_pickle=True)

train_dataset = ICUTimeSeriesDataset(train_data)
val_dataset = ICUTimeSeriesDataset(val_data)
test_dataset = ICUTimeSeriesDataset(test_data)

# Dataloader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define model
model = MambaAttentionClassifier(
    ts_feature_dim=38,
    static_feature_dim=8,
    hidden_dim=64,
    num_classes=2
)
model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop
for epoch in range(100):  # Adjust epochs as needed
    model.train()
    for ts_values, ts_indicators, ts_time, static, labels in train_loader:
        ts_values,ts_indicators, ts_time , static, labels = ts_values.to(device), ts_indicators.to(device), ts_time.to(device), static.to(device), labels.to(device)

        # Forward pass
        outputs = model(ts_values, ts_indicators, ts_time, static)
        loss = criterion(outputs, labels.long())

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/100], Loss: {loss.item():.4f}")

Epoch [1/100], Loss: 0.6214
Epoch [2/100], Loss: 0.2667
Epoch [3/100], Loss: 0.3618
Epoch [4/100], Loss: 0.2018
Epoch [5/100], Loss: 0.2690
Epoch [6/100], Loss: 0.3858
Epoch [7/100], Loss: 0.2078
Epoch [8/100], Loss: 0.1243
Epoch [9/100], Loss: 0.6300
Epoch [10/100], Loss: 0.1670
Epoch [11/100], Loss: 0.3674
Epoch [12/100], Loss: 0.1851
Epoch [13/100], Loss: 0.2427
Epoch [14/100], Loss: 0.2393
Epoch [15/100], Loss: 0.2748
Epoch [16/100], Loss: 0.5315
Epoch [17/100], Loss: 0.2089
Epoch [18/100], Loss: 0.1379
Epoch [19/100], Loss: 0.1769
Epoch [20/100], Loss: 0.2095
Epoch [21/100], Loss: 0.2463
Epoch [22/100], Loss: 0.0853
Epoch [23/100], Loss: 0.3786
Epoch [24/100], Loss: 0.2540
Epoch [25/100], Loss: 0.0718
Epoch [26/100], Loss: 0.2231
Epoch [27/100], Loss: 0.3472
Epoch [28/100], Loss: 0.0942
Epoch [29/100], Loss: 0.1795
Epoch [30/100], Loss: 0.5082
Epoch [31/100], Loss: 0.1834
Epoch [32/100], Loss: 0.1393
Epoch [33/100], Loss: 0.0387
Epoch [34/100], Loss: 0.1058
Epoch [35/100], Loss: 0

In [38]:
# Evaluation loop
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for ts_values, ts_indicators, ts_time, static, labels in val_loader:
        ts_values,ts_indicators,ts_time, static, labels = ts_values.to(device), ts_indicators.to(device),ts_time.to(device), static.to(device), labels.to(device)
        outputs = model(ts_values, ts_indicators,ts_time, static)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels.long()).sum().item()

    print(f"Validation Accuracy: {100 * correct / total:.2f}%")

Validation Accuracy: 99.16%


In [39]:
def evaluate_model(model, data_loader, device):
    """
    Evaluates the model on the given data loader and calculates evaluation metrics.

    Args:
        model (torch.nn.Module): Trained model.
        data_loader (torch.utils.data.DataLoader): Data loader for validation/test set.
        device (torch.device): Device to perform computation on (CPU/GPU).

    Returns:
        dict: A dictionary containing evaluation metrics.
    """
    model.eval()  # Set model to evaluation mode
    y_true = []
    y_pred = []
    y_prob = []

    with torch.no_grad():
        for ts_values, ts_indicators, ts_time, static, labels in data_loader:
            # Move data to device
            ts_values,ts_indicators, ts_time, static, labels = ts_values.to(device), ts_indicators.to(device), ts_time.to(device), static.to(device), labels.to(device)

            # Get model predictions
            outputs = model(ts_values, ts_indicators, ts_time, static)  # Raw logits
            probabilities = torch.softmax(outputs, dim=1)[:, 1]  # Probability for class 1
            predictions = torch.argmax(outputs, dim=1)  # Predicted class labels

            # Collect predictions and ground truth
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predictions.cpu().numpy())
            y_prob.extend(probabilities.cpu().numpy())

    # Calculate evaluation metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=1)
    recall = recall_score(y_true, y_pred, zero_division=1)
    f1 = f1_score(y_true, y_pred, zero_division=1)
    roc_auc = roc_auc_score(y_true, y_prob)

    return {
        "Accuracy": accuracy,
        "Precision": precision,
        "Recall": recall,
        "F1-Score": f1,
        "ROC-AUC": roc_auc,
    }

In [40]:
# Evaluate on the validation or test set
metrics = evaluate_model(model, val_loader, device)

# Print metrics
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")

Accuracy: 0.9916
Precision: 0.9612
Recall: 0.9791
F1-Score: 0.9701
ROC-AUC: 0.9992
