In [28]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torch.optim as optim
from sklearn.metrics import  confusion_matrix, roc_auc_score, average_precision_score
from sklearn.preprocessing import RobustScaler, StandardScaler, MinMaxScaler
import random
import matplotlib.pyplot as plt
import time
import random
from tqdm import tqdm
import os
from Transformers import MedicalTimeSeriesDatasetTimeGrid, MedicalTimeSeriesDatasetTuple, collate_fn, TimeSeriesGridTransformer, TimeSeriesTupleTransformer

plt.style.use('default')





In [None]:
config = {
    "model_type": "tuple",  # Choose 'time_grid' or 'tuple'
    "scaler_name": "MinMaxScaler", # Name of the scaler used

    "seed": 42,
    "device": "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu",
    "epochs": 30, # Adjust as needed
    "batch_size": 32,
    "learning_rate": 1e-4,
    "weight_decay": 1e-5,
    "patience": 7, # For early stopping based on validation performance

    # Paths (MODIFY THESE)
    "data_dir": "data", # Directory containing train/val/test data
    "output_dir": "data/model_outputs", # Directory to save models and results

    # Model Specific Hyperparameters (adjust based on model_type and tuning)
    # Grid Transformer
    "grid_d_model": 128,
    "grid_nhead": 4,
    "grid_num_layers": 2,
    "grid_dropout": 0.2,
    "grid_feature_dim": 41, # Should match your data

    # Tuple Transformer
    "tuple_d_model": 128,
    "tuple_nhead": 8,
    "tuple_num_encoder_layers": 3,
    "tuple_dim_feedforward": 256, # Typically 2-4x d_model
    "tuple_dropout": 0.2,
    "tuple_num_modalities": 41, # 40 variables + 1 for padding (if PAD_INDEX_Z=0) Adjust if needed!
    "PAD_INDEX_Z": 0, # Padding index for time and value features
    "tuple_modality_emb_dim": 64,
    "tuple_max_seq_len": 768, # From transformers.py or defined here
    }

In [30]:

# --- Add this near the top of your script ---
def set_seed(seed_value=42):
    """Sets the seed for reproducibility in PyTorch, NumPy, and Python."""
    random.seed(seed_value)  # Python random module
    np.random.seed(seed_value) # Numpy module
    torch.manual_seed(seed_value) # PyTorch CPU seeding

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value) # if you are using multi-GPU.
        # Configure CuDNN for deterministic operations
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        # Optional: Newer PyTorch versions might require this for full determinism
        # Note: This can sometimes throw errors if a deterministic implementation isn't available
        # try:
        #     torch.use_deterministic_algorithms(True)
        # except Exception as e:
        #     print(f"Warning: Could not enable deterministic algorithms: {e}")
        # Optional: Sometimes needed for deterministic matrix multiplication
        # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

    print(f"Seed set globally to {seed_value}")



# --- Call this function very early in your script ---
seed = config["seed"]
set_seed(seed_value=seed)


Seed set globally to 42


