# Survival Analysis with scikit-survival

This notebook implements survival analysis using the **scikit-survival** library exclusively.

scikit-survival is a Python library for survival analysis built on top of scikit-learn. It provides:
- Non-parametric estimators (Kaplan-Meier, Nelson-Aalen)
- Cox Proportional Hazards models
- Machine learning survival models (Random Survival Forests, Gradient Boosting)
- Model evaluation metrics (concordance index, Brier score)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# scikit-survival imports
from sksurv.nonparametric import kaplan_meier_estimator, nelson_aalen_estimator
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored, brier_score, integrated_brier_score
from sksurv.functions import StepFunction
from sksurv.preprocessing import OneHotEncoder

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

sns.set_style('whitegrid')
%matplotlib inline

## Load and Prepare Data

scikit-survival requires the target variable as a **structured numpy array** with dtype `[("event", bool), ("time", float)]`.

In [None]:
# Load the preprocessed survival data
try:
    df = pd.read_parquet('../data/processed/survival_data.parquet')
except ImportError:
    df = pd.read_csv('../data/processed/survival_data.csv')

print(f"Dataset shape: {df.shape}")
print(f"\nColumns: {df.columns.tolist()}")
df.head()

In [None]:
def make_survival_array(event: np.ndarray, time: np.ndarray) -> np.ndarray:
    """Create structured array for scikit-survival.
    
    Parameters
    ----------
    event : array-like
        Boolean array indicating whether event occurred (True) or censored (False)
    time : array-like
        Array of survival/censoring times
        
    Returns
    -------
    y : structured array
        Structured array with dtype [('event', bool), ('time', float)]
    """
    return np.array(
        list(zip(event.astype(bool), time.astype(float))),
        dtype=[('event', bool), ('time', float)]
    )

# Create survival target array
y = make_survival_array(df['event'].values, df['duration'].values)

print(f"Survival array dtype: {y.dtype}")
print(f"Survival array shape: {y.shape}")
print(f"\nFirst 5 records:")
print(y[:5])

In [None]:
# Event distribution
print("=== Event Distribution ===")
print(df['event_type'].value_counts())
print(f"\nEvent rate: {y['event'].mean():.2%}")
print(f"Censoring rate: {(~y['event']).mean():.2%}")

## Kaplan-Meier Estimator

The Kaplan-Meier estimator is a non-parametric statistic used to estimate the survival function from lifetime data.

In [None]:
# Overall Kaplan-Meier survival curve
time_km, surv_prob = kaplan_meier_estimator(y['event'], y['time'])

# Find median survival time (where S(t) crosses 0.5)
median_idx = np.searchsorted(surv_prob[::-1], 0.5)
if median_idx < len(time_km):
    median_survival = time_km[::-1][median_idx]
    print(f"Median survival time: {median_survival:.1f} months")
else:
    print("Median survival time: Not reached")

# Print survival probabilities at key time points
print(f"\nSurvival probabilities:")
for t in [12, 24, 36, 60, 120]:
    idx = np.searchsorted(time_km, t)
    if idx < len(surv_prob):
        print(f"  At {t:3d} months: {surv_prob[idx]:.1%}")

In [None]:
# Plot overall survival curve
fig, ax = plt.subplots(figsize=(10, 6))
ax.step(time_km, surv_prob, where='post', linewidth=2, label='All Loans')
ax.set_xlabel('Time (months)')
ax.set_ylabel('Survival Probability')
ax.set_title('Kaplan-Meier Survival Curve - All Loans (scikit-survival)')
ax.set_ylim(0, 1)
ax.grid(True, alpha=0.3)
ax.legend()
plt.tight_layout()
plt.savefig('../reports/figures/sksurv_km_overall.png', dpi=150)
plt.show()

In [None]:
# Survival curves by Vintage Year
fig, ax = plt.subplots(figsize=(12, 7))
vintage_groups = [(1999, 2005), (2006, 2008), (2009, 2015), (2016, 2020), (2021, 2025)]
colors = plt.cm.viridis(np.linspace(0, 1, len(vintage_groups)))

