In [None]:
import numpy as np
import torch
import torch.nn as nn
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from helper_funcs import get_data_loaders, get_model, train_model, evaluate
from Transformers import ProjectionHead
plt.style.use('default')
import torch.optim as optim
import json
from sklearn.model_selection import StratifiedShuffleSplit
from Transformers import MedicalTimeSeriesDatasetTuple, collate_fn
from torch.utils.data import WeightedRandomSampler, DataLoader



In [None]:
config = {
    "model_type": "contrast",  # Choose 'time_grid', 'tuple', "contrast", linear_probe
    "scaler_name": "MinMaxScaler", # Name of the scaler used

    "seed": 42,
    "device": "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu",
    "epochs": 30, # Adjust as needed
    "batch_size": 64,
    "learning_rate": 1e-4,
    "weight_decay": 1e-5,
    "patience": 8, # For early stopping based on validation performance
    "optimizer": "adam", # Choose 'adam', 'sgd', or 'adamw'
    "loss_function": "bce", # Choose 'bce', 'focal', or 'mse'

    # Paths (MODIFY THESE)
    "data_dir": "data", # Directory containing train/val/test data
    "output_dir": "data/model_outputs", # Directory to save models and results

    # Model Specific Hyperparameters (adjust based on model_type and tuning)
    # Grid Transformer
    "grid_d_model": 128,
    "grid_nhead": 4,
    "grid_num_layers": 2,
    "grid_dropout": 0.2,
    "grid_feature_dim": 41, # Should match your data

    # Tuple Transformer
    "tuple_d_model": 128,
    "tuple_nhead": 8,
    "tuple_num_encoder_layers": 3,
    "tuple_dim_feedforward": 256, # Typically 2-4x d_model
    "tuple_dropout": 0.2,
    "tuple_num_modalities": 41, # 40 variables + 1 for padding (if PAD_INDEX_Z=0) Adjust if needed!
    "PAD_INDEX_Z": 0, # Padding index for time and value features
    "tuple_modality_emb_dim": 64,
    "tuple_max_seq_len": 768, # From transformers.py or defined here

    # Contrastive Learning
    "proj_dim": 128
    }

In [None]:

# --- Add this near the top of your script ---
def set_seed(seed_value=42):
    """Sets the seed for reproducibility in PyTorch, NumPy, and Python."""
    random.seed(seed_value)  # Python random module
    np.random.seed(seed_value) # Numpy module
    torch.manual_seed(seed_value) # PyTorch CPU seeding

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value) # if you are using multi-GPU.
        # Configure CuDNN for deterministic operations
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        # Optional: Newer PyTorch versions might require this for full determinism
        # Note: This can sometimes throw errors if a deterministic implementation isn't available
        # try:
        #     torch.use_deterministic_algorithms(True)
        # except Exception as e:
        #     print(f"Warning: Could not enable deterministic algorithms: {e}")
        # Optional: Sometimes needed for deterministic matrix multiplication
        # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

    print(f"Seed set globally to {seed_value}")



# --- Call this function very early in your script ---
seed = config["seed"]
set_seed(seed)


In [None]:
def info_nce_loss(embeddings1, embeddings2, temperature = 0.2):
    """
    embeddings1: Tensor of shape (batch_size, proj_dim) for view 1
    embeddings2: Tensor of shape (batch_size, proj_dim) for view 2
    """
    batch_size = embeddings1.shape[0]

    # Normalize the embeddings
    embeddings1 = nn.functional.normalize(embeddings1, dim=1)
    embeddings2 = nn.functional.normalize(embeddings2, dim=1)

    similarity_matrix = torch.matmul(embeddings1, embeddings2.T) / temperature # shape (batch_size, batch_size)
    labels = torch.arange(batch_size, device=embeddings1.device)

    loss1 = nn.functional.cross_entropy(similarity_matrix, labels)
    loss2 = nn.functional.cross_entropy(similarity_matrix.T, labels)
    loss = (loss1 + loss2) / 2.0

    return loss

In [None]:
def train_contrastive(encoder, projection_head, dataloader, optimizer, device):
    encoder.train()
    projection_head.train()
    total_loss = 0.0
    for batch in tqdm(dataloader, desc="Training"):
        (t_seq1, z_seq1, v_seq1, attn_mask1), (t_seq2, z_seq2, v_seq2, attn_mask2), _ = batch
        t_seq1, z_seq1, v_seq1, attn_mask1 = t_seq1.to(device), z_seq1.to(device), v_seq1.to(device), attn_mask1.to(device)
        t_seq2, z_seq2, v_seq2, attn_mask2 = t_seq2.to(device), z_seq2.to(device), v_seq2.to(device), attn_mask2.to(device)
        optimizer.zero_grad()
        rep1 = encoder.get_representation(t_seq1, z_seq1, v_seq1, attn_mask1)
        rep2 = encoder.get_representation(t_seq2, z_seq2, v_seq2, attn_mask2)
        proj1 = projection_head(rep1)
        proj2 = projection_head(rep2)
        loss = info_nce_loss(proj1, proj2, temperature=0.2)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)





