# CytoGPS Survival Analysis Demo
## Random Survival Forest and Transformer Survival on AnnData

This notebook demonstrates survival analysis on AnnData object with:
- **Observations:** 1,222 samples
- **Variables:** 2,748 cytogenetic features
- **Task:** Compare Random Survival Forest and Transformer-based Survival models

## Section 1: Load and Explore AnnData Object

In [None]:
import numpy as np
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import concordance_index_score
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

In [None]:
# Create a synthetic AnnData object matching your specifications
n_obs = 1222
n_vars = 2748

# Generate random feature matrix
X = np.random.randn(n_obs, n_vars)

# Create observation metadata
obs_data = {
    'Line_Number': np.arange(n_obs),
    'pat_id': np.random.randint(1, 300, n_obs),
    'Clone_Number': np.random.randint(1, 10, n_obs),
    'Karyotype_Revised': [f'kar_{i}' for i in range(n_obs)],
    'Clone_Code': np.random.randint(1, 100, n_obs),
    'survival_time': np.random.exponential(scale=50, size=n_obs) + 10,  # months
    'event_occurred': np.random.binomial(1, 0.6, n_obs),  # 60% event rate
    'sex': np.random.choice(['M', 'F'], n_obs),
    'age': np.random.uniform(30, 85, n_obs),
    'n_abs_aberr': np.random.randint(0, 15, n_obs),
    'n_abs_aberr_iscn': np.random.randint(0, 15, n_obs),
    'n_abs_max_clone': np.random.randint(0, 15, n_obs),
    'ck_abs': np.random.randint(0, 10, n_obs),
    'ck_abs_iscn': np.random.randint(0, 10, n_obs),
    'ck_abs_max_clone': np.random.randint(0, 10, n_obs),
    'diff_max_clone': np.random.randint(0, 5, n_obs),
    'ck_category': np.random.choice(['good', 'intermediate', 'poor'], n_obs),
    'ck_category_iscn': np.random.choice(['good', 'intermediate', 'poor'], n_obs),
    'ck_category_max_clone': np.random.choice(['good', 'intermediate', 'poor'], n_obs),
    'clonal_evolution': np.random.binomial(1, 0.5, n_obs),
    'TP53mutStatus': np.random.choice(['WT', 'MUT', 'NA'], n_obs),
    'IGHVmutStatus': np.random.choice(['M', 'UM', 'NA'], n_obs),
    'percent_path_cells_ip': np.random.uniform(0, 100, n_obs),
    'material': np.random.choice(['PB', 'BM', 'LN'], n_obs),
    'bal_transloc': np.random.randint(0, 5, n_obs),
    'unbal_transloc': np.random.randint(0, 5, n_obs),
    'deletions': np.random.randint(0, 5, n_obs),
    'der_chrom': np.random.randint(0, 5, n_obs),
    'only_bal_abs_trisomies': np.random.randint(0, 5, n_obs),
    'comma_abs_count': np.random.randint(0, 100, n_obs),
    'character_count': np.random.randint(100, 500, n_obs),
    'any_uncertain_abs': np.random.binomial(1, 0.3, n_obs),
    'any_detailed_abs': np.random.binomial(1, 0.5, n_obs),
    'LGF_subcytoband_sum': np.random.uniform(0, 50, n_obs),
    'Loss_cytoband_sum': np.random.uniform(0, 50, n_obs),
    'Gain_cytoband_sum': np.random.uniform(0, 50, n_obs),
    'Fusion_cytoband_sum': np.random.uniform(0, 50, n_obs),
    'LGF_cytoband_sum': np.random.uniform(0, 50, n_obs),
}

obs_df = pd.DataFrame(obs_data)

# Create variable names
var_names = [f'feature_{i}' for i in range(n_vars)]

# Create obsm data (observation matrices for different cytoband categories)
obsm_data = {
    'Fusion_cytoband': np.random.randn(n_obs, 100),
    'Fusion_subcytoband': np.random.randn(n_obs, 80),
    'Gain_cytoband': np.random.randn(n_obs, 120),
    'Gain_subcytoband': np.random.randn(n_obs, 100),
    'Loss_cytoband': np.random.randn(n_obs, 110),
    'Loss_subcytoband': np.random.randn(n_obs, 90),
}

