# Survival model development

Notebook to develop own survival heads. Take note to to

- [ ] (negative log) partial likelihood (loss function or Cox PH)
- [ ] Deal with ties!
- [ ] Extras:
    - [ ] Baseline hazard and survival curves
    - [ ] ...

In [None]:
import torch
import torch.nn as nn
import math
import ukko 
import importlib
# For preprocessing
print("Loading sklearn")
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper 
import pandas as pd
import numpy as np
print("Libraries loaded")

In [None]:
class SurvivalHead(nn.Module):
    """
    Neural Cox proportional hazards model head.
    Outputs hazard ratios instead of binary classification.
    """
    def __init__(self, d_model, n_features, dropout=0.1):
        super().__init__()
        self.risk_score = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, 1)  # log hazard ratio
        )

    def forward(self, x):
        # Risk score (log hazard ratio)
        return torch.exp(self.risk_score(x))  # hazard ratio

    
def cox_loss(risk_scores, survival_times, events):
    """
    Negative log partial likelihood for Cox model
    
    Args:
        risk_scores: predicted hazard ratios
        survival_times: time to event/censoring
        events: 1 if event occurred, 0 if censored
    """
    # Sort by survival time
    _, indices = torch.sort(survival_times, descending=True)
    risk_scores = risk_scores[indices]
    events = events[indices]
    
    # Calculate log partial likelihood
    log_risk = torch.log(torch.cumsum(torch.exp(risk_scores), 0))
    likelihood = risk_scores - log_risk
    
    # Mask for events only
    return -torch.mean(likelihood * events)

In [None]:
class SurvivalDataset(Dataset):
    def __init__(self, features, survival_times, events):
        self.features = features
        self.times = survival_times    # Time to event/censoring
        self.events = events           # Event indicator (1=death, 0=censored)

In [None]:
#Load tidy data
print("Loading tidy data")
df_xy = pd.read_csv("data/df_xy_synth_v1.csv")
# IMPUTE nan: -1
df_xy = df_xy.fillna(-1)

# Define function to convert df into 3-D numpy array
def convert_to_3d_df(df):

    # Convert column names to tuples, assuming this "('feature', timepoint)"
    columns = [eval(col) for col in df.columns]
    df.columns = columns
    
    # Extract unique features and timepoints
    features = sorted(list(set([col[0] for col in columns])))
    timepoints = sorted(list(set([col[1] for col in columns])))
    
    # Initialize a 3D numpy array
    n_rows = df.shape[0]
    n_features = len(features)
    n_timepoints = len(timepoints)
    data_3d = np.empty((n_rows, n_features, n_timepoints))
    data_3d.fill(np.nan)
    
    # Map feature names and timepoints to indices
    feature_indices = {feature: i for i, feature in enumerate(features)}
    timepoint_indices = {timepoint: i for i, timepoint in enumerate(timepoints)}
    
    # Fill the 3D array with data from the DataFrame
    for col in columns:
        feature, timepoint = col
        feature_idx = feature_indices[feature]
        timepoint_idx = timepoint_indices[timepoint]
        data_3d[:, feature_idx, timepoint_idx] = df[col]

    # Create a MultiIndex for the columns of the 3D DataFrame
    columns = pd.MultiIndex.from_product([features, timepoints], names=["Feature", "Timepoint"])
    
    # Create the 3D DataFrame
    df_multiindex = pd.DataFrame(data_3d.reshape(n_rows, -1), columns=columns)
    
    return df_multiindex, data_3d

# Convert AML data to multiindex df
df_x, data_3d = convert_to_3d_df(df_xy.iloc[:,3:].fillna(-1))
df_y = df_xy.iloc[:,:3]
display(df_x)
display(df_y)


## Custom Cox PH implementation