In [None]:
def get_data_loaders(config):
    """Loads data and creates DataLoader objects."""
    
    data_path = config["data_dir"]

    y_train = np.load(os.path.join(data_path, 'outcomes_sorted_set-a.npy'), allow_pickle=True)
    y_val = np.load(os.path.join(data_path, 'outcomes_sorted_set-b.npy'),allow_pickle=True)
    y_test = np.load(os.path.join(data_path, 'outcomes_sorted_set-c.npy'), allow_pickle=True)

    if config["model_type"] == "time_grid":
        data_format = config["model_type"]
    elif config["model_type"] == "tuple" or config["model_type"] == "contrast" or config["model_type"] == "linear_probe":
        data_format = "tuple"
    else:
        raise ValueError(f"Unknown model_type: {config['model_type']}")
    scaler_name = config["scaler_name"]
    X_train = np.load(os.path.join(data_path,f"{data_format}_processed_{scaler_name}_set-a.npy"), allow_pickle=True)
    X_val = np.load(os.path.join(data_path,f"{data_format}_processed_{scaler_name}_set-b.npy"), allow_pickle=True)
    X_test = np.load(os.path.join(data_path,f"{data_format}_processed_{scaler_name}_set-c.npy"), allow_pickle=True)

  
    if config["model_type"] == "time_grid":
      # Feature dimension check
        config["grid_feature_dim"] = X_train.shape[2] # Update based on loaded data
        print(f"Grid Data Shapes: Train X: {X_train.shape}, Val X: {X_val.shape}, Test X: {X_test.shape}")

        train_dataset = MedicalTimeSeriesDatasetTimeGrid(X_train, y_train)
        val_dataset = MedicalTimeSeriesDatasetTimeGrid(X_val, y_val)
        test_dataset = MedicalTimeSeriesDatasetTimeGrid(X_test, y_test)
        collate_func = None # Default collate for grid data

    elif config["model_type"] == "tuple":
      
        # Modality count check (find max index + 1, assuming 0 is padding or valid)
        all_z = [t[1] for patient in X_train for t in patient]
        num_modalities = max(all_z) + 1 if all_z else 1
        if num_modalities > config["tuple_num_modalities"]:
             print(f"Warning: Found {num_modalities-1} as max modality index. Updating config.")
             config["tuple_num_modalities"] = num_modalities
        elif num_modalities < config["tuple_num_modalities"]:
             print(f"Warning: Max modality index is {num_modalities-1}, but config expects {config['tuple_num_modalities']-1}. Using config value.")

        print(f"Tuple Data: Train Patients: {len(X_train)}, Val Patients: {len(X_val)}, Test Patients: {len(X_test)}")
        print(f"Using {config['tuple_num_modalities']} modalities (including padding index {config['PAD_INDEX_Z']}")


        train_dataset = MedicalTimeSeriesDatasetTuple(X_train, y_train, max_seq_len=config["tuple_max_seq_len"])
        val_dataset = MedicalTimeSeriesDatasetTuple(X_val, y_val, max_seq_len=config["tuple_max_seq_len"])
        test_dataset = MedicalTimeSeriesDatasetTuple(X_test, y_test, max_seq_len=config["tuple_max_seq_len"])
        collate_func = collate_fn # Use the custom collate_fn

    else:
        raise ValueError(f"Unknown model_type: {config['model_type']}")

    print(f"Labels - Train: {len(y_train)} (Positive: {y_train.sum()}), Val: {len(y_val)} (Positive: {y_val.sum()}), Test: {len(y_test)} (Positive: {y_test.sum()})")

    # --- Weighted Sampler for Imbalance ---
    class_counts = np.bincount(y_train.astype(int))
    class_weights = 1. / class_counts
    sample_weights = np.array([class_weights[int(t)] for t in y_train])
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

    # --- DataLoaders ---
    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], sampler=sampler, collate_fn=collate_func)
    val_loader = DataLoader(val_dataset, batch_size=config["batch_size"] * 2, shuffle=False, collate_fn=collate_func) # Larger batch size for validation
    test_loader = DataLoader(test_dataset, batch_size=config["batch_size"] * 2, shuffle=False, collate_fn=collate_func)

    print("DataLoaders created.")
    return train_loader, val_loader, test_loader

def get_model(config):
    """Initializes the model based on the specified type."""
    
    if config["model_type"] == "time_grid":
        model = TimeSeriesGridTransformer(
            feature_dim=config["grid_feature_dim"],
            d_model=config["grid_d_model"],
            nhead=config["grid_nhead"],
            num_layers=config["grid_num_layers"],
            dropout=config["grid_dropout"]
        )
    elif config["model_type"] == "tuple":
        model = TimeSeriesTupleTransformer(
            num_modalities=config["tuple_num_modalities"],
            d_model=config["tuple_d_model"],
            nhead=config["tuple_nhead"],
            num_encoder_layers=config["tuple_num_encoder_layers"],
            dim_feedforward=config["tuple_dim_feedforward"],
            num_classes=1, # Binary classification
            dropout=config["tuple_dropout"],
            max_seq_len=config["tuple_max_seq_len"],
            modality_emb_dim=config["tuple_modality_emb_dim"]
        )
    else:
        raise ValueError(f"Unknown model_type: {config['model_type']}")

    model.to(config["device"])
    print(f"Model ({config['model_type']}) created and moved to {config['device']}.")
 
    return model