for (start, end), color in zip(vintage_groups, colors):
    mask = (df['vintage_year'] >= start) & (df['vintage_year'] <= end)
    if mask.sum() > 0:
        y_group = make_survival_array(
            df.loc[mask, 'event'].values,
            df.loc[mask, 'duration'].values
        )
        time_grp, surv_grp = kaplan_meier_estimator(y_group['event'], y_group['time'])
        ax.step(time_grp, surv_grp, where='post', label=f'{start}-{end}', 
                color=color, linewidth=1.5)

ax.set_xlabel('Time (months)')
ax.set_ylabel('Survival Probability')
ax.set_title('Kaplan-Meier Survival Curves by Vintage Year (scikit-survival)')
ax.set_ylim(0, 1)
ax.legend(title='Vintage')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('../reports/figures/sksurv_km_by_vintage.png', dpi=150)
plt.show()

In [None]:
# Survival curves by FICO Score bands
fig, ax = plt.subplots(figsize=(12, 7))
fico_bands = ['<620', '620-679', '680-739', '740-779', '780+']
colors = plt.cm.RdYlGn(np.linspace(0.1, 0.9, len(fico_bands)))

for band, color in zip(fico_bands, colors):
    mask = df['fico_band'] == band
    if mask.sum() > 100:
        y_group = make_survival_array(
            df.loc[mask, 'event'].values,
            df.loc[mask, 'duration'].values
        )
        time_grp, surv_grp = kaplan_meier_estimator(y_group['event'], y_group['time'])
        ax.step(time_grp, surv_grp, where='post', label=f'FICO {band}', 
                color=color, linewidth=1.5)

ax.set_xlabel('Time (months)')
ax.set_ylabel('Survival Probability')
ax.set_title('Kaplan-Meier Survival Curves by FICO Score (scikit-survival)')
ax.set_ylim(0, 1)
ax.legend(title='FICO Band')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('../reports/figures/sksurv_km_by_fico.png', dpi=150)
plt.show()

In [None]:
# Survival curves by LTV bands
fig, ax = plt.subplots(figsize=(12, 7))
ltv_bands = ['<=60', '61-70', '71-80', '81-90', '91-95', '>95']
colors = plt.cm.coolwarm(np.linspace(0, 1, len(ltv_bands)))

for band, color in zip(ltv_bands, colors):
    mask = df['ltv_band'] == band
    if mask.sum() > 100:
        y_group = make_survival_array(
            df.loc[mask, 'event'].values,
            df.loc[mask, 'duration'].values
        )
        time_grp, surv_grp = kaplan_meier_estimator(y_group['event'], y_group['time'])
        ax.step(time_grp, surv_grp, where='post', label=f'LTV {band}', 
                color=color, linewidth=1.5)

ax.set_xlabel('Time (months)')
ax.set_ylabel('Survival Probability')
ax.set_title('Kaplan-Meier Survival Curves by LTV (scikit-survival)')
ax.set_ylim(0, 1)
ax.legend(title='LTV Band')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('../reports/figures/sksurv_km_by_ltv.png', dpi=150)
plt.show()

## Nelson-Aalen Cumulative Hazard Estimator

The Nelson-Aalen estimator estimates the cumulative hazard function H(t).

- Higher cumulative hazard = higher accumulated risk over time
- The slope of H(t) represents the instantaneous hazard rate

In [None]:
# Nelson-Aalen Cumulative Hazard - Overall
time_na, cumhaz = nelson_aalen_estimator(y['event'], y['time'])

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Overall cumulative hazard
axes[0].step(time_na, cumhaz, where='post', color='steelblue', linewidth=2)
axes[0].set_xlabel('Time (months)')
axes[0].set_ylabel('Cumulative Hazard H(t)')
axes[0].set_title('Nelson-Aalen Cumulative Hazard - All Loans')
axes[0].grid(True, alpha=0.3)