# Create uns data (unstructured annotation)
uns_data = {
    'Fusion_cytoband_varnames': [f'Fusion_cytoband_{i}' for i in range(100)],
    'Fusion_subcytoband_varnames': [f'Fusion_subcytoband_{i}' for i in range(80)],
    'Gain_cytoband_varnames': [f'Gain_cytoband_{i}' for i in range(120)],
    'Gain_subcytoband_varnames': [f'Gain_subcytoband_{i}' for i in range(100)],
    'Loss_cytoband_varnames': [f'Loss_cytoband_{i}' for i in range(110)],
    'Loss_subcytoband_varnames': [f'Loss_subcytoband_{i}' for i in range(90)],
}

# Create AnnData object
adata = ad.AnnData(X=X, obs=obs_df, var=pd.DataFrame(index=var_names), obsm=obsm_data, uns=uns_data)

print("AnnData Object Created Successfully!")
print(f"\nAnnData Structure:")
print(adata)
print(f"\nShape: {adata.n_obs} observations × {adata.n_vars} variables")
print(f"\nObservation Metadata (obs):")
print(adata.obs.head())
print(f"\nSurvival Data Summary:")
print(adata.obs[['survival_time', 'event_occurred']].describe())

## Section 2: Prepare Data for Survival Analysis

In [None]:
# Extract features from adata.X
X_features = adata.X.astype(np.float32)

# Extract survival outcomes
survival_time = adata.obs['survival_time'].values
event_occurred = adata.obs['event_occurred'].values

print(f"Feature matrix shape: {X_features.shape}")
print(f"Survival time shape: {survival_time.shape}")
print(f"Event occurred shape: {event_occurred.shape}")
print(f"\nSurvival Statistics:")
print(f"  Mean survival time: {survival_time.mean():.2f} months")
print(f"  Median survival time: {np.median(survival_time):.2f} months")
print(f"  Event rate: {event_occurred.mean():.2%}")

# Handle missing values (none in this synthetic case, but good practice)
print(f"\nMissing values in features: {np.isnan(X_features).sum()}")
print(f"Missing values in survival_time: {np.isnan(survival_time).sum()}")
print(f"Missing values in event_occurred: {np.isnan(event_occurred).sum()}")

# Normalize features
scaler = StandardScaler()
X_features_normalized = scaler.fit_transform(X_features)

print(f"\nFeatures normalized:")
print(f"  Mean: {X_features_normalized.mean():.6f}")
print(f"  Std: {X_features_normalized.std():.6f}")

# Split data into training (80%) and testing (20%) sets
X_train, X_test, time_train, time_test, event_train, event_test = train_test_split(
    X_features_normalized, survival_time, event_occurred,
    test_size=0.2, random_state=42
)

print(f"\nData split:")
print(f"  Training set: {X_train.shape[0]} samples")
print(f"  Testing set: {X_test.shape[0]} samples")

## Section 3: Train Random Survival Forest Model

In [None]:
# Install scikit-survival if not already installed
import subprocess
import sys

try:
    from sksurv.ensemble import RandomSurvivalForest
    from sksurv.metrics import concordance_index_censored
except ImportError:
    print("Installing scikit-survival...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-survival", "-q"])
    from sksurv.ensemble import RandomSurvivalForest
    from sksurv.metrics import concordance_index_censored

# Prepare structured array for survival data (required by sksurv)
structured_time_train = np.array([(bool(e), t) for e, t in zip(event_train, time_train)],
                                  dtype=[('event', bool), ('time', float)])
structured_time_test = np.array([(bool(e), t) for e, t in zip(event_test, time_test)],
                                 dtype=[('event', bool), ('time', float)])

print("Training Random Survival Forest Model...")

# Initialize and train Random Survival Forest
rsf = RandomSurvivalForest(
    n_estimators=100,
    min_samples_split=10,
    min_samples_leaf=5,
    max_depth=15,
    random_state=42,
    n_jobs=-1,
    verbose=0
)

rsf.fit(X_train, structured_time_train)

print("Random Survival Forest training completed!")
print(f"Model parameters:")
print(f"  Number of trees: {rsf.n_estimators}")
print(f"  Max depth: {rsf.max_depth}")
print(f"  Min samples split: {rsf.min_samples_split}")

# Generate predictions on test set
rsf_predictions = rsf.predict(X_test)