# --------------------------

In [None]:
contrastive_loader,_, _ = get_data_loaders(config) 



device = config["device"]
encoder = get_model(config)

encoder.to(device)

projection_head = ProjectionHead(config["tuple_d_model"], config["proj_dim"])
projection_head.to(device)
projection_head = ProjectionHead(config["tuple_d_model"], config["proj_dim"])
projection_head.to(device)

optimizer_pretrain = torch.optim.AdamW(list(encoder.parameters()) + list(projection_head.parameters()), lr=config["learning_rate"], weight_decay=config["weight_decay"])

for epoch in range(config["epochs"]):
    print(f"\nEpoch {epoch+1}/{config['epochs']}")

    train_loss = train_contrastive(encoder,projection_head, contrastive_loader, optimizer_pretrain, device)
    print(f"Epoch {epoch+1}/{config['epochs']}, Loss: {train_loss:.4f}")

# Save the encoder and projection head

model_name = f"V3_contrastive_encoder_model_epoch_{config['epochs']}.pth"
torch.save(encoder.state_dict(), f"{config['output_dir']}/{model_name}.pth")
torch.save(projection_head.state_dict(), f"{config['output_dir']}/{model_name}_projection_head.pth")
print(f"Model saved as {model_name}")
# save config
with open(f"{config['output_dir']}/config_{model_name}.json", "w") as f:
    json.dump(config, f, indent=4)



In [None]:
from sklearn.metrics import  confusion_matrix, roc_auc_score, average_precision_score