# Compare cumulative hazard by vintage groups
colors = plt.cm.viridis(np.linspace(0, 1, len(vintage_groups)))

for (start, end), color in zip(vintage_groups, colors):
    mask = (df['vintage_year'] >= start) & (df['vintage_year'] <= end)
    if mask.sum() > 0:
        y_group = make_survival_array(
            df.loc[mask, 'event'].values,
            df.loc[mask, 'duration'].values
        )
        time_grp, cumhaz_grp = nelson_aalen_estimator(y_group['event'], y_group['time'])
        axes[1].step(time_grp, cumhaz_grp, where='post', 
                     label=f'{start}-{end}', color=color, linewidth=1.5)

axes[1].set_xlabel('Time (months)')
axes[1].set_ylabel('Cumulative Hazard H(t)')
axes[1].set_title('Nelson-Aalen Cumulative Hazard by Vintage')
axes[1].legend(title='Vintage')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../reports/figures/sksurv_nelson_aalen.png', dpi=150)
plt.show()

print("Interpretation: Steeper slopes indicate periods of higher hazard (risk of event).")

## Cox Proportional Hazards Model

scikit-survival provides `CoxPHSurvivalAnalysis` which is similar to sklearn's API.

In [None]:
# Prepare features for Cox model
feature_cols = ['credit_score', 'orig_ltv', 'orig_dti', 'orig_interest_rate']

# Drop rows with missing values
mask_valid = df[feature_cols].notna().all(axis=1)
df_cox = df.loc[mask_valid].copy()
X = df_cox[feature_cols].values
y_cox = make_survival_array(df_cox['event'].values, df_cox['duration'].values)

print(f"Data for Cox model: {len(df_cox):,} loans")
print(f"Features: {feature_cols}")

In [None]:
# Train-test split BEFORE scaling (to prevent data leakage)
X_train_raw, X_test_raw, y_train, y_test = train_test_split(
    X, y_cox, test_size=0.2, random_state=42
)

print(f"Training set: {len(X_train_raw):,} loans")
print(f"Test set: {len(X_test_raw):,} loans")

In [None]:
# Standardize features - fit on TRAINING data only (prevent data leakage)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train_raw)  # Fit and transform training data
X_test = scaler.transform(X_test_raw)         # Transform test data (using training stats)

print("Features standardized (mean=0, std=1)")
print("Scaler fit on TRAINING data only to prevent data leakage")
print("Coefficients will represent effect of 1 SD change")
print(f"\nFeature statistics (from training data):")
print(f"  FICO: mean={scaler.mean_[0]:.1f}, std={scaler.scale_[0]:.1f}")
print(f"  LTV:  mean={scaler.mean_[1]:.1f}, std={scaler.scale_[1]:.1f}")
print(f"  DTI:  mean={scaler.mean_[2]:.1f}, std={scaler.scale_[2]:.1f}")
print(f"  Rate: mean={scaler.mean_[3]:.2f}, std={scaler.scale_[3]:.2f}")

In [None]:
# Fit Cox Proportional Hazards model
cox_model = CoxPHSurvivalAnalysis(alpha=0.01)  # L2 regularization
cox_model.fit(X_train, y_train)

# Display coefficients
coef_df = pd.DataFrame({
    'feature': [f'{col}_std' for col in feature_cols],
    'coefficient': cox_model.coef_,
    'hazard_ratio': np.exp(cox_model.coef_)
})
coef_df['interpretation'] = coef_df.apply(
    lambda row: f"{(row['hazard_ratio']-1)*100:+.1f}% hazard per 1 SD increase", 
    axis=1
)

print("=== Cox PH Model Coefficients ===")
print(coef_df.to_string(index=False))