In [None]:
def train_model(model,
                dataloader,
                criterion,
                optimizer,
                device,
                model_type):
    
    model.train()
    total_loss = 0.0
   
    for batch in tqdm(dataloader, desc="Training", leave=False):
        
        if model_type == "tuple":
            t_seq, z_seq, v_seq, attn_mask, labels = batch
            t_seq, z_seq, v_seq = t_seq.to(device), z_seq.to(device), v_seq.to(device)
            attn_mask, labels = attn_mask.to(device), labels.to(device)
            
            logits = model(t_seq, z_seq, v_seq, attn_mask) # Pass correct mask polarity
            y = labels # Ensure y is of shape (batch_size, 1)
            
        elif model_type == "time_grid":
            x, label = batch
        
            x = x.to(device)
            y = label.to(device)
            logits = model(x)


        else:
            raise ValueError(f"Unknown model_type: {model_type}")
        
        optimizer.zero_grad()
        loss = criterion(logits, y)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)    
   
    return avg_loss


def evaluate(model, dataloader, criterion, device, model_type):
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            if model_type == "tuple":
                t_seq, z_seq, v_seq, attn_mask, labels = batch
                t_seq, z_seq, v_seq, attn_mask, labels = t_seq.to(device), z_seq.to(device), v_seq.to(device), attn_mask.to(device), labels.to(device)
                logits = model(t_seq, z_seq, v_seq, attn_mask)
                y = labels

                # Store original labels for metrics


            elif model_type == "time_grid":
                x, labels = batch
                x = x.to(device)
                y = labels.to(device)
                logits = model(x)

                # Store original labels for metrics


            else:
                raise ValueError(f"Unknown model_type: {model_type}")
        
            loss = criterion(logits, y)
            total_loss += loss.item()

            probs = torch.sigmoid(logits).cpu().numpy()
            all_preds.extend(probs.flatten().flatten())
            all_labels.extend(y.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    auroc = roc_auc_score(all_labels, all_preds)
    auprc = average_precision_score(all_labels, all_preds)
    confusion_matrix_result = confusion_matrix(all_labels, (all_preds > 0.5).astype(int))
    return avg_loss, auroc, auprc, confusion_matrix_result
 

In [33]:
# 1. Load Data
train_loader, val_loader, test_loader = get_data_loaders(config)

# 2. Model Initialization, Criterion, Optimizer
model = get_model(config)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=config["patience"])

# 3. Training Loop
best_val_auroc = -1.0
epochs_no_improve = 0
best_model_path = os.path.join(config["output_dir"], f"best_model_{config['model_type']}.pth")


for epoch in range(config["epochs"]):
    print(f"\nEpoch {epoch+1}/{config['epochs']}")

    train_loss = train_model(model, train_loader, criterion, optimizer, config["device"], config["model_type"])
    val_loss, val_auroc, val_auprc, _ = evaluate(model, val_loader, criterion, config["device"], config["model_type"])

    scheduler.step(val_auroc)

    if val_auroc > best_val_auroc:
        print(f"Validation AuROC improved ({best_val_auroc:.4f} -> {val_auroc:.4f}). Saving model...")
        best_val_auroc = val_auroc
        torch.save(model.state_dict(), best_model_path)
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        print(f"Validation AuROC did not improve. ({epochs_no_improve}/{config['patience']})")

    if epochs_no_improve >= config["patience"]:
            print(f"Early stopping triggered after {epoch + 1} epochs.")
            break


# 4. Load Best Model and Evaluate on Test Set
print("\n--- Loading Best Model for Testing ---")
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, map_location=config["device"]))

    print("\n--- Evaluating on Test Set ---")
    test_loss, test_auroc, test_auprc, cm = evaluate(model, test_loader, criterion, config["device"], config["model_type"])
    print("\n--- Test Results ---")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test AuROC: {test_auroc:.4f}")
    print(f"Test AuPRC: {test_auprc:.4f}")
    print(f"Confusion Matrix:\n{cm}")
    print("--------------------")