def evaluate(model, dataloader, criterion, device, model_type, linear_probe=None):
    model.eval()
    if linear_probe:
        linear_probe.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            if model_type == "tuple":
                t_seq, z_seq, v_seq, attn_mask, labels = batch
                t_seq, z_seq, v_seq, attn_mask, labels = t_seq.to(device), z_seq.to(device), v_seq.to(device), attn_mask.to(device), labels.to(device)
                logits = model(t_seq, z_seq, v_seq, attn_mask)
                y = labels

                # Store original labels for metrics


            elif model_type == "time_grid":
                x, labels = batch
                x = x.to(device)
                y = labels.to(device)
                logits = model(x)

                # Store original labels for metrics

            elif model_type == "linear_probe":
                t_seq, z_seq, v_seq, attn_mask, labels = batch
                t_seq, z_seq, v_seq, attn_mask, labels = t_seq.to(device), z_seq.to(device), v_seq.to(device), attn_mask.to(device), labels.to(device)
                rep = model.get_representation(t_seq, z_seq, v_seq, attn_mask)
                logits = linear_probe(rep)
                y = labels

            else:
                raise ValueError(f"Unknown model_type: {model_type}")
        
            loss = criterion(logits, y)
            total_loss += loss.item()

            probs = torch.sigmoid(logits).cpu().numpy()
            all_preds.extend(probs.flatten().flatten())
            all_labels.extend(y.cpu().numpy().flatten())

    avg_loss = total_loss / len(dataloader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    auroc = roc_auc_score(all_labels, all_preds)
    auprc = average_precision_score(all_labels, all_preds)
    confusion_matrix_result = confusion_matrix(all_labels, (all_preds > 0.5).astype(int))
    return avg_loss, auroc, auprc, confusion_matrix_result
 

In [None]:
# Load the encoder and projection head
config["model_type"] = "contrast"
encoder = get_model(config)
encoder.load_state_dict(torch.load(f"{config['output_dir']}/{model_name}.pth"))
encoder.to(device)
projection_head = ProjectionHead(config["tuple_d_model"], config["proj_dim"])
projection_head.load_state_dict(torch.load(f"{config['output_dir']}/{model_name}_projection_head.pth"))
projection_head.to(device)

# freeze the encoder
for param in encoder.parameters():
    param.requires_grad = False

linear_probe = nn.Linear(config["proj_dim"], 1) 
linear_probe.to(device)

optimizer_probe = torch.optim.AdamW(linear_probe.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
criterion = nn.BCEWithLogitsLoss()
config["model_type"] = "linear_probe" # Set the model type to linear probe for training
# Get the data loaders for training, validation, and testing
train_loader, val_loader, test_loader = get_data_loaders(config)


best_val_auroc = -1.0
epochs_no_improve = 0
best_model_path = os.path.join(config["output_dir"], f"{model_name}_linear_probe.pth")

print("Training Linear Probe...")
for epoch in range(config["epochs"]):
    print(f"\nEpoch {epoch+1}/{config['epochs']}")
    
    train_loss = train_model(encoder, train_loader, criterion, optimizer_probe, device, model_type= config["model_type"],linear_probe=linear_probe)
    print(f"Epoch {epoch+1}/{config['epochs']}, Loss: {train_loss:.4f}")

    val_loss, val_auroc, val_auprc, _ = evaluate(encoder, val_loader, criterion, device, model_type=config["model_type"], linear_probe=linear_probe)
    if val_auroc > best_val_auroc:
            print(f"Validation AuROC improved ({best_val_auroc:.4f} -> {val_auroc:.4f}). Saving model...")
            best_val_auroc = val_auroc
            torch.save(linear_probe.state_dict(), best_model_path)
            epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        print(f"Validation AuROC did not improve. ({epochs_no_improve}/{config['patience']})")

    if epochs_no_improve >= config["patience"]:
            print(f"Early stopping triggered after {epoch + 1} epochs.")
            break

In [None]:

linear_probe.load_state_dict(torch.load(best_model_path))
print("\n--- Evaluating on Test Set ---")
config["model_type"] = "linear_probe" # Set the model type to linear probe for evaluation
test_loss, test_auroc, test_auprc, cm = evaluate(encoder, test_loader, criterion, config["device"], config["model_type"], linear_probe=linear_probe)
print("\n--- Test Results ---")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test AuROC: {test_auroc:.4f}")
print(f"Test AuPRC: {test_auprc:.4f}")
print(f"Confusion Matrix:\n{cm}")
print("--------------------")
results = {
    "test_loss": test_loss,
    "test_auroc": test_auroc,
    "test_auprc": test_auprc,
    "confusion_matrix": cm.tolist(),  # Convert to list for JSON serialization
}
# Save results to a JSON file
results_path = os.path.join(config["output_dir"], f"results_{model_name}_linear_probe.json")
with open(results_path, "w") as f:
    json.dump(results, f, indent=4)

## 3.2. Simulating sparse labels

In [None]:
def get_stratified_subset_indices(labels, subset_size, random_state=None):
    """
    Get stratified subset indices from the labels.
    
    Args:
        labels (array-like): The labels for stratification.
        subset_size (int): The size of the subset to sample.
        random_state (int, optional): Random state for reproducibility.
        
    Returns:
        list: Indices of the sampled subset.
    """
    dummy_features = np.zeros((len(labels)))  # Dummy features for stratification
    sss = StratifiedShuffleSplit(n_splits=1, train_size=subset_size, random_state=random_state)

    for train_index,_  in sss.split(dummy_features, labels):
        return train_index

In [None]:
# Load data
seed = config["seed"]
data_path = config["data_dir"]

y_train = np.load(os.path.join(data_path, 'outcomes_sorted_set-a.npy'), allow_pickle=True)
y_val = np.load(os.path.join(data_path, 'outcomes_sorted_set-b.npy'),allow_pickle=True)
y_test = np.load(os.path.join(data_path, 'outcomes_sorted_set-c.npy'), allow_pickle=True)

scaler_name = config["scaler_name"]
data_format = "tuple" # or "tuple" depending on your data format
X_train = np.load(os.path.join(data_path,f"{data_format}_processed_{scaler_name}_set-a.npy"), allow_pickle=True)
X_val = np.load(os.path.join(data_path,f"{data_format}_processed_{scaler_name}_set-b.npy"), allow_pickle=True)
X_test = np.load(os.path.join(data_path,f"{data_format}_processed_{scaler_name}_set-c.npy"), allow_pickle=True)

In [None]:
# Train Tuple Transformer on 100, 500 and 1000 subset samples
subset_sizes = ["100", "500", "1000"]

idx_100 = get_stratified_subset_indices(y_train, 100, random_state=seed)
idx_500 = get_stratified_subset_indices(y_train, 500, random_state=seed)
idx_1000 = get_stratified_subset_indices(y_train, 1000, random_state=seed)

y_train_100 = y_train[idx_100]
X_train_100 = X_train[idx_100]

y_train_500 = y_train[idx_500]
X_train_500 = X_train[idx_500]

y_train_1000 = y_train[idx_1000]
X_train_1000 = X_train[idx_1000]

subsets = {
    "100": (X_train_100, y_train_100),
    "500": (X_train_500, y_train_500),
    "1000": (X_train_1000, y_train_1000)
}

In [None]:
val_dataset = MedicalTimeSeriesDatasetTuple(X_val, y_val, max_seq_len=config["tuple_max_seq_len"])
test_dataset = MedicalTimeSeriesDatasetTuple(X_test, y_test, max_seq_len=config["tuple_max_seq_len"])
collate_func = collate_fn # Use the custom collate_fn

train_loaders = {}

for size, (X_train_subset, y_train_subset) in subsets.items():
    print(f"Subset size: {size}")

    tmp_train_dataset = MedicalTimeSeriesDatasetTuple(X_train_subset, 
                                                      y_train_subset, 
                                                      max_seq_len=config["tuple_max_seq_len"])

    class_counts = np.bincount(y_train_subset.astype(int))
    class_weights = 1. / class_counts
    sample_weights = np.array([class_weights[int(t)] for t in y_train_subset])
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

    train_loader = DataLoader(tmp_train_dataset, batch_size=config["batch_size"], sampler=sampler, collate_fn=collate_func)
    train_loaders[size] = train_loader

val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, collate_fn=collate_func)
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False, collate_fn=collate_func)
    
    