# Calculate concordance index for RSF
rsf_cindex = concordance_index_censored(event_test, time_test, rsf_predictions)[0]

print(f"\nRandom Survival Forest Performance:")
print(f"  Concordance Index: {rsf_cindex:.4f}")
print(f"  Predictions shape: {rsf_predictions.shape}")
print(f"  Prediction range: [{rsf_predictions.min():.4f}, {rsf_predictions.max():.4f}]")

## Section 4: Train Transformer Survival Model

In [None]:
# Install torch and pycox for Transformer Survival models
try:
    import torch
    import torch.nn as nn
    from torch.utils.data import TensorDataset, DataLoader
except ImportError:
    print("Installing PyTorch...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "-q"])
    import torch
    import torch.nn as nn
    from torch.utils.data import TensorDataset, DataLoader

try:
    from pycox.models import DeepHitSingle
    from pycox.preprocessing.label_transforms import LabelTransformDiscrete
except ImportError:
    print("Installing pycox...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "pycox", "-q"])
    from pycox.models import DeepHitSingle
    from pycox.preprocessing.label_transforms import LabelTransformDiscrete

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Convert to torch tensors
X_train_torch = torch.FloatTensor(X_train).to(device)
X_test_torch = torch.FloatTensor(X_test).to(device)
time_train_torch = torch.FloatTensor(time_train).to(device)
time_test_torch = torch.FloatTensor(time_test).to(device)
event_train_torch = torch.FloatTensor(event_train).to(device)
event_test_torch = torch.FloatTensor(event_test).to(device)

# Define simple Transformer-based Survival Model
class TransformerSurvivalModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, num_heads=8, num_layers=2, dropout=0.1):
        super(TransformerSurvivalModel, self).__init__()
        
        # Input projection
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            batch_first=True,
            activation='relu'
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Output layers
        self.fc_layers = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )
        
    def forward(self, x):
        # Project input to hidden dimension
        x = self.input_proj(x)
        
        # Add sequence dimension for transformer (batch, seq, features)
        x = x.unsqueeze(1)  # (batch, 1, hidden_dim)
        
        # Apply transformer
        x = self.transformer_encoder(x)
        
        # Pool: take the mean of the sequence
        x = x.mean(dim=1)  # (batch, hidden_dim)
        
        # Output layer - positive output for survival risk
        x = self.fc_layers(x)
        x = torch.relu(x)  # Ensure positive risk scores
        
        return x.squeeze(1)

# Initialize model
print("\nInitializing Transformer Survival Model...")
transformer_model = TransformerSurvivalModel(
    input_dim=X_train.shape[1],
    hidden_dim=256,
    num_heads=8,
    num_layers=2,
    dropout=0.1
).to(device)

# Loss function: Cox partial likelihood
class CoxLoss(nn.Module):
    def forward(self, hazard, time, event):
        # Sort by time
        sorted_indices = torch.argsort(time, descending=True)
        hazard = hazard[sorted_indices]
        time = time[sorted_indices]
        event = event[sorted_indices]
        
        # Calculate Cox partial likelihood
        log_hazard = torch.log(hazard + 1e-8)
        cumsum_hazard = torch.cumsum(torch.exp(hazard), dim=0)
        
        loss = -torch.sum(event * (log_hazard - torch.log(cumsum_hazard + 1e-8)))
        return loss / (torch.sum(event) + 1e-8)

criterion = CoxLoss()
optimizer = torch.optim.Adam(transformer_model.parameters(), lr=0.001, weight_decay=1e-5)

# Training loop
print("Training Transformer Survival Model...")
num_epochs = 50
batch_size = 32

# Create data loaders
train_dataset = TensorDataset(X_train_torch, time_train_torch, event_train_torch)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

train_losses = []
for epoch in range(num_epochs):
    transformer_model.train()
    epoch_loss = 0
    
    for batch_X, batch_time, batch_event in train_loader:
        optimizer.zero_grad()
        
        # Forward pass
        hazard = transformer_model(batch_X)
        
        # Calculate loss
        loss = criterion(hazard, batch_time, batch_event)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    train_losses.append(epoch_loss)
    
    if (epoch + 1) % 10 == 0:
        print(f"  Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.6f}")

print("Transformer Survival Model training completed!")

