# TTML Survival Analysis Examples

This notebook demonstrates using the TTML model for survival analysis tasks using the SUPPORT dataset. We'll cover:

1. Time-to-Event Prediction
   - Survival time prediction
   - Survival curve estimation
   - Risk score calculation

2. Competing Risks Analysis
   - Multiple event types
   - Cause-specific hazards
   - Cumulative incidence functions

In [None]:
import sys
import os
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score
from lifelines import KaplanMeierFitter

# Import TTML modules
from tabular_transformer.models import TabularTransformer
from tabular_transformer.models.task_heads import SurvivalHead, CompetingRisksHead
from tabular_transformer.training import Trainer
from tabular_transformer.inference import predict
from tabular_transformer.explainability import global_explanations, local_explanations
from tabular_transformer.utils.config import TransformerConfig
from tabular_transformer.data.dataset import TabularDataset

# Import data utilities
from data_utils import download_support_dataset

## Part 1: Time-to-Event Prediction

First, we'll work with the SUPPORT dataset to predict survival times and estimate survival curves.

In [None]:
# Download SUPPORT dataset
support_df = download_support_dataset(save_csv=False)
print("SUPPORT dataset shape:", support_df.shape)
print("\nFeature types:")
print(support_df.dtypes)
print("\nEvent distribution:")
print(support_df['death'].value_counts(normalize=True))

In [None]:
# Identify numeric and categorical columns
numeric_features = support_df.select_dtypes(include=['int64', 'float64']).columns.tolist()
categorical_features = support_df.select_dtypes(include=['object']).columns.tolist()

# Remove time and event columns from features
time_column = 'time'
event_column = 'death'

if time_column in numeric_features:
    numeric_features.remove(time_column)
if event_column in numeric_features:
    numeric_features.remove(event_column)
if time_column in categorical_features:
    categorical_features.remove(time_column)
if event_column in categorical_features:
    categorical_features.remove(event_column)

# Create train/test datasets
train_dataset, test_dataset, _ = TabularDataset.from_dataframe(
    dataframe=support_df,
    numeric_columns=numeric_features,
    categorical_columns=categorical_features,
    target_columns={
        'time': [time_column],
        'event': [event_column]
    },
    validation_split=0.2,
    random_state=42
)

In [None]:
# Get feature dimensions from preprocessor
feature_dims = train_dataset.preprocessor.get_feature_dimensions()
numeric_dim = feature_dims['numeric_dim']
categorical_dims = feature_dims['categorical_dims']
categorical_embedding_dims = feature_dims['categorical_embedding_dims']

# Model configuration
config = TransformerConfig(
    embed_dim=128,
    num_heads=8,
    num_layers=4,
    dropout=0.2,
    variational=False
)

# Initialize transformer encoder
encoder = TabularTransformer(
    numeric_dim=numeric_dim,
    categorical_dims=categorical_dims,
    categorical_embedding_dims=categorical_embedding_dims,
    config=config
)

# Initialize survival head
survival_head = SurvivalHead(
    input_dim=128,  # Should match config.embed_dim
    num_time_bins=50  # Number of time intervals for discrete-time survival
)

In [None]:
# Create data loaders
train_loader = train_dataset.create_dataloader(batch_size=64, shuffle=True)
test_loader = test_dataset.create_dataloader(batch_size=64, shuffle=False)

# Initialize trainer
trainer = Trainer(
    encoder=encoder,
    task_head=survival_head,
    optimizer=None,  # Will be created by trainer
    device=None  # Will use CUDA if available
)

# Train the model
history = trainer.train(
    train_loader=train_loader,
    val_loader=test_loader,
    num_epochs=25,
    early_stopping_patience=3
)

In [None]:
# Make predictions
predictions = trainer.predict(test_loader)

# Get survival predictions and times/events
survival_probs = predictions['main']['survival_probabilities']
time_test = test_dataset.targets['time']
event_test = test_dataset.targets['event']

# Calculate concordance index
c_index = survival_head.calculate_concordance_index(
    survival_probs,
    time_test,
    event_test
)

print(f"Concordance Index: {c_index:.4f}")

# Plot Kaplan-Meier curves for different risk groups
risk_scores = survival_probs[:, -1].numpy()  # Use last time point for risk scoring
risk_groups = pd.qcut(risk_scores, q=3, labels=['Low', 'Medium', 'High'])

plt.figure(figsize=(10, 6))
kmf = KaplanMeierFitter()