In [None]:
config["model_type"] = "tuple" # Set the model type to tuple for training
device = config["device"]

for size, train_loader in train_loaders.items():
    print(f"Training Tuple Transformer on subset size: {size}")

    # Initialize the model
    model = get_model(config)
    model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=config["patience"])

    # 3. Training Loop
    best_val_auroc = -1.0
    epochs_no_improve = 0


    model_name = f"tuple_transformer_{size}_{config['epochs']}.pth"
    best_model_path = os.path.join(config["output_dir"], f"{model_name}.pth")
    config_file_path = os.path.join(config["output_dir"], f"config_{model_name}.json")
    with open(config_file_path, 'w') as f:
        json.dump(config, f, indent=4)


    # Train the model
    for epoch in range(config["epochs"]):
        print(f"\nEpoch {epoch+1}/{config['epochs']}")

        train_loss = train_model(model, train_loader, criterion, optimizer, config["device"], config["model_type"])
        val_loss, val_auroc, val_auprc, _ = evaluate(model, val_loader, criterion, config["device"], config["model_type"])

        scheduler.step(val_auroc)

        if val_auroc > best_val_auroc:
            print(f"Validation AuROC improved ({best_val_auroc:.4f} -> {val_auroc:.4f}). Saving model...")
            best_val_auroc = val_auroc
            torch.save(model.state_dict(), best_model_path)
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            print(f"Validation AuROC did not improve. ({epochs_no_improve}/{config['patience']})")

        if epochs_no_improve >= config["patience"]:
                print(f"Early stopping triggered after {epoch + 1} epochs.")
                break
        
    # 4. Load Best Model and Evaluate on Test Set
    print("\n--- Loading Best Model for Testing ---")
    if os.path.exists(best_model_path):
        model.load_state_dict(torch.load(best_model_path, map_location=config["device"]))

        print("\n--- Evaluating on Test Set ---")
        test_loss, test_auroc, test_auprc, cm = evaluate(model, test_loader, criterion, config["device"], config["model_type"])
        print("\n--- Test Results ---")
        print(f"Test Loss: {test_loss:.4f}")
        print(f"Test AuROC: {test_auroc:.4f}")
        print(f"Test AuPRC: {test_auprc:.4f}")
        print(f"Confusion Matrix:\n{cm}")
        print("--------------------")
        results = {
            "Test Loss": test_loss,
            "Test AuROC": test_auroc,
            "Test AuPRC": test_auprc,
            "Confusion Matrix": cm.tolist()  # Convert to list for JSON serialization
        }
        results_file_path = os.path.join(config["output_dir"], f"results_{model_name}.json")
        with open(results_file_path, 'w') as f:
            json.dump(results, f, indent=4)
    else:
        print(f"Best model not found at {best_model_path}. Skipping evaluation.")
        


                

#### Test subsets with self-supervised

In [None]:
# Load the encoder and projection head
config["model_type"] = "contrast"
encoder = get_model(config)
encoder_name = f"V3_contrastive_encoder_model_epoch_30.pth"
encoder.load_state_dict(torch.load(f"{config['output_dir']}/{encoder_name}.pth"))
encoder.to(device)