else:
    print(f"Warning: Best model file not found at {best_model_path}. Testing with the last state.")
    # Optionally evaluate the final model state if no best model was saved
    print("\n--- Evaluating Last Model State on Test Set ---")
    test_loss, test_auroc, test_auprc, cm = evaluate(model, test_loader, criterion, config["device"], config["model_type"])
    print("\n--- Test Results (Last Epoch Model) ---")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test AuROC: {test_auroc:.4f}")
    print(f"Test AuPRC: {test_auprc:.4f}")
    print(f"Confusion Matrix:\n{cm}")
    print("--------------------")


Tuple Data: Train Patients: 4000, Val Patients: 4000, Test Patients: 4000
Using 41 modalities (including padding index 0
Labels - Train: 4000 (Positive: 554), Val: 4000 (Positive: 568), Test: 4000 (Positive: 585)
DataLoaders created.
Model (tuple) created and moved to cuda.

Epoch 1/30


                                                           

Validation AuROC improved (-1.0000 -> 0.6673). Saving model...

Epoch 2/30


                                                           

Validation AuROC improved (0.6673 -> 0.6804). Saving model...

Epoch 3/30


                                                           

Validation AuROC improved (0.6804 -> 0.6841). Saving model...

Epoch 4/30


                                                           

Validation AuROC improved (0.6841 -> 0.7030). Saving model...

Epoch 5/30


                                                           

Validation AuROC did not improve. (1/7)

Epoch 6/30


                                                           

Validation AuROC improved (0.7030 -> 0.7135). Saving model...

Epoch 7/30


                                                           

Validation AuROC improved (0.7135 -> 0.7199). Saving model...

Epoch 8/30


                                                           

Validation AuROC did not improve. (1/7)

Epoch 9/30


                                                           

Validation AuROC improved (0.7199 -> 0.7308). Saving model...

Epoch 10/30


                                                           

Validation AuROC improved (0.7308 -> 0.7319). Saving model...

Epoch 11/30


                                                           

Validation AuROC did not improve. (1/7)

Epoch 12/30


                                                           

Validation AuROC improved (0.7319 -> 0.7338). Saving model...

Epoch 13/30


                                                           

Validation AuROC improved (0.7338 -> 0.7341). Saving model...

Epoch 14/30


                                                           

Validation AuROC improved (0.7341 -> 0.7356). Saving model...

Epoch 15/30


                                                           

Validation AuROC did not improve. (1/7)

Epoch 16/30


                                                           

Validation AuROC improved (0.7356 -> 0.7396). Saving model...

Epoch 17/30


                                                           

Validation AuROC did not improve. (1/7)

Epoch 18/30


                                                           

Validation AuROC did not improve. (2/7)

Epoch 19/30


                                                           

Validation AuROC did not improve. (3/7)

Epoch 20/30


                                                           

Validation AuROC did not improve. (4/7)

Epoch 21/30


                                                           

Validation AuROC did not improve. (5/7)

Epoch 22/30


                                                           

Validation AuROC did not improve. (6/7)

Epoch 23/30


                                                           

Validation AuROC improved (0.7396 -> 0.7403). Saving model...

Epoch 24/30


                                                           

Validation AuROC improved (0.7403 -> 0.7406). Saving model...

Epoch 25/30


                                                           

Validation AuROC improved (0.7406 -> 0.7406). Saving model...

Epoch 26/30


                                                           

Validation AuROC did not improve. (1/7)

Epoch 27/30


                                                           

Validation AuROC did not improve. (2/7)

Epoch 28/30


                                                           

Validation AuROC did not improve. (3/7)

Epoch 29/30


                                                           

Validation AuROC did not improve. (4/7)

Epoch 30/30


                                                           

Validation AuROC did not improve. (5/7)

--- Loading Best Model for Testing ---

--- Evaluating on Test Set ---


                                                           


--- Test Results ---
Test Loss: 0.5856
Test AuROC: 0.7525
Test AuPRC: 0.3277
Confusion Matrix:
[[2163 1252]
 [ 150  435]]
--------------------