In [None]:
class LinearCoxPH(nn.Module):
    """
    Classical Cox PH with linear predictor: h(t|x) = h₀(t)exp(βx)
    """
    def __init__(self, n_features):
        super().__init__()
        self.modelname = "LinearCoxPH"
        self.beta = nn.Linear(n_features, 1, bias=False)  # β coefficients
        
    def forward(self, x):
        return torch.exp(self.beta(x))  # exp(βx)

In [None]:
class CoxPHModel(nn.Module):
    """
    Classical Cox Proportional Hazards model implemented in PyTorch.
    Learns a linear combination of features to predict hazard ratios.
    """
    def __init__(self, n_features, hidden_size=32):
        super().__init__()
        self.modelname = "CoxPHModel"
        self.hidden_size = hidden_size
        
        # Linear hazard prediction
        self.hazard_ratio = nn.Sequential(
            nn.Linear(n_features, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)
        )
        
    def forward(self, x):
        """
        Compute hazard ratios for each sample.
        
        Args:
            x: Input features [batch_size, n_features]
            
        Returns:
            hazard_ratios: Predicted hazard ratios [batch_size, 1]
        """
        return torch.exp(self.hazard_ratio(x))  # exp(β * x)

In [None]:
def cox_loss(hazard_ratios, durations, events):
    """
    Negative log partial likelihood for Cox model.
    
    Args:
        hazard_ratios: Predicted hazard ratios [batch_size, 1]
        durations: Time to event/censoring [batch_size]
        events: Event indicators (1=event, 0=censored) [batch_size]
    
    Returns:
        loss: Negative log partial likelihood
    """
    # Sort all arrays by duration in descending order
    sorted_idx = torch.argsort(durations, descending=True)
    hazard_ratios = hazard_ratios[sorted_idx]
    events = events[sorted_idx]
    
    # Calculate log risk (cumulative hazard)
    log_risk = torch.logcumsumexp(hazard_ratios.flatten(), dim=0)
    
    # Select events that contribute to likelihood
    event_indices = (events == 1).nonzero().flatten()
    
    if len(event_indices) == 0:
        return torch.tensor(0.0, requires_grad=True)
    
    # Calculate negative log likelihood
    partial_likelihood = hazard_ratios[event_indices].flatten() - log_risk[event_indices]
    neg_likelihood = -torch.mean(partial_likelihood)
    
    return neg_likelihood

In [None]:
def train_cox_model(model, train_loader, val_loader, epochs=100, lr=0.001, device='cpu'):
    """
    Train Cox proportional hazards model.
    
    Args:
        model: CoxPHModel instance
        train_loader: DataLoader with (features, durations, events)
        val_loader: DataLoader with (features, durations, events)
        epochs: Number of training epochs
        lr: Learning rate
        device: 'cuda' or 'cpu'
    """
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': []}
    
    print(f"Training {model.modelname}")
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        for x, durations, events in train_loader:
            x = x.float().to(device)
            durations = durations.float().to(device)
            events = events.float().to(device)
            
            optimizer.zero_grad()
            hazard_ratios = model(x)
            loss = cox_loss(hazard_ratios, durations, events)
            
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, durations, events in val_loader:
                x = x.float().to(device)
                durations = durations.float().to(device)
                events = events.float().to(device)
                
                hazard_ratios = model(x)
                val_loss += cox_loss(hazard_ratios, durations, events).item()
        
        # Log metrics
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"{model.modelname}.pt")
            
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
    
    return history

In [None]:
from torch.utils.data import Dataset

class SurvivalDataset(Dataset):
    """
    Dataset for survival analysis with Cox PH model.
    """
    def __init__(self, features, durations, events):
        self.features = torch.FloatTensor(features)
        self.durations = torch.FloatTensor(durations)
        self.events = torch.FloatTensor(events)
        
    def __len__(self):
        return len(self.features)
        
    def __getitem__(self, idx):
        return (self.features[idx], 
                self.durations[idx], 
                self.events[idx])


In [None]:
# Test/Example data

