In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

In [17]:
class AttentionClassifier(nn.Module):
    def __init__(self, ts_feature_dim, static_feature_dim, hidden_dim, num_classes):
        """
        Args:
            ts_feature_dim (int): Number of features in time-series data (e.g., 37).
            static_feature_dim (int): Number of features in static data (e.g., 8).
            hidden_dim (int): Dimension of hidden states in the model.
            num_classes (int): Number of output classes (e.g., 2 for binary classification).
        """
        super(AttentionClassifier, self).__init__()

        # Time-series processing layers
        self.ts_rnn = nn.LSTM(
            input_size=ts_feature_dim,
            hidden_size=hidden_dim,
            batch_first=True,
            bidirectional=True
        )  # Bidirectional LSTM for time-series features

        # Static feature processing
        self.static_fc = nn.Linear(static_feature_dim, hidden_dim)

        # Attention layer to weight time-series features
        self.attention = nn.Linear(hidden_dim * 2, 1)  # Bi-LSTM has 2*hidden_dim output

        # Fully connected layers for classification
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2 + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, ts_values, ts_indicators, ts_time, static):
        """
        Args:
            ts_values (torch.Tensor): Time-series data (batch_size, seq_len, ts_feature_dim).
            ts_indicators (torch.Tensor): Indicator for missing time-series data (batch_size, seq_len, ts_feature_dim).
            static (torch.Tensor): Static features (batch_size, static_feature_dim).
        Returns:
            torch.Tensor: Class probabilities (batch_size, num_classes).
        """
        # Ensure the shape of ts_indicators matches ts_values
        assert ts_values.shape == ts_indicators.shape, "Shape mismatch between ts_values and ts_indicators"
        
        # Handle missing data: Mask out the missing time-series values using ts_indicators
        ts_values = ts_values * ts_indicators  # Element-wise multiplication to mask missing data

        ts_time = ts_time.unsqueeze(-1)  # (batch_size, seq_len, 1)
        ts_combined = torch.cat([ts_values, ts_time], dim=-1)  # (batch_size, seq_len, ts_feature_dim + 1)


        # Process time-series data with LSTM
        ts_encoded, _ = self.ts_rnn(ts_combined)  # ts_encoded: (batch_size, seq_len, hidden_dim*2)

        # Compute attention weights over time steps (dim=1)
        attn_weights = F.softmax(self.attention(ts_encoded), dim=1)  # (batch_size, seq_len, 1)

        # Apply attention weights to the LSTM output, but only over the time steps
        ts_attended = torch.sum(attn_weights * ts_encoded, dim=1)  # (batch_size, hidden_dim*2)

        # Process static features
        static_encoded = F.relu(self.static_fc(static))  # (batch_size, hidden_dim)

        # Concatenate attended time-series and static features
        combined = torch.cat([ts_attended, static_encoded], dim=1)  # (batch_size, hidden_dim*3)

        # Classification
        output = self.classifier(combined)  # (batch_size, num_classes)

        return output

In [None]:

def custom_collate_fn(batch):
    ts_values = [sample[0].clone().detach().float() for sample in batch]
    ts_indicators = [sample[1].clone().detach().float() for sample in batch]
    ts_time = [sample[2].clone().detach().float() for sample in batch]
    static = torch.stack([sample[3].clone().detach().float() for sample in batch])
    labels = torch.tensor([sample[4] for sample in batch], dtype=torch.float32)


    # Pad ts_values and ts_indicators
    ts_values_padded = pad_sequence(ts_values, batch_first=True)
    ts_indicators_padded = pad_sequence(ts_indicators, batch_first=True)

    return ts_values_padded, ts_indicators_padded, static, labels

In [21]:
def custom_collate_fn(batch):
    """
    Custom collate function to handle batches with variable-length time-series data and static features.
    
    Args:
        batch (list of tuples): Each tuple contains (ts_values, ts_indicators, ts_time, static, labels).
    
    Returns:
        tuple: Padded time-series values, indicators, times, static features, and labels.
    """
    ts_values = [sample[0].clone().detach().float() for sample in batch]
    ts_indicators = [sample[1].clone().detach().float() for sample in batch]
    ts_times = [sample[2].clone().detach().float() for sample in batch]
    static = torch.stack([sample[3].clone().detach().float() for sample in batch])
    labels = torch.tensor([sample[4] for sample in batch], dtype=torch.float32)

    # Pad ts_values, ts_indicators, and ts_time
    ts_values_padded = pad_sequence(ts_values, batch_first=True)
    ts_indicators_padded = pad_sequence(ts_indicators, batch_first=True)
    ts_times_padded = pad_sequence(ts_times, batch_first=True)

    return ts_values_padded, ts_indicators_padded, ts_times_padded, static, labels