# Generate predictions on test set
transformer_model.eval()
with torch.no_grad():
    transformer_predictions = transformer_model(X_test_torch).cpu().numpy()

# Calculate concordance index for Transformer
transformer_cindex = concordance_index_censored(event_test, time_test, transformer_predictions)[0]

print(f"\nTransformer Survival Model Performance:")
print(f"  Concordance Index: {transformer_cindex:.4f}")
print(f"  Predictions shape: {transformer_predictions.shape}")
print(f"  Prediction range: [{transformer_predictions.min():.4f}, {transformer_predictions.max():.4f}]")

## Section 5: Compare Model Performance

In [None]:
from sksurv.metrics import integrated_brier_score, brier_score
from scipy.special import expit

# Calculate additional survival metrics
print("="*70)
print("COMPREHENSIVE MODEL PERFORMANCE COMPARISON")
print("="*70)

# 1. Concordance Index (higher is better, max=1.0)
print("\n1. CONCORDANCE INDEX (Harrell's C-index)")
print("-" * 70)
print(f"   Random Survival Forest:     {rsf_cindex:.4f}")
print(f"   Transformer Survival:       {transformer_cindex:.4f}")
print(f"   Baseline (random):          0.5000")

# Determine which model is better
if rsf_cindex > transformer_cindex:
    print(f"   ✓ RSF is better by: {(rsf_cindex - transformer_cindex):.4f}")
else:
    print(f"   ✓ Transformer is better by: {(transformer_cindex - rsf_cindex):.4f}")

# 2. Risk Score Statistics
print("\n2. RISK SCORE STATISTICS")
print("-" * 70)

rsf_risk_scores = rsf_predictions
transformer_risk_scores = transformer_predictions

print(f"   Random Survival Forest:")
print(f"      Mean:              {rsf_risk_scores.mean():.4f}")
print(f"      Std Dev:           {rsf_risk_scores.std():.4f}")
print(f"      Min:               {rsf_risk_scores.min():.4f}")
print(f"      Max:               {rsf_risk_scores.max():.4f}")
print(f"      Median:            {np.median(rsf_risk_scores):.4f}")

print(f"\n   Transformer Survival:")
print(f"      Mean:              {transformer_risk_scores.mean():.4f}")
print(f"      Std Dev:           {transformer_risk_scores.std():.4f}")
print(f"      Min:               {transformer_risk_scores.min():.4f}")
print(f"      Max:               {transformer_risk_scores.max():.4f}")
print(f"      Median:            {np.median(transformer_risk_scores):.4f}")

# 3. Create performance comparison dataframe
comparison_df = pd.DataFrame({
    'Metric': ['Concordance Index', 'Mean Risk Score', 'Std Dev Risk Score', 'Min Risk Score', 'Max Risk Score'],
    'RSF': [
        rsf_cindex,
        rsf_risk_scores.mean(),
        rsf_risk_scores.std(),
        rsf_risk_scores.min(),
        rsf_risk_scores.max()
    ],
    'Transformer': [
        transformer_cindex,
        transformer_risk_scores.mean(),
        transformer_risk_scores.std(),
        transformer_risk_scores.min(),
        transformer_risk_scores.max()
    ]
})

print("\n3. PERFORMANCE COMPARISON TABLE")
print("-" * 70)
print(comparison_df.to_string(index=False))

# 4. Stratify patients by risk quartiles
print("\n4. PATIENT STRATIFICATION BY RISK QUARTILES")
print("-" * 70)

for model_name, risk_scores in [("RSF", rsf_risk_scores), ("Transformer", transformer_risk_scores)]:
    quartiles = np.percentile(risk_scores, [25, 50, 75])
    q1 = (risk_scores <= quartiles[0]).sum()
    q2 = ((risk_scores > quartiles[0]) & (risk_scores <= quartiles[1])).sum()
    q3 = ((risk_scores > quartiles[1]) & (risk_scores <= quartiles[2])).sum()
    q4 = (risk_scores > quartiles[2]).sum()
    
    print(f"\n   {model_name}:")
    print(f"      Q1 (Low risk):     {q1:4d} patients ({q1/len(risk_scores):.1%})")
    print(f"      Q2 (Mid-Low risk): {q2:4d} patients ({q2/len(risk_scores):.1%})")
    print(f"      Q3 (Mid-High risk):{q3:4d} patients ({q3/len(risk_scores):.1%})")
    print(f"      Q4 (High risk):    {q4:4d} patients ({q4/len(risk_scores):.1%})")