# from sksurv.datasets import load_veterans_lung_cancer

# data_x, data_y = load_veterans_lung_cancer()
# data_y

from lifelines.datasets import load_regression_dataset
regression_dataset = load_regression_dataset() # a Pandas DataFrame
regression_dataset.head()

In [None]:
display(regression_dataset)

In [None]:
# Example usage
from torch.utils.data import DataLoader

# Split data
df_train = regression_dataset.sample(frac=0.8)
df_val = regression_dataset.drop(df_train.index)

# Create datasets
train_dataset = SurvivalDataset(
    features=df_train.drop(columns=['T', 'E']).values,
    durations=df_train['T'].values,
    events=df_train['E'].values
)

train_loader = DataLoader(
    train_dataset, 
    batch_size=200, 
    shuffle=True
)

val_dataset = SurvivalDataset(
    features=df_val.drop(columns=['T', 'E']).values,
    durations=df_val['T'].values,
    events=df_val['E'].values
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=200, 
    shuffle=False
)

# Define model
n_features = 3

model = CoxPHModel(n_features=n_features)
# Train model
history = train_cox_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=150
)
print(' ')

model2 = LinearCoxPH(n_features=n_features)
# Train model
history = train_cox_model(
    model=model2,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=150
)



In [None]:
def get_risk_groups(model, data_loader, device='cpu'):
    """Get risk scores for all patients and split into high/low risk groups"""
    model.eval()
    all_risks = []
    all_times = []
    all_events = []
    
    with torch.no_grad():
        for x, durations, events in data_loader:
            x = x.float().to(device)
            hazard_ratios = model(x)
            # Ensure 1D arrays
            all_risks.append(hazard_ratios.cpu().numpy().flatten())
            all_times.append(durations.numpy().flatten())
            all_events.append(events.numpy().flatten())
    
    # Concatenate all predictions
    risk_scores = np.concatenate(all_risks)
    times = np.concatenate(all_times)
    events = np.concatenate(all_events)
    
    # Split into high/low risk groups using median
    median_risk = np.median(risk_scores)
    high_risk = risk_scores >= median_risk
    
    return risk_scores, times, events, high_risk

from lifelines import KaplanMeierFitter
import matplotlib.pyplot as plt

def plot_risk_stratification(times, events, high_risk, title="Risk Stratification"):
    """Plot Kaplan-Meier curves for high and low risk groups"""
    
    # Initialize KM estimator
    kmf1 = KaplanMeierFitter()
    kmf2 = KaplanMeierFitter()
    
    # Create figure
    plt.figure(figsize=(10, 6))
    
    # Plot high risk group
    mask = high_risk.astype(bool)  # Ensure boolean mask
    kmf1.fit(times[mask], events[mask], label='High Risk')
    kmf1.plot(show_censors = True, censor_styles={'marker': 'x', 'ms': 15})
    
    # Plot low risk group
    mask = ~high_risk.astype(bool)  # Ensure boolean mask
    kmf2.fit(times[mask], events[mask], label='Low Risk')
    kmf2.plot(show_censors = True, censor_styles={'marker': 'x', 'ms': 15})
    
    # Add at-risk counts
    lifelines.plotting.add_at_risk_counts(kmf1, kmf2)
    
    # Customize plot
    plt.title(title)
    plt.xlabel('Time (days)')
    plt.ylabel('Survival Probability')
    plt.grid(True)
    
    # Add log-rank test
    from lifelines.statistics import logrank_test
    log_rank = logrank_test(times[high_risk], times[~high_risk],
                           events[high_risk], events[~high_risk])
    plt.text(0.05, 0.05, f'Log-rank p-value: {log_rank.p_value:.3e}',
             transform=plt.gca().transAxes)
    
    return plt.gcf()