for group in ['Low', 'Medium', 'High']:
    mask = risk_groups == group
    kmf.fit(
        time_test[mask],
        event_test[mask],
        label=f'{group} Risk'
    )
    kmf.plot()

plt.title('Survival Curves by Risk Group')
plt.xlabel('Time')
plt.ylabel('Survival Probability')
plt.grid(True)
plt.show()

## Feature Importance for Survival

Let's analyze which features are most important for survival prediction.

In [None]:
# Calculate and plot feature importance
feature_importance = global_explanations.calculate_feature_importance(
    encoder=encoder,
    task_head=survival_head,
    dataset=test_dataset,
    feature_names=numeric_features + categorical_features
)

plt.figure(figsize=(12, 6))
feature_importance.sort_values().plot(kind='barh')
plt.title('Feature Importance for Survival Prediction')
plt.xlabel('Importance Score')
plt.tight_layout()
plt.show()

## Part 2: Competing Risks Analysis

Now we'll demonstrate competing risks analysis by considering different causes of death.

In [None]:
# Create train/test datasets with cause information
train_dataset_cr, test_dataset_cr, _ = TabularDataset.from_dataframe(
    dataframe=support_df,
    numeric_columns=numeric_features,
    categorical_columns=categorical_features,
    target_columns={
        'time': [time_column],
        'event': [event_column],
        'cause': ['cause']  # Additional cause column
    },
    validation_split=0.2,
    random_state=42
)

In [None]:
# Initialize competing risks head
competing_risks_head = CompetingRisksHead(
    input_dim=128,  # Should match config.embed_dim
    num_risks=3,  # Number of competing events
    num_time_bins=50
)

# Create data loaders
train_loader_cr = train_dataset_cr.create_dataloader(batch_size=64, shuffle=True)
test_loader_cr = test_dataset_cr.create_dataloader(batch_size=64, shuffle=False)

# Initialize trainer
trainer_cr = Trainer(
    encoder=encoder,  # Reuse encoder from survival analysis
    task_head=competing_risks_head,
    optimizer=None,  # Will be created by trainer
    device=None  # Will use CUDA if available
)

# Train the model
history_cr = trainer_cr.train(
    train_loader=train_loader_cr,
    val_loader=test_loader_cr,
    num_epochs=25,
    early_stopping_patience=3
)

In [None]:
# Make predictions
predictions_cr = trainer_cr.predict(test_loader_cr)

# Get competing risks predictions
risk_probs = predictions_cr['main']['risk_probabilities']

# Plot cumulative incidence functions
plt.figure(figsize=(12, 6))
time_points = np.linspace(0, max(test_dataset_cr.targets['time']), 100)

for i in range(3):  # For each competing risk
    cif = competing_risks_head.calculate_cumulative_incidence(
        risk_probs[:, i, :],
        time_points
    )
    plt.plot(time_points, cif.mean(axis=0), label=f'Cause {i+1}')

plt.title('Cumulative Incidence Functions')
plt.xlabel('Time')
plt.ylabel('Cumulative Incidence')
plt.legend()
plt.grid(True)
plt.show()

## Individual Patient Analysis

Let's examine predictions for individual patients to understand risk factors.

In [None]:
# Get local explanations for a few examples
sample_indices = np.random.choice(len(test_dataset_cr), 3, replace=False)

for idx in sample_indices:
    explanation = local_explanations.explain_prediction(
        encoder=encoder,
        task_head=competing_risks_head,
        instance_idx=idx,
        dataset=test_dataset_cr,
        feature_names=numeric_features + categorical_features
    )
    
    print(f"\nPatient {idx+1}:")
    print(f"Observed time: {test_dataset_cr.targets['time'][idx]:.1f}")
    print(f"Event occurred: {bool(test_dataset_cr.targets['event'][idx])}")
    if test_dataset_cr.targets['event'][idx]:
        print(f"Cause: {test_dataset_cr.targets['cause'][idx]}")
    
    print("\nTop risk factors:")
    sorted_features = sorted(explanation.items(), key=lambda x: abs(x[1]), reverse=True)[:5]
    for feature, contribution in sorted_features:
        print(f"{feature}: {contribution:.4f}")

## Conclusion

This notebook demonstrated the survival analysis capabilities of the TTML model:

1. Time-to-Event Prediction
   - Successfully predicted survival times
   - Generated interpretable survival curves
   - Achieved good concordance index

2. Competing Risks Analysis
   - Modeled multiple causes of events
   - Estimated cause-specific risks
   - Provided individual risk assessments

The TTML model showed strong performance in both standard survival analysis and competing risks scenarios, while providing valuable insights through its explainability features.