In [24]:

class ICUTimeSeriesDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        return (
            torch.tensor(sample['ts_values'], dtype=torch.float32),  # Time-series values
            torch.tensor(sample['ts_indicators'], dtype=torch.float32),  # Missing indicators
            torch.tensor(sample['ts_times'], dtype=torch.float32),  # Time steps
            torch.tensor(sample['static'], dtype=torch.float32),  # Static features
            torch.tensor(sample['labels'], dtype=torch.float32)  # Label
        )

# Create train, validation, and test datasets
train_data = np.load('Data/P12Data_1/split_1/train_physionet2012_1.npy', allow_pickle=True)
test_data = np.load('Data/P12Data_1/split_1/test_physionet2012_1.npy', allow_pickle=True)
val_data = np.load('Data/P12Data_1/split_1/validation_physionet2012_1.npy', allow_pickle=True)

train_dataset = ICUTimeSeriesDataset(train_data)
val_dataset = ICUTimeSeriesDataset(val_data)
test_dataset = ICUTimeSeriesDataset(test_data)

# Dataloader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define model
model = AttentionClassifier(
    ts_feature_dim=38,
    static_feature_dim=8,
    hidden_dim=64,
    num_classes=2
)
model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop
for epoch in range(10):  # Adjust epochs as needed
    model.train()
    for ts_values, ts_indicators, ts_time, static, labels in train_loader:
        ts_values,ts_indicators, ts_time , static, labels = ts_values.to(device), ts_indicators.to(device), ts_time.to(device), static.to(device), labels.to(device)

        # Forward pass
        outputs = model(ts_values, ts_indicators, ts_time, static)
        loss = criterion(outputs, labels.long())

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/10], Loss: {loss.item():.4f}")

Epoch [1/10], Loss: 0.4365
Epoch [2/10], Loss: 0.2594
Epoch [3/10], Loss: 0.2410
Epoch [4/10], Loss: 0.2922
Epoch [5/10], Loss: 0.3249
Epoch [6/10], Loss: 0.3173
Epoch [7/10], Loss: 0.5327
Epoch [8/10], Loss: 0.3559
Epoch [9/10], Loss: 0.2313
Epoch [10/10], Loss: 0.2953


In [28]:
# Evaluation loop
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for ts_values, ts_indicators, ts_time, static, labels in val_loader:
        ts_values,ts_indicators,ts_time, static, labels = ts_values.to(device), ts_indicators.to(device),ts_time.to(device), static.to(device), labels.to(device)
        outputs = model(ts_values, ts_indicators,ts_time, static)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels.long()).sum().item()

    print(f"Validation Accuracy: {100 * correct / total:.2f}%")

Validation Accuracy: 85.24%


In [29]:
def evaluate_model(model, data_loader, device):
    """
    Evaluates the model on the given data loader and calculates evaluation metrics.
    
    Args:
        model (torch.nn.Module): Trained model.
        data_loader (torch.utils.data.DataLoader): Data loader for validation/test set.
        device (torch.device): Device to perform computation on (CPU/GPU).
    
    Returns:
        dict: A dictionary containing evaluation metrics.
    """
    model.eval()  # Set model to evaluation mode
    y_true = []
    y_pred = []
    y_prob = []

    with torch.no_grad():
        for ts_values, ts_indicators, ts_time, static, labels in data_loader:
            # Move data to device
            ts_values,ts_indicators, ts_time, static, labels = ts_values.to(device), ts_indicators.to(device), ts_time.to(device), static.to(device), labels.to(device)
            
            # Get model predictions
            outputs = model(ts_values, ts_indicators, ts_time, static)  # Raw logits
            probabilities = torch.softmax(outputs, dim=1)[:, 1]  # Probability for class 1
            predictions = torch.argmax(outputs, dim=1)  # Predicted class labels
            
            # Collect predictions and ground truth
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predictions.cpu().numpy())
            y_prob.extend(probabilities.cpu().numpy())
    
    # Calculate evaluation metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=1)
    recall = recall_score(y_true, y_pred, zero_division=1)
    f1 = f1_score(y_true, y_pred, zero_division=1)
    roc_auc = roc_auc_score(y_true, y_prob)

    return {
        "Accuracy": accuracy,
        "Precision": precision,
        "Recall": recall,
        "F1-Score": f1,
        "ROC-AUC": roc_auc,
    }

In [30]:
# Evaluate on the validation or test set
metrics = evaluate_model(model, val_loader, device)

# Print metrics
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")

Accuracy: 0.8524
Precision: 0.5909
Recall: 0.2694
F1-Score: 0.3701
ROC-AUC: 0.8475