In [None]:
# Plot coefficients and hazard ratios
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Coefficients
colors = ['green' if c < 0 else 'red' for c in cox_model.coef_]
axes[0].barh(coef_df['feature'], coef_df['coefficient'], color=colors, alpha=0.7)
axes[0].axvline(x=0, color='black', linestyle='-', linewidth=0.5)
axes[0].set_xlabel('Coefficient')
axes[0].set_title('Cox PH Model Coefficients (Standardized)')
axes[0].grid(True, alpha=0.3)

# Hazard ratios
axes[1].barh(coef_df['feature'], coef_df['hazard_ratio'], color=colors, alpha=0.7)
axes[1].axvline(x=1, color='black', linestyle='--', linewidth=1)
axes[1].set_xlabel('Hazard Ratio')
axes[1].set_title('Cox PH Hazard Ratios (per 1 SD change)')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../reports/figures/sksurv_cox_coefficients.png', dpi=150)
plt.show()

print("\nInterpretation:")
print("- Green bars: Higher values associated with lower hazard (protective)")
print("- Red bars: Higher values associated with higher hazard (risk factor)")

## Model Evaluation

scikit-survival provides several evaluation metrics:
- **Concordance Index (C-index)**: Measures the model's ability to correctly rank survival times
- **Brier Score**: Measures calibration (prediction accuracy at specific time points)
- **Integrated Brier Score**: Overall calibration across all time points

In [None]:
# Concordance Index
# Predict risk scores (higher = more risk)
risk_scores_train = cox_model.predict(X_train)
risk_scores_test = cox_model.predict(X_test)

# Calculate C-index
c_train = concordance_index_censored(
    y_train['event'], y_train['time'], risk_scores_train
)
c_test = concordance_index_censored(
    y_test['event'], y_test['time'], risk_scores_test
)

print("=== Concordance Index ===")
print(f"Training C-index: {c_train[0]:.4f}")
print(f"Test C-index:     {c_test[0]:.4f}")
print(f"\nConcordant pairs (test): {c_test[1]:,}")
print(f"Discordant pairs (test): {c_test[2]:,}")
print(f"Tied pairs (test): {c_test[3]:,} risk, {c_test[4]:,} time")
print("\nInterpretation: C-index of 0.5 = random, 1.0 = perfect discrimination")

In [None]:
# Get survival function predictions for Brier score
surv_funcs = cox_model.predict_survival_function(X_test)

# Show example survival functions
fig, ax = plt.subplots(figsize=(10, 6))

# Plot survival functions for a few random individuals
np.random.seed(42)
sample_idx = np.random.choice(len(surv_funcs), size=10, replace=False)

for i, idx in enumerate(sample_idx):
    fn = surv_funcs[idx]
    ax.step(fn.x, fn(fn.x), where='post', alpha=0.7, label=f'Individual {idx}')

ax.set_xlabel('Time (months)')
ax.set_ylabel('Survival Probability')
ax.set_title('Predicted Survival Functions (10 Random Individuals)')
ax.set_ylim(0, 1)
ax.grid(True, alpha=0.3)
ax.legend(loc='lower left', fontsize=8)
plt.tight_layout()
plt.savefig('../reports/figures/sksurv_predicted_survival.png', dpi=150)
plt.show()

In [None]:
# Brier Score at specific time points
# Note: times must be within the range of observed times
min_time = y_test['time'].min()
max_time = y_test['time'].max()
time_points = np.array([12, 24, 36, 60, 120])
time_points = time_points[(time_points >= min_time) & (time_points <= max_time)]

# Create survival probability matrix for test set at evaluation times
surv_probs = np.column_stack([
    [fn(t) for fn in surv_funcs] for t in time_points
])

# Compute Brier scores
try:
    times_bs, brier_scores = brier_score(y_train, y_test, surv_probs, time_points)
    
    print("=== Brier Score at Time Points ===")
    for t, bs in zip(times_bs, brier_scores):
        print(f"  At {t:3.0f} months: {bs:.4f}")
    print("\nInterpretation: Lower Brier score = better calibration (0 = perfect)")
except Exception as e:
    print(f"Brier score calculation failed: {e}")
    print("This can happen if there are issues with the time range or censoring.")