for size, train_loader in train_loaders.items():
    print(f"Training Linear Probe on subset size: {size}")

    for param in encoder.parameters():
        param.requires_grad = False
    config["model_type"] = "linear_probe"
    linear_probe = nn.Linear(config["proj_dim"], 1) 
    linear_probe.to(device)

    optimizer_probe = torch.optim.AdamW(linear_probe.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
    criterion = nn.BCEWithLogitsLoss()

    _, val_loader, test_loader = get_data_loaders(config)

    best_val_auroc = -1.0
    epochs_no_improve = 0
    probe_name = f"{encoder_name}_linear_probe_subset_{size}"
    best_model_path = os.path.join(config["output_dir"], f"{probe_name}.pth")
    config_file_path = os.path.join(config["output_dir"], f"config_{probe_name}.json")
    with open(config_file_path, 'w') as f:
        json.dump(config, f, indent=4)

    print("Training Linear Probe...")
    for epoch in range(config["epochs"]):
        print(f"\nEpoch {epoch+1}/{config['epochs']}")
        
        train_loss = train_model(encoder, train_loader, criterion, optimizer_probe, device, model_type= config["model_type"],linear_probe=linear_probe)
        print(f"Epoch {epoch+1}/{config['epochs']}, Loss: {train_loss:.4f}")

        val_loss, val_auroc, val_auprc, _ = evaluate(encoder, val_loader, criterion, device, model_type=config["model_type"], linear_probe=linear_probe)
        if val_auroc > best_val_auroc:
                print(f"Validation AuROC improved ({best_val_auroc:.4f} -> {val_auroc:.4f}). Saving model...")
                best_val_auroc = val_auroc
                torch.save(linear_probe.state_dict(), best_model_path)
                epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            print(f"Validation AuROC did not improve. ({epochs_no_improve}/{config['patience']})")

        if epochs_no_improve >= config["patience"]:
                print(f"Early stopping triggered after {epoch + 1} epochs.")
                break
        

    linear_probe.load_state_dict(torch.load(best_model_path))
    print("\n--- Evaluating on Test Set ---")
    test_loss, test_auroc, test_auprc, cm = evaluate(encoder, test_loader, criterion, config["device"], config["model_type"], linear_probe=linear_probe)
    print("\n--- Test Results ---")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test AuROC: {test_auroc:.4f}")
    print(f"Test AuPRC: {test_auprc:.4f}")
    print(f"Confusion Matrix:\n{cm}")
    print("--------------------")
    results = {
        "Test Loss": test_loss,
        "Test AuROC": test_auroc,
        "Test AuPRC": test_auprc,
        "Confusion Matrix": cm.tolist()  # Convert to list for JSON serialization
    }
    results_file_path = os.path.join(config["output_dir"], f"results_{probe_name}.json")
    with open(results_file_path, 'w') as f:
        json.dump(results, f, indent=4)

## 3.3 Visualising learned representations

In [None]:

config["model_type"] = "tuple"
train_loader, val_loader, test_loader = get_data_loaders(config) 

config["model_type"] = "contrast"

encoder = get_model(config)
encoder.load_state_dict(torch.load(f"{config['output_dir']}/V3_contrastive_encoder_model_epoch_30.pth.pth"))
encoder.to(device)

encoder.eval()
embeddings = []
labels = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Generating embeddings", leave=False):
        t_seq, z_seq, v_seq, attn_mask, label = batch
        t_seq, z_seq, v_seq, attn_mask = t_seq.to(device), z_seq.to(device), v_seq.to(device), attn_mask.to(device)
        rep = encoder.get_representation(t_seq, z_seq, v_seq, attn_mask)
        embeddings.append(rep.cpu().numpy())
        labels.append(label.cpu().numpy())

embeddings = np.concatenate(embeddings, axis=0)
labels = np.concatenate(labels, axis=0).squeeze()

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2)
tsne_embeddings = tsne.fit_transform(embeddings)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 6))
scatter = plt.scatter(tsne_embeddings[:, 0], tsne_embeddings[:, 1], c=labels, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, label='Label (e.g., Death=1, Survival=0)')
plt.title("t-SNE Visualization of Patient Embeddings")
plt.xlabel("Dimension 1")
plt.ylabel("Dimension 2")
plt.show()

In [None]:
plt.figure(figsize=(8, 6))
scatter = plt.scatter(tsne_embeddings[:, 0], tsne_embeddings[:, 1], c=labels, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, label='Label (e.g., Death=1, Survival=0)')
plt.title("t-SNE Visualization of Patient Embeddings")
plt.xlabel("Dimension 1")
plt.ylabel("Dimension 2")
plt.show()

In [None]:
from sklearn.metrics import silhouette_score

# embeddings_2d: your 2D t-SNE or UMAP embeddings
# labels: corresponding class labels for each data point
sil_score = silhouette_score(embeddings, labels)
print("Silhouette Score:", sil_score)