In [None]:
import lifelines
from lifelines.utils import concordance_index
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load best models
best_model = CoxPHModel(n_features=n_features)
best_model.load_state_dict(torch.load(f'{model.modelname}.pt'))
best_model = best_model.to(device)

best_model2 = LinearCoxPH(n_features=n_features)
best_model2.load_state_dict(torch.load(f'{model2.modelname}.pt'))
best_model2 = best_model2.to(device)

# CoxPHModel
print("CoxPHModel")
# Get risk groups
risk_scores, times, events, high_risk = get_risk_groups(best_model, train_loader, device)
# Plot Kaplan-Meier curves
fig = plot_risk_stratification(times, events, high_risk, 
                             title=model.modelname)
# Add additional metrics
c_index = concordance_index(times, -risk_scores.flatten(), events)
print(f"Concordance Index: {c_index:.3f} (1=perfect model)")

# LinearCoxPH
print("LinearCoxPH")
# Get risk groups
risk_scores, times, events, high_risk = get_risk_groups(best_model2, train_loader, device)
# Plot Kaplan-Meier curves
fig = plot_risk_stratification(times, events, high_risk, 
                             title=model2.modelname)
# Add additional metrics
c_index = concordance_index(times, -risk_scores.flatten(), events)
print(f"Concordance Index: {c_index:.3f}")

## Comparision Cox PH from lifelines

In [None]:
from lifelines import CoxPHFitter
cph = CoxPHFitter()
cph.fit(df_train, duration_col='T', event_col='E')

cph.print_summary()

In [None]:
print(df_train.shape)
risk = cph.predict_partial_hazard(df_train)
print (risk)
print (np.median(risk))
high_risk = risk >= np.median(risk)

In [None]:
fig = plot_risk_stratification(df_train['T'], df_train['E'], high_risk, 
                             title="CPH from lifelines")

In [None]:
# Initialize list to store results
results = []

# Run 10 times
for i in range(15):
    # Initialize model
    model = LinearCoxPH(n_features=n_features)
    
    # Train model with existing loaders
    history = train_cox_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=100
    )    
    
    # Load best model and get beta coefficients
    model.load_state_dict(torch.load(f'{model.modelname}.pt'))
    beta = model.beta.weight.detach().cpu().numpy().flatten()
    
    # Store results
    results.append({
        'run': i+1,
        'beta1': beta[0],
        'beta2': beta[1],
        'beta3': beta[2]
    })


# Create results dataframe
results_df = pd.DataFrame(results)
display(results_df)
print("\nMean betas:")
print(results_df[['beta1', 'beta2', 'beta3']].mean())
print("\nStd betas:")
print(results_df[['beta1', 'beta2', 'beta3']].std())


In [None]:
# Compare beta from Cox regression and linear model (including confidence intervals)
beta_cph = cph.params_.values   # β coefficients from Cox regression
ci_cph = cph.confidence_intervals_.values  # Confidence intervals for β coefficients
beta_linear = best_model2.beta.weight.detach().cpu().numpy().flatten()  # β coefficients from linear model 
print("Cox PH β:", beta_cph)
print("Linear Model β:", beta_linear)

# Plot β coefficients with confidence intervals
plt.figure(figsize=(10, 3)) 
plot1 = plt.errorbar(beta_cph, range(len(beta_cph)), xerr=(ci_cph[:, 1] - ci_cph[:, 0]) / 2, fmt='o', label='Cox PH', capsize=5)
plot2 = plt.errorbar(beta_linear, range(len(beta_linear)), fmt='o', label='Linear Model', capsize=5)
for row in range(results_df.shape[0]):
    plt.errorbar(results_df.loc[row, ['beta1', 'beta2', 'beta3']], range(0,3), fmt='o', label='Linear Model', capsize=5, c='orange')
plt.ylabel('Feature Index')
plt.xlabel('β Coefficient')
plt.legend([plot1, plot2], ['Cox PH', best_model2.modelname])
plt.grid(True)
plt.show()