In [None]:
# Integrated Brier Score (overall calibration)
try:
    # Use a finer grid of time points for integrated score
    percentiles = np.percentile(y_test['time'][y_test['event']], [10, 90])
    time_grid = np.linspace(percentiles[0], percentiles[1], 100)
    
    surv_probs_grid = np.column_stack([
        [fn(t) for fn in surv_funcs] for t in time_grid
    ])
    
    ibs = integrated_brier_score(y_train, y_test, surv_probs_grid, time_grid)
    print(f"\n=== Integrated Brier Score ===")
    print(f"IBS: {ibs:.4f}")
    print(f"\nInterpretation: IBS averages Brier score across time points.")
    print(f"Typical range: 0.1-0.3 for reasonable models")
except Exception as e:
    print(f"Integrated Brier score calculation failed: {e}")

## Competing Risks Analysis

Separate Kaplan-Meier curves for default vs prepayment, treating the other event as censored.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Default-specific survival (treating prepay/matured/other as censored)
df_default = df[df['event_type'].isin(['default', 'censored', 'matured'])].copy()
y_default = make_survival_array(
    (df_default['event_type'] == 'default').values,
    df_default['duration'].values
)
time_def, surv_def = kaplan_meier_estimator(y_default['event'], y_default['time'])

axes[0].step(time_def, surv_def, where='post', color='indianred', linewidth=2)
axes[0].set_xlabel('Time (months)')
axes[0].set_ylabel('Survival Probability (no default)')
axes[0].set_title('Cause-Specific Survival: Default')
axes[0].set_ylim(0, 1)
axes[0].grid(True, alpha=0.3)

# Prepayment-specific survival (treating default/matured/other as censored)
df_prepay = df[df['event_type'].isin(['prepay', 'censored', 'matured'])].copy()
y_prepay = make_survival_array(
    (df_prepay['event_type'] == 'prepay').values,
    df_prepay['duration'].values
)
time_prep, surv_prep = kaplan_meier_estimator(y_prepay['event'], y_prepay['time'])

axes[1].step(time_prep, surv_prep, where='post', color='steelblue', linewidth=2)
axes[1].set_xlabel('Time (months)')
axes[1].set_ylabel('Survival Probability (no prepay)')
axes[1].set_title('Cause-Specific Survival: Prepayment')
axes[1].set_ylim(0, 1)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../reports/figures/sksurv_competing_risks.png', dpi=150)
plt.show()

print(f"Default events: {y_default['event'].sum():,} ({y_default['event'].mean():.2%})")
print(f"Prepayment events: {y_prepay['event'].sum():,} ({y_prepay['event'].mean():.2%})")

## Summary

This notebook demonstrated survival analysis using **scikit-survival**:

1. **Data Preparation**: Convert pandas data to scikit-survival's structured array format
2. **Kaplan-Meier Estimator**: Non-parametric survival curve estimation
3. **Nelson-Aalen Estimator**: Cumulative hazard function estimation
4. **Cox PH Model**: Semi-parametric regression with standardized covariates
5. **Model Evaluation**: Concordance index and Brier score
6. **Competing Risks**: Cause-specific survival analysis

### Next Steps

scikit-survival also provides:
- **Random Survival Forests**: `sksurv.ensemble.RandomSurvivalForest`
- **Gradient Boosted Survival**: `sksurv.ensemble.GradientBoostingSurvivalAnalysis`
- **Survival SVM**: `sksurv.svm.FastSurvivalSVM`

In [None]:
print("=== Model Performance Summary ===")
print(f"Cox PH Test C-index: {c_test[0]:.4f}")
print(f"\nModel coefficients (standardized):")
for feat, coef, hr in zip(coef_df['feature'], coef_df['coefficient'], coef_df['hazard_ratio']):
    direction = '↑' if coef > 0 else '↓'
    print(f"  {feat}: coef={coef:+.4f}, HR={hr:.4f} {direction}")