print("\n" + "="*70)

## Section 6: Visualize Survival Predictions

In [None]:
from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test

# Create figure with multiple subplots
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('Survival Analysis: Random Survival Forest vs Transformer Model', fontsize=16, fontweight='bold')

# 1. Risk Score Distribution Comparison
ax = axes[0, 0]
ax.hist(rsf_risk_scores, bins=30, alpha=0.6, label='RSF', color='steelblue', edgecolor='black')
ax.hist(transformer_risk_scores, bins=30, alpha=0.6, label='Transformer', color='coral', edgecolor='black')
ax.set_xlabel('Risk Score', fontsize=11, fontweight='bold')
ax.set_ylabel('Number of Patients', fontsize=11, fontweight='bold')
ax.set_title('Distribution of Risk Scores', fontsize=12, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(alpha=0.3)

# 2. Risk Score Correlation
ax = axes[0, 1]
ax.scatter(rsf_risk_scores, transformer_risk_scores, alpha=0.5, s=30, color='purple')
correlation = np.corrcoef(rsf_risk_scores, transformer_risk_scores)[0, 1]
ax.set_xlabel('RSF Risk Score', fontsize=11, fontweight='bold')
ax.set_ylabel('Transformer Risk Score', fontsize=11, fontweight='bold')
ax.set_title(f'Model Agreement (Correlation: {correlation:.3f})', fontsize=12, fontweight='bold')
ax.grid(alpha=0.3)

# Add diagonal line
lims = [
    np.min([ax.get_xlim(), ax.get_ylim()]),
    np.max([ax.get_xlim(), ax.get_ylim()]),
]
ax.plot(lims, lims, 'k--', alpha=0.5, zorder=0)

# 3. Kaplan-Meier curves for RSF risk quartiles
ax = axes[1, 0]
kmf = KaplanMeierFitter()

rsf_quartiles = np.percentile(rsf_risk_scores, [25, 50, 75])
rsf_risk_groups = np.digitize(rsf_risk_scores, rsf_quartiles)

colors = ['green', 'yellow', 'orange', 'red']
group_labels = ['Q1 (Low)', 'Q2 (Mid-Low)', 'Q3 (Mid-High)', 'Q4 (High)']

for group_idx in range(1, 5):
    mask = rsf_risk_groups == group_idx
    kmf.fit(time_test[mask], event_test[mask], label=group_labels[group_idx-1])
    kmf.plot_survival_function(ax=ax, ci_show=False, linewidth=2.5, color=colors[group_idx-1])

ax.set_xlabel('Time (months)', fontsize=11, fontweight='bold')
ax.set_ylabel('Survival Probability', fontsize=11, fontweight='bold')
ax.set_title('Kaplan-Meier: RSF Risk Stratification', fontsize=12, fontweight='bold')
ax.grid(alpha=0.3)
ax.legend(loc='lower left', fontsize=10)
ax.set_ylim([0, 1.05])

# 4. Kaplan-Meier curves for Transformer risk quartiles
ax = axes[1, 1]
kmf_t = KaplanMeierFitter()

transformer_quartiles = np.percentile(transformer_risk_scores, [25, 50, 75])
transformer_risk_groups = np.digitize(transformer_risk_scores, transformer_quartiles)

for group_idx in range(1, 5):
    mask = transformer_risk_groups == group_idx
    kmf_t.fit(time_test[mask], event_test[mask], label=group_labels[group_idx-1])
    kmf_t.plot_survival_function(ax=ax, ci_show=False, linewidth=2.5, color=colors[group_idx-1])

ax.set_xlabel('Time (months)', fontsize=11, fontweight='bold')
ax.set_ylabel('Survival Probability', fontsize=11, fontweight='bold')
ax.set_title('Kaplan-Meier: Transformer Risk Stratification', fontsize=12, fontweight='bold')
ax.grid(alpha=0.3)
ax.legend(loc='lower left', fontsize=10)
ax.set_ylim([0, 1.05])

plt.tight_layout()
plt.savefig('survival_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("Survival comparison plots saved as 'survival_comparison.png'")