# DeepHit for Competing Risks

This notebook implements **DeepHit** (Lee et al., 2018) for competing risks analysis using **pure PyTorch**. DeepHit is a deep learning approach that directly models the probability mass function of survival times without assuming any particular form for the underlying stochastic process.

## Methodology

| Aspect | Description |
|--------|-------------|
| **Model** | DeepHit with competing risks |
| **Implementation** | Pure PyTorch (custom network and loss) |
| **Features** | Blumenstock et al. (2022) - 21 variables |
| **Evaluation** | Time-dependent C-index at 24, 48, 72 months |

## Key Advantages of DeepHit

| Feature | Description |
|---------|-------------|
| **No PH assumption** | Does not assume proportional hazards |
| **Competing risks** | Native support for multiple event types |
| **Flexible** | Can capture complex non-linear relationships |
| **Direct modeling** | Models PMF directly, not hazard function |

## DeepHit Loss Function

The DeepHit loss combines two components:
1. **Negative log-likelihood (L1)**: Maximizes probability of observed event at observed time
2. **Ranking loss (L2)**: Ensures correct ordering of event times

$$\mathcal{L} = \mathcal{L}_1 + \alpha \cdot \mathcal{L}_2$$

## References

- Lee, C., Zame, W., Yoon, J., & van der Schaar, M. (2018). DeepHit: A Deep Learning Approach to Survival Analysis with Competing Risks. AAAI.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Deep learning
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# pycox for evaluation metrics only
from pycox.evaluation import EvalSurv

# Survival analysis (for comparison metrics)
from sksurv.util import Surv
from sksurv.metrics import concordance_index_censored, concordance_index_ipcw

# Preprocessing
from sklearn.preprocessing import StandardScaler
import pickle

sns.set_style('whitegrid')
%matplotlib inline

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Device setup (MPS for macOS, CUDA for Linux/Windows, CPU fallback)
if torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
    print("Using MPS (Apple Silicon GPU)")
elif torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    print("Using CUDA GPU")
else:
    DEVICE = torch.device('cpu')
    print("Using CPU")

# Time horizons for evaluation (matching previous notebooks)
TIME_HORIZONS = [24, 48, 72]

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {DEVICE}")
print(f"Time horizons for C-index evaluation: {TIME_HORIZONS} months")

In [None]:
# === CONFIGURATION ===
DATA_DIR = Path('../data/processed')
FIGURES_DIR = Path('../reports/figures')
MODELS_DIR = Path('../models')

FIGURES_DIR.mkdir(parents=True, exist_ok=True)
MODELS_DIR.mkdir(parents=True, exist_ok=True)

# Cross-validation folds (Blumenstock methodology)
TRAIN_FOLDS = list(range(10))  # Folds 0-9 for training
VAL_FOLDS = [9]                 # Use fold 9 for validation (early stopping)
TEST_FOLD = 10                  # Fold 10 for testing

# Number of discrete time points
NUM_DURATIONS = 200

# DeepHit hyperparameters (Blumenstock et al. 2022)
DEEPHIT_PARAMS = {
    # Training
    'batch_size': 256,
    'epochs': 100,
    'learning_rate': 0.01,
    # Loss function
    'alpha': 0.2,               # Weight for ranking loss
    'sigma': 0.1,               # Smoothing parameter for ranking loss
    # Regularization
    'dropout': 0.2,
    'batch_norm': True,
}

# Architecture is defined in DeepHitNetwork class:
# - Shared: 3 layers x 300 nodes
# - Heads: 5 layers x 200 nodes each

print(f"Training folds: {[f for f in TRAIN_FOLDS if f not in VAL_FOLDS]}")
print(f"Validation fold: {VAL_FOLDS}")
print(f"Test fold: {TEST_FOLD}")
print(f"\nTime discretization: {NUM_DURATIONS} bins")
print(f"\nDeepHit parameters:")
for k, v in DEEPHIT_PARAMS.items():
    print(f"  {k}: {v}")

---

## Load Loan-Month Panel Data

In [None]:
# Load the loan-month panel data
print("Loading loan-month panel data...")
panel_df = pd.read_parquet(DATA_DIR / 'loan_month_panel.parquet')

print(f"Loaded {len(panel_df):,} loan-months")
print(f"Unique loans: {panel_df['loan_sequence_number'].nunique():,}")
print(f"Folds: {sorted(panel_df['fold'].unique())}")
print(f"Vintages: {panel_df['vintage_year'].min()} - {panel_df['vintage_year'].max()}")

print("\nEvent distribution (terminal observations):")
event_names = {0: 'Censored', 1: 'Prepay', 2: 'Default'}
terminal_events = panel_df[panel_df['event'] == 1].groupby('event_code').size()
for code, count in terminal_events.items():
    print(f"  {event_names.get(code, 'Other')} (k={code}): {count:,}")

---

## Define Features (Blumenstock et al. 2022)

In [None]:
# Define feature groups (Blumenstock et al. 2022, Table 2)
# Matching previous notebooks exactly

# Static covariates (fixed at origination) - 5 variables
STATIC_FEATURES = [
    'int_rate',      # Initial interest rate
    'orig_upb',      # Original unpaid balance
    'fico_score',    # Initial FICO score
    'dti_r',         # Initial debt-to-income ratio
    'ltv_r',         # Initial loan-to-value ratio
]

# Behavioral covariates (time-varying) - 4 variables
BEHAVIORAL_FEATURES = [
    'bal_repaid',      # Current repaid balance in percent
    't_act_12m',       # No. of times not being delinquent in last 12 months
    't_del_30d_12m',   # No. of times being 30 days delinquent in last 12 months
    't_del_60d_12m',   # No. of times being 60 days delinquent in last 12 months
]

# Macro covariates (time-varying) - 12 variables
MACRO_FEATURES = [
    'hpi_st_d_t_o',    # HPI difference between origination and today (state)
    'ppi_c_FRMA',      # Current prepayment incentive
    'TB10Y_d_t_o',     # Treasury rate difference
    'FRMA30Y_d_t_o',   # 30Y FRM difference
    'ppi_o_FRMA',      # Prepayment incentive at origination
    'hpi_st_log12m',   # HPI 12-month log return (state)
    'hpi_r_st_us',     # Ratio of state HPI to national HPI
    'st_unemp_r12m',   # Unemployment 12-month log return (state)
    'st_unemp_r3m',    # Unemployment 3-month log return (state)
    'TB10Y_r12m',      # Treasury rate 12-month return
    'T10Y3MM',         # Yield spread (10Y - 3M)
    'T10Y3MM_r12m',    # Yield spread 12-month return
]

ALL_FEATURES = STATIC_FEATURES + BEHAVIORAL_FEATURES + MACRO_FEATURES

# Filter to available features
feature_cols = [f for f in ALL_FEATURES if f in panel_df.columns]
missing_features = [f for f in ALL_FEATURES if f not in panel_df.columns]

print("=== Feature Groups (Blumenstock et al. 2022) ===")
print(f"Static features: {len([f for f in STATIC_FEATURES if f in feature_cols])}/5")
print(f"Behavioral features: {len([f for f in BEHAVIORAL_FEATURES if f in feature_cols])}/4")
print(f"Macro features: {len([f for f in MACRO_FEATURES if f in feature_cols])}/12")
print(f"\nTotal available: {len(feature_cols)}/21")

if missing_features:
    print(f"\nMissing features ({len(missing_features)}):")
    for f in missing_features:
        print(f"  - {f}")

---

## Prepare Data for DeepHit

DeepHit requires terminal observations (one per loan) with discretized time.

In [None]:
# For DeepHit, we need terminal observations (one per loan)
# Get the last observation for each loan
print("=== Preparing Terminal Observations ===")

time_col = 'loan_age'
event_col = 'event_code'

# Sort and get last observation per loan
panel_df = panel_df.sort_values(['loan_sequence_number', time_col])
terminal_df = panel_df.groupby('loan_sequence_number').last().reset_index()

print(f"Terminal observations: {len(terminal_df):,} loans")

# Lag bal_repaid to avoid data leakage (matching previous notebooks)
if 'bal_repaid' in feature_cols:
    print("\nLagging bal_repaid to avoid data leakage...")
    
    def get_lagged_bal_repaid(group):
        if len(group) >= 2:
            return group['bal_repaid'].iloc[-2]
        else:
            return group['bal_repaid'].iloc[-1]
    
    bal_repaid_lag = panel_df.groupby('loan_sequence_number').apply(get_lagged_bal_repaid)
    terminal_df['bal_repaid_lag1'] = terminal_df['loan_sequence_number'].map(bal_repaid_lag)
    
    feature_cols = [f if f != 'bal_repaid' else 'bal_repaid_lag1' for f in feature_cols]
    print("  Created bal_repaid_lag1")

# Log transform UPB
if 'orig_upb' in terminal_df.columns:
    terminal_df['log_upb'] = np.log(terminal_df['orig_upb'])
    feature_cols = [f if f != 'orig_upb' else 'log_upb' for f in feature_cols]
    print("  Created log_upb")

# Drop rows with missing features
n_before = len(terminal_df)
terminal_df = terminal_df.dropna(subset=feature_cols)
n_after = len(terminal_df)
print(f"\nAfter dropping NaN: {n_after:,} loans (dropped {n_before - n_after:,})")

print(f"\nFeatures ({len(feature_cols)}): {feature_cols}")

In [None]:
# Split by folds (matching previous notebooks)
print("=== Splitting Data by Folds ===")

# Training uses folds 0-8, validation uses fold 9, test uses fold 10
train_folds_actual = [f for f in TRAIN_FOLDS if f not in VAL_FOLDS]

train_df = terminal_df[terminal_df['fold'].isin(train_folds_actual)].copy()
val_df = terminal_df[terminal_df['fold'].isin(VAL_FOLDS)].copy()
test_df = terminal_df[terminal_df['fold'] == TEST_FOLD].copy()

print(f"Training set (folds {train_folds_actual}): {len(train_df):,} loans")
print(f"Validation set (fold {VAL_FOLDS}): {len(val_df):,} loans")
print(f"Test set (fold {TEST_FOLD}): {len(test_df):,} loans")

# Event distribution
print("\nTraining set event distribution:")
for code, name in event_names.items():
    count = (train_df[event_col] == code).sum()
    print(f"  {name}: {count:,}")

In [None]:
# Standardize features (important for neural networks)
print("=== Standardizing Features ===")

scaler = StandardScaler()

X_train = scaler.fit_transform(train_df[feature_cols]).astype('float32')
X_val = scaler.transform(val_df[feature_cols]).astype('float32')
X_test = scaler.transform(test_df[feature_cols]).astype('float32')

print(f"X_train shape: {X_train.shape}")
print(f"X_val shape: {X_val.shape}")
print(f"X_test shape: {X_test.shape}")

# Durations
duration_train = train_df[time_col].values
duration_val = val_df[time_col].values
duration_test = test_df[time_col].values

print(f"\nDuration range: {duration_train.min()} - {duration_train.max()} months")

---

## Pure PyTorch DeepHit Implementation

Custom implementation of DeepHit (Lee et al., 2018) without using pycox wrappers.

In [None]:
# === DeepHit Model Architecture (Blumenstock et al. 2022) ===

class DeepHitNetwork(torch.nn.Module):
    """
    DeepHit neural network for competing risks (Blumenstock et al. 2022).
    
    Architecture:
    - Shared FFN: 3 layers, 300 nodes each
    - Residual connection: shared output + input features
    - Cause-specific heads: 2 heads (prepay, default), each 5 layers, 200 nodes
    - Joint softmax over (time, cause) for PMF
    
    Outputs joint probability mass function P(T=t, K=k) where:
    - T is discrete time
    - K is cause (0=prepay, 1=default)
    """
    def __init__(
        self,
        in_features: int,
        num_time_bins: int,
        num_causes: int = 2,
        shared_layers: int = 3,
        shared_nodes: int = 300,
        head_layers: int = 5,
        head_nodes: int = 200,
        dropout: float = 0.2,
        batch_norm: bool = True,
    ):
        super().__init__()
        
        self.in_features = in_features
        self.num_time_bins = num_time_bins
        self.num_causes = num_causes
        
        # === Shared FFN (3 layers, 300 nodes) ===
        shared = []
        prev_dim = in_features
        for _ in range(shared_layers):
            shared.append(torch.nn.Linear(prev_dim, shared_nodes))
            if batch_norm:
                shared.append(torch.nn.BatchNorm1d(shared_nodes))
            shared.append(torch.nn.ReLU())
            shared.append(torch.nn.Dropout(dropout))
            prev_dim = shared_nodes
        self.shared = torch.nn.Sequential(*shared)
        
        # Residual projection: project input to shared_nodes for addition
        self.residual_proj = torch.nn.Linear(in_features, shared_nodes)
        
        # === Cause-specific heads (5 layers, 200 nodes each) ===
        self.heads = torch.nn.ModuleList()
        for _ in range(num_causes):
            head = []
            prev_dim = shared_nodes  # Input from shared + residual
            for layer_idx in range(head_layers):
                head.append(torch.nn.Linear(prev_dim, head_nodes))
                if batch_norm:
                    head.append(torch.nn.BatchNorm1d(head_nodes))
                head.append(torch.nn.ReLU())
                head.append(torch.nn.Dropout(dropout))
                prev_dim = head_nodes
            # Output layer for this head: num_time_bins
            head.append(torch.nn.Linear(prev_dim, num_time_bins))
            self.heads.append(torch.nn.Sequential(*head))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            x: Input features [batch_size, in_features]
            
        Returns:
            Joint PMF over (time, cause) [batch_size, num_causes, num_time_bins]
        """
        batch_size = x.shape[0]
        
        # Shared layers
        shared_out = self.shared(x)  # [batch, shared_nodes]
        
        # Residual connection: add projected input
        residual = self.residual_proj(x)  # [batch, shared_nodes]
        combined = shared_out + residual  # [batch, shared_nodes]
        
        # Cause-specific heads
        head_outputs = []
        for head in self.heads:
            head_out = head(combined)  # [batch, num_time_bins]
            head_outputs.append(head_out)
        
        # Stack: [batch, num_causes, num_time_bins]
        logits = torch.stack(head_outputs, dim=1)
        
        # Joint softmax over ALL (cause, time) combinations
        # Reshape to [batch, num_causes * num_time_bins], softmax, reshape back
        logits_flat = logits.view(batch_size, -1)
        pmf_flat = torch.softmax(logits_flat, dim=-1)
        pmf = pmf_flat.view(batch_size, self.num_causes, self.num_time_bins)
        
        return pmf
    
    def predict_cif(self, x: torch.Tensor) -> torch.Tensor:
        """
        Predict cause-specific cumulative incidence functions (CIF).
        
        CIF_k(t) = P(T <= t, K = k) = sum_{s <= t} P(T=s, K=k)
        
        Args:
            x: Input features [batch_size, in_features]
            
        Returns:
            CIF for each cause [batch_size, num_causes, num_time_bins]
        """
        pmf = self.forward(x)  # [batch, num_causes, num_time_bins]
        cif = torch.cumsum(pmf, dim=-1)
        return cif
    
    def predict_survival(self, x: torch.Tensor) -> torch.Tensor:
        """
        Predict overall survival function S(t) = P(T > t).
        
        S(t) = 1 - sum_k CIF_k(t)
        
        Args:
            x: Input features [batch_size, in_features]
            
        Returns:
            Survival probabilities [batch_size, num_time_bins]
        """
        cif = self.predict_cif(x)  # [batch, num_causes, num_time_bins]
        # Sum CIF over all causes
        total_cif = cif.sum(dim=1)  # [batch, num_time_bins]
        survival = 1 - total_cif
        return survival
    
    def predict_cause_specific_hazard(self, x: torch.Tensor, cause: int) -> torch.Tensor:
        """
        Predict cause-specific hazard for a given cause.
        
        h_k(t) = P(T=t, K=k | T >= t)
        
        Args:
            x: Input features [batch_size, in_features]
            cause: Cause index (0=prepay, 1=default)
            
        Returns:
            Hazard probabilities [batch_size, num_time_bins]
        """
        pmf = self.forward(x)  # [batch, num_causes, num_time_bins]
        survival = self.predict_survival(x)  # [batch, num_time_bins]
        
        # Shift survival by 1 for S(t-1)
        survival_prev = torch.cat([
            torch.ones(survival.shape[0], 1, device=survival.device),
            survival[:, :-1]
        ], dim=-1)
        
        # h_k(t) = f_k(t) / S(t-1)
        hazard = pmf[:, cause, :] / (survival_prev + 1e-7)
        return hazard


# Print architecture summary
print("DeepHitNetwork (Blumenstock et al. 2022 architecture)")
print("=" * 60)
print(f"Shared FFN: 3 layers x 300 nodes")
print(f"Residual: shared_output + proj(input)")
print(f"Head 1 (Prepay):  5 layers x 200 nodes -> {NUM_DURATIONS} time bins")
print(f"Head 2 (Default): 5 layers x 200 nodes -> {NUM_DURATIONS} time bins")
print(f"Output: Joint softmax over {2 * NUM_DURATIONS} (cause, time) combinations")

In [None]:
# === DeepHit Loss Function (Competing Risks) ===

class DeepHitLoss(torch.nn.Module):
    """
    DeepHit loss function for competing risks.
    
    Combines:
    1. Negative log-likelihood loss (L1) for joint distribution P(T=t, K=k)
    2. Ranking loss (L2) for each cause
    
    Loss = L1 + alpha * L2
    
    Reference: Lee et al. (2018) - DeepHit
    """
    def __init__(self, alpha: float = 0.2, sigma: float = 0.1):
        """
        Args:
            alpha: Weight for ranking loss (default 0.2)
            sigma: Smoothing parameter for ranking loss (default 0.1)
        """
        super().__init__()
        self.alpha = alpha
        self.sigma = sigma
        
    def forward(
        self,
        pmf: torch.Tensor,
        durations: torch.Tensor,
        events: torch.Tensor,
        time_bins: torch.Tensor,
    ) -> tuple:
        """
        Compute DeepHit loss for competing risks.
        
        Args:
            pmf: Predicted joint PMF [batch_size, num_causes, num_time_bins]
            durations: Observed durations [batch_size]
            events: Event codes (0=censored, 1=prepay, 2=default) [batch_size]
            time_bins: Discrete time bin edges [num_time_bins + 1]
            
        Returns:
            Tuple of (total_loss, nll_loss, ranking_loss)
        """
        batch_size = pmf.shape[0]
        num_causes = pmf.shape[1]
        num_bins = pmf.shape[2]
        device = pmf.device
        
        # Map durations to bin indices
        bin_indices = torch.bucketize(durations, time_bins[1:])
        bin_indices = torch.clamp(bin_indices, 0, num_bins - 1)
        
        # === Negative Log-Likelihood Loss (L1) ===
        # For uncensored (event > 0): -log(P(T=t, K=k))
        # For censored (event = 0): -log(S(t)) where S(t) = 1 - sum_k CIF_k(t)
        
        eps = 1e-7
        
        # Compute CIF for survival calculation
        cif = torch.cumsum(pmf, dim=-1)  # [batch, num_causes, num_time_bins]
        total_cif = cif.sum(dim=1)  # [batch, num_time_bins]
        survival = 1 - total_cif  # [batch, num_time_bins]
        
        # Get survival at observed time for censored samples
        survival_at_time = survival[torch.arange(batch_size, device=device), bin_indices]
        
        # For uncensored: get PMF at (observed_time, observed_cause)
        # Event codes: 0=censored, 1=prepay (cause 0), 2=default (cause 1)
        # Map event codes to cause indices: event - 1 (but only for event > 0)
        cause_indices = (events - 1).long()
        cause_indices = torch.clamp(cause_indices, 0, num_causes - 1)  # Safety clamp
        
        # Get PMF at observed (time, cause)
        pmf_at_event = pmf[
            torch.arange(batch_size, device=device),
            cause_indices,
            bin_indices
        ]
        
        # Masks
        is_censored = (events == 0).float()
        is_uncensored = (events > 0).float()
        
        # NLL components
        nll_uncensored = -torch.log(pmf_at_event + eps) * is_uncensored
        nll_censored = -torch.log(survival_at_time + eps) * is_censored
        nll_loss = (nll_uncensored + nll_censored).mean()
        
        # === Ranking Loss (L2) ===
        if self.alpha > 0 and is_uncensored.sum() > 0:
            ranking_loss = self._compute_ranking_loss(
                pmf, cif, bin_indices, events, num_causes
            )
        else:
            ranking_loss = torch.tensor(0.0, device=device)
        
        total_loss = nll_loss + self.alpha * ranking_loss
        
        return total_loss, nll_loss, ranking_loss
    
    def _compute_ranking_loss(
        self,
        pmf: torch.Tensor,
        cif: torch.Tensor,
        bin_indices: torch.Tensor,
        events: torch.Tensor,
        num_causes: int,
    ) -> torch.Tensor:
        """
        Compute ranking loss for competing risks.
        
        For each cause k, compare pairs where subject i experienced cause k
        before subject j (who either experienced k later, different cause, or censored).
        """
        device = pmf.device
        batch_size = pmf.shape[0]
        
        ranking_loss = torch.tensor(0.0, device=device)
        n_pairs = 0
        
        # For each cause
        for k in range(num_causes):
            event_code = k + 1  # Map cause index to event code
            
            # Find subjects who experienced this cause
            cause_mask = (events == event_code)
            cause_indices = torch.where(cause_mask)[0]
            
            if len(cause_indices) < 1:
                continue
            
            # Sample subjects for efficiency
            sample_size = min(100, len(cause_indices))
            if len(cause_indices) > sample_size:
                perm = torch.randperm(len(cause_indices), device=device)[:sample_size]
                cause_indices = cause_indices[perm]
            
            for idx in cause_indices:
                t_i = bin_indices[idx]
                
                # Compare with subjects who have later event time
                # (regardless of their cause or censoring status)
                later_mask = bin_indices > t_i
                later_indices = torch.where(later_mask)[0]
                
                if len(later_indices) == 0:
                    continue
                
                # Sample if too many
                if len(later_indices) > 10:
                    perm = torch.randperm(len(later_indices), device=device)[:10]
                    later_indices = later_indices[perm]
                
                # CIF_k at time t_i for subject i and subjects j
                cif_i_at_ti = cif[idx, k, t_i]
                cif_j_at_ti = cif[later_indices, k, t_i]
                
                # Penalize if CIF_j > CIF_i (j predicted higher risk at t_i)
                diff = cif_j_at_ti - cif_i_at_ti
                pair_loss = torch.exp(diff / self.sigma)
                
                ranking_loss = ranking_loss + pair_loss.sum()
                n_pairs += len(later_indices)
        
        if n_pairs > 0:
            ranking_loss = ranking_loss / n_pairs
        
        return ranking_loss


print("DeepHitLoss (Competing Risks) defined")
print(f"  - Joint NLL over P(T=t, K=k)")
print(f"  - Cause-specific ranking loss")
print(f"  - alpha={DEEPHIT_PARAMS['alpha']}, sigma={DEEPHIT_PARAMS['sigma']}")

In [None]:
# === Create Time Bins for Discretization ===

print("=== Creating Time Bins ===")

# Get min/max durations
all_durations = np.concatenate([duration_train, duration_val, duration_test])
min_dur, max_dur = all_durations.min(), all_durations.max()

# Create evenly spaced time bins
time_bins = np.linspace(min_dur, max_dur, NUM_DURATIONS + 1)
time_bins_tensor = torch.tensor(time_bins, dtype=torch.float32)

print(f"Duration range: {min_dur:.0f} - {max_dur:.0f} months")
print(f"Number of bins: {NUM_DURATIONS}")
print(f"Bin width: {(max_dur - min_dur) / NUM_DURATIONS:.2f} months")

# Store time points (bin centers) for later use
time_points = (time_bins[:-1] + time_bins[1:]) / 2
print(f"Time points: {time_points[:5]}... {time_points[-5:]}")

In [None]:
# === Training Function (Competing Risks) ===

def train_deephit(
    model: torch.nn.Module,
    criterion: DeepHitLoss,
    X_train: np.ndarray,
    y_train: np.ndarray,  # durations
    e_train: np.ndarray,  # event codes (0=censored, 1=prepay, 2=default)
    X_val: np.ndarray,
    y_val: np.ndarray,
    e_val: np.ndarray,
    time_bins: torch.Tensor,
    batch_size: int = 256,
    epochs: int = 100,
    learning_rate: float = 0.01,
    patience: int = 10,
    device: torch.device = torch.device('cpu'),
) -> dict:
    """
    Train DeepHit competing risks model with early stopping.
    
    Returns:
        Dictionary with training history
    """
    # Move model and time bins to device
    model = model.to(device)
    time_bins = time_bins.to(device)
    
    # Create tensors
    X_train_t = torch.tensor(X_train, dtype=torch.float32)
    y_train_t = torch.tensor(y_train, dtype=torch.float32)
    e_train_t = torch.tensor(e_train, dtype=torch.float32)
    
    X_val_t = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val_t = torch.tensor(y_val, dtype=torch.float32).to(device)
    e_val_t = torch.tensor(e_val, dtype=torch.float32).to(device)
    
    # Create DataLoader
    train_dataset = torch.utils.data.TensorDataset(X_train_t, y_train_t, e_train_t)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )
    
    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_nll': [],
        'train_rank': [],
        'val_nll': [],
        'val_rank': [],
    }
    
    # Early stopping
    best_val_loss = float('inf')
    best_epoch = 0
    best_state = None
    epochs_no_improve = 0
    
    print(f"Training on {device}")
    print(f"{'Epoch':>6} | {'Train Loss':>10} | {'Val Loss':>10} | {'NLL':>8} | {'Rank':>8}")
    print("-" * 60)
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_losses = []
        train_nlls = []
        train_ranks = []
        
        for batch_X, batch_y, batch_e in train_loader:
            batch_X = batch_X.to(device)
            batch_y = batch_y.to(device)
            batch_e = batch_e.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            pmf = model(batch_X)
            loss, nll, rank = criterion(pmf, batch_y, batch_e, time_bins)
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_losses.append(loss.item())
            train_nlls.append(nll.item())
            train_ranks.append(rank.item())
        
        avg_train_loss = np.mean(train_losses)
        avg_train_nll = np.mean(train_nlls)
        avg_train_rank = np.mean(train_ranks)
        
        # Validation
        model.eval()
        with torch.no_grad():
            pmf_val = model(X_val_t)
            val_loss, val_nll, val_rank = criterion(pmf_val, y_val_t, e_val_t, time_bins)
            val_loss = val_loss.item()
            val_nll = val_nll.item()
            val_rank = val_rank.item()
        
        # Update scheduler
        scheduler.step(val_loss)
        
        # Save history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(val_loss)
        history['train_nll'].append(avg_train_nll)
        history['train_rank'].append(avg_train_rank)
        history['val_nll'].append(val_nll)
        history['val_rank'].append(val_rank)
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
        
        # Print progress
        if epoch % 10 == 0 or epoch == epochs - 1:
            print(f"{epoch:>6} | {avg_train_loss:>10.4f} | {val_loss:>10.4f} | {val_nll:>8.4f} | {val_rank:>8.4f}")
        
        # Check early stopping
        if epochs_no_improve >= patience:
            print(f"\nEarly stopping at epoch {epoch}")
            break
    
    # Restore best model
    if best_state is not None:
        model.load_state_dict(best_state)
        model.to(device)
    
    print(f"\nBest epoch: {best_epoch} (val_loss={best_val_loss:.4f})")
    
    return history


print("train_deephit function defined (with gradient clipping)")

---

## Train Joint DeepHit Model

Train a single model that jointly predicts prepayment and default risks.

In [None]:
# === Train Joint DeepHit Model ===
print("=== Training Joint DeepHit Model (Competing Risks) ===")

# Get event codes (0=censored, 1=prepay, 2=default)
event_train = train_df[event_col].values.astype('float32')
event_val = val_df[event_col].values.astype('float32')
event_test = test_df[event_col].values.astype('float32')

print(f"\nEvent distribution (training):")
for code, name in event_names.items():
    count = (event_train == code).sum()
    print(f"  {name} (k={code}): {count:,.0f}")

# Create model (Blumenstock architecture)
in_features = X_train.shape[1]
model = DeepHitNetwork(
    in_features=in_features,
    num_time_bins=NUM_DURATIONS,
    num_causes=2,
    shared_layers=3,
    shared_nodes=300,
    head_layers=5,
    head_nodes=200,
    dropout=DEEPHIT_PARAMS['dropout'],
    batch_norm=DEEPHIT_PARAMS['batch_norm'],
)

# Count parameters
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel parameters: {n_params:,}")

# Create loss function
criterion = DeepHitLoss(
    alpha=DEEPHIT_PARAMS['alpha'],
    sigma=DEEPHIT_PARAMS['sigma'],
)

# Train model
history = train_deephit(
    model=model,
    criterion=criterion,
    X_train=X_train,
    y_train=duration_train,
    e_train=event_train,
    X_val=X_val,
    y_val=duration_val,
    e_val=event_val,
    time_bins=time_bins_tensor,
    batch_size=DEEPHIT_PARAMS['batch_size'],
    epochs=DEEPHIT_PARAMS['epochs'],
    learning_rate=DEEPHIT_PARAMS['learning_rate'],
    patience=10,
    device=DEVICE,
)

print("\nJoint model training complete!")

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Total loss
ax = axes[0]
ax.plot(history['train_loss'], label='Train')
ax.plot(history['val_loss'], label='Validation')
ax.set_xlabel('Epoch')
ax.set_ylabel('Total Loss')
ax.set_title('DeepHit: Total Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# NLL component
ax = axes[1]
ax.plot(history['train_nll'], label='Train')
ax.plot(history['val_nll'], label='Validation')
ax.set_xlabel('Epoch')
ax.set_ylabel('NLL Loss')
ax.set_title('DeepHit: NLL Component')
ax.legend()
ax.grid(True, alpha=0.3)

# Ranking component
ax = axes[2]
ax.plot(history['train_rank'], label='Train')
ax.plot(history['val_rank'], label='Validation')
ax.set_xlabel('Epoch')
ax.set_ylabel('Ranking Loss')
ax.set_title('DeepHit: Ranking Component')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'deephit_training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

---

## Model Evaluation

Evaluate using time-dependent C-index at 24, 48, and 72 months.

In [None]:
# === Evaluate Joint Model ===
print("=== DeepHit Joint Model Evaluation ===")

# Get predictions from joint model
model.eval()
with torch.no_grad():
    X_test_t = torch.tensor(X_test, dtype=torch.float32).to(DEVICE)
    
    # Get CIF for each cause
    cif_tensor = model.predict_cif(X_test_t)  # [batch, num_causes, num_time_bins]
    cif_np = cif_tensor.cpu().numpy()
    
    # Get overall survival
    surv_tensor = model.predict_survival(X_test_t)  # [batch, num_time_bins]
    surv_np = surv_tensor.cpu().numpy()

# Extract cause-specific CIFs
cif_prepay = cif_np[:, 0, :]   # Prepay CIF [batch, time]
cif_default = cif_np[:, 1, :]  # Default CIF [batch, time]

print(f"CIF shapes: prepay={cif_prepay.shape}, default={cif_default.shape}")
print(f"Survival shape: {surv_np.shape}")

# Convert survival to DataFrame for EvalSurv
surv_df = pd.DataFrame(surv_np.T, index=time_points)

# Create cause-specific event indicators for evaluation
event_prepay_test = (event_test == 1).astype('float32')
event_default_test = (event_test == 2).astype('float32')

duration_test_np = np.array(duration_test).astype('float64')

In [None]:
# === Compute C-index for both causes ===

# For C-index, we use CIF as risk score (higher CIF = higher risk)
# Risk at a specific time horizon
def get_risk_at_horizon(cif, time_points, tau):
    """Get CIF value at time horizon tau."""
    idx = np.searchsorted(time_points, tau)
    idx = min(idx, len(time_points) - 1)
    return cif[:, idx]

# === PREPAYMENT Evaluation ===
print("=== PREPAYMENT C-index ===")
print("-" * 50)

# Create sksurv structured arrays
y_train_prepay_sk = Surv.from_arrays(
    (train_df[event_col] == 1).values.astype(bool),
    duration_train
)
y_test_prepay_sk = Surv.from_arrays(
    event_prepay_test.astype(bool),
    duration_test
)

cindex_prepay_ipcw = {}
for tau in TIME_HORIZONS:
    try:
        risk_prepay = get_risk_at_horizon(cif_prepay, time_points, tau)
        c_tau = concordance_index_ipcw(
            y_train_prepay_sk,
            y_test_prepay_sk,
            risk_prepay,
            tau=tau
        )
        cindex_prepay_ipcw[tau] = c_tau[0]
        print(f"  tau = {tau:3d} months: C-index (IPCW) = {c_tau[0]:.4f}")
    except Exception as e:
        print(f"  tau = {tau:3d} months: Error - {str(e)[:50]}")

# Overall C-index (using mean CIF as risk)
risk_prepay_overall = cif_prepay.mean(axis=1)
c_index_prepay = concordance_index_censored(
    event_prepay_test.astype(bool),
    duration_test,
    risk_prepay_overall
)
print(f"\nOverall C-index (Harrell): {c_index_prepay[0]:.4f}")

# === DEFAULT Evaluation ===
print("\n=== DEFAULT C-index ===")
print("-" * 50)

y_train_default_sk = Surv.from_arrays(
    (train_df[event_col] == 2).values.astype(bool),
    duration_train
)
y_test_default_sk = Surv.from_arrays(
    event_default_test.astype(bool),
    duration_test
)

cindex_default_ipcw = {}
for tau in TIME_HORIZONS:
    try:
        risk_default = get_risk_at_horizon(cif_default, time_points, tau)
        c_tau = concordance_index_ipcw(
            y_train_default_sk,
            y_test_default_sk,
            risk_default,
            tau=tau
        )
        cindex_default_ipcw[tau] = c_tau[0]
        print(f"  tau = {tau:3d} months: C-index (IPCW) = {c_tau[0]:.4f}")
    except Exception as e:
        print(f"  tau = {tau:3d} months: Error - {str(e)[:50]}")

# Overall C-index
risk_default_overall = cif_default.mean(axis=1)
c_index_default = concordance_index_censored(
    event_default_test.astype(bool),
    duration_test,
    risk_default_overall
)
print(f"\nOverall C-index (Harrell): {c_index_default[0]:.4f}")

In [None]:
# Plot time-dependent C-index comparison
fig, ax = plt.subplots(figsize=(10, 6))

# Use IPCW results for plotting
horizons = sorted(set(cindex_prepay_ipcw.keys()) & set(cindex_default_ipcw.keys()))
prepay_cindex = [cindex_prepay_ipcw[h] for h in horizons]
default_cindex = [cindex_default_ipcw[h] for h in horizons]

x = np.arange(len(horizons))
width = 0.35

bars1 = ax.bar(x - width/2, prepay_cindex, width, label='Prepayment', color='steelblue', alpha=0.8)
bars2 = ax.bar(x + width/2, default_cindex, width, label='Default', color='indianred', alpha=0.8)

# Add value labels
for bar, val in zip(bars1, prepay_cindex):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
            f'{val:.3f}', ha='center', va='bottom', fontsize=10)
for bar, val in zip(bars2, default_cindex):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
            f'{val:.3f}', ha='center', va='bottom', fontsize=10)

ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Random (0.5)')

ax.set_xlabel('Time Horizon (months)', fontsize=12)
ax.set_ylabel('C-index (IPCW)', fontsize=12)
ax.set_title('DeepHit: Time-Dependent Concordance Index by Event Type', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels([f'tau = {h}' for h in horizons])
ax.set_ylim(0.4, 1.0)
ax.legend(loc='lower right')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'deephit_time_dependent_cindex.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Figure saved to: {FIGURES_DIR / 'deephit_time_dependent_cindex.png'}")

---

## Survival Curves

Plot predicted survival curves for sample loans.

In [None]:
# Plot CIF curves for sample loans
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Sample 5 random loans
np.random.seed(42)
sample_idx = np.random.choice(len(X_test), size=5, replace=False)

# Prepayment CIF
ax = axes[0]
for i, idx in enumerate(sample_idx):
    ax.plot(time_points, cif_prepay[idx], label=f'Loan {idx}', alpha=0.7)
ax.set_xlabel('Time (months)')
ax.set_ylabel('Cumulative Incidence')
ax.set_title('DeepHit: Prepayment CIF')
ax.legend(loc='lower right', fontsize=8)
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)

# Default CIF
ax = axes[1]
for i, idx in enumerate(sample_idx):
    ax.plot(time_points, cif_default[idx], label=f'Loan {idx}', alpha=0.7)
ax.set_xlabel('Time (months)')
ax.set_ylabel('Cumulative Incidence')
ax.set_title('DeepHit: Default CIF')
ax.legend(loc='lower right', fontsize=8)
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)

# Overall Survival
ax = axes[2]
for i, idx in enumerate(sample_idx):
    ax.plot(time_points, surv_np[idx], label=f'Loan {idx}', alpha=0.7)
ax.set_xlabel('Time (months)')
ax.set_ylabel('Survival Probability')
ax.set_title('DeepHit: Overall Survival S(t)')
ax.legend(loc='lower left', fontsize=8)
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'deephit_survival_curves.png', dpi=150, bbox_inches='tight')
plt.show()

# Also plot stacked CIF to show competing risks
fig, ax = plt.subplots(figsize=(10, 6))

# Average CIF across all test samples
mean_cif_prepay = cif_prepay.mean(axis=0)
mean_cif_default = cif_default.mean(axis=0)
mean_survival = surv_np.mean(axis=0)

ax.fill_between(time_points, 0, mean_cif_prepay, alpha=0.7, label='Prepay', color='steelblue')
ax.fill_between(time_points, mean_cif_prepay, mean_cif_prepay + mean_cif_default, 
                alpha=0.7, label='Default', color='indianred')
ax.fill_between(time_points, mean_cif_prepay + mean_cif_default, 1, 
                alpha=0.3, label='Survival', color='gray')

ax.set_xlabel('Time (months)', fontsize=12)
ax.set_ylabel('Probability', fontsize=12)
ax.set_title('DeepHit: Average Stacked CIF (Test Set)', fontsize=14)
ax.legend(loc='center right')
ax.set_ylim(0, 1)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## Feature Importance (Permutation)

Since DeepHit is a neural network, we use permutation importance to assess feature importance.

In [None]:
# Compute permutation importance for joint model
print("=== Feature Importance (Permutation) ===")
print("Computing permutation importance for each cause...")

def permutation_importance_deephit_joint(model, X, duration, event, feature_names, 
                                         cause_idx, device, n_repeats=5):
    """
    Compute permutation importance for DeepHit joint model.
    
    Args:
        cause_idx: 0 for prepay, 1 for default
    """
    model.eval()
    event_code = cause_idx + 1  # Map cause index to event code
    event_binary = (event == event_code).astype(bool)
    
    # Baseline score using CIF as risk
    with torch.no_grad():
        X_t = torch.tensor(X, dtype=torch.float32).to(device)
        cif_baseline = model.predict_cif(X_t).cpu().numpy()[:, cause_idx, :]
    risk_baseline = cif_baseline.mean(axis=1)
    
    baseline_cindex = concordance_index_censored(
        event_binary, duration, risk_baseline
    )[0]
    
    importances = []
    importances_std = []
    
    for i, feat in enumerate(feature_names):
        scores = []
        for _ in range(n_repeats):
            X_perm = X.copy()
            np.random.shuffle(X_perm[:, i])
            
            with torch.no_grad():
                X_perm_t = torch.tensor(X_perm, dtype=torch.float32).to(device)
                cif_perm = model.predict_cif(X_perm_t).cpu().numpy()[:, cause_idx, :]
            risk_perm = cif_perm.mean(axis=1)
            
            perm_cindex = concordance_index_censored(
                event_binary, duration, risk_perm
            )[0]
            
            scores.append(baseline_cindex - perm_cindex)
        
        importances.append(np.mean(scores))
        importances_std.append(np.std(scores))
        
    return np.array(importances), np.array(importances_std)

# Prepayment importance
print("\nPrepayment cause...")
imp_prepay, imp_prepay_std = permutation_importance_deephit_joint(
    model, X_test, duration_test, event_test, feature_cols, 
    cause_idx=0, device=DEVICE, n_repeats=5
)

importance_prepay = pd.DataFrame({
    'feature': feature_cols,
    'importance': imp_prepay,
    'std': imp_prepay_std
}).sort_values('importance', ascending=False)

# Default importance
print("Default cause...")
imp_default, imp_default_std = permutation_importance_deephit_joint(
    model, X_test, duration_test, event_test, feature_cols, 
    cause_idx=1, device=DEVICE, n_repeats=5
)

importance_default = pd.DataFrame({
    'feature': feature_cols,
    'importance': imp_default,
    'std': imp_default_std
}).sort_values('importance', ascending=False)

print("\nTop 10 Features - PREPAYMENT:")
print(importance_prepay.head(10).to_string(index=False))

print("\nTop 10 Features - DEFAULT:")
print(importance_default.head(10).to_string(index=False))

In [None]:
# Plot feature importance comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 8))

# Prepayment
ax = axes[0]
top_n = 15
plot_df = importance_prepay.head(top_n).iloc[::-1]
ax.barh(plot_df['feature'], plot_df['importance'], xerr=plot_df['std'],
        color='steelblue', alpha=0.7, capsize=3)
ax.set_xlabel('Importance (decrease in C-index)')
ax.set_title('DeepHit Permutation Importance: Prepayment', fontsize=12)
ax.grid(True, alpha=0.3, axis='x')

# Default
ax = axes[1]
plot_df = importance_default.head(top_n).iloc[::-1]
ax.barh(plot_df['feature'], plot_df['importance'], xerr=plot_df['std'],
        color='indianred', alpha=0.7, capsize=3)
ax.set_xlabel('Importance (decrease in C-index)')
ax.set_title('DeepHit Permutation Importance: Default', fontsize=12)
ax.grid(True, alpha=0.3, axis='x')

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'deephit_feature_importance.png', dpi=150, bbox_inches='tight')
plt.show()

---

## Save Models

In [None]:
# Save joint model (PyTorch native format)
torch.save({
    'model_state_dict': model.state_dict(),
    'in_features': in_features,
    'num_time_bins': NUM_DURATIONS,
    'num_causes': 2,
    'shared_layers': 3,
    'shared_nodes': 300,
    'head_layers': 5,
    'head_nodes': 200,
    'dropout': DEEPHIT_PARAMS['dropout'],
    'batch_norm': DEEPHIT_PARAMS['batch_norm'],
}, MODELS_DIR / 'deephit_joint.pt')

# Save scaler for inference
with open(MODELS_DIR / 'deephit_scaler.pkl', 'wb') as f:
    pickle.dump(scaler, f)

# Save time bins
np.save(MODELS_DIR / 'deephit_time_bins.npy', time_bins)

# Save feature importance
importance_prepay.to_csv(MODELS_DIR / 'deephit_importance_prepay.csv', index=False)
importance_default.to_csv(MODELS_DIR / 'deephit_importance_default.csv', index=False)

# Save feature columns
with open(MODELS_DIR / 'deephit_feature_cols.pkl', 'wb') as f:
    pickle.dump(feature_cols, f)

# Save training history
with open(MODELS_DIR / 'deephit_history.pkl', 'wb') as f:
    pickle.dump(history, f)

print(f"Model saved to {MODELS_DIR}:")
print(f"  - deephit_joint.pt (joint competing risks model)")
print(f"  - deephit_scaler.pkl")
print(f"  - deephit_time_bins.npy")
print(f"  - deephit_importance_prepay.csv")
print(f"  - deephit_importance_default.csv")
print(f"  - deephit_feature_cols.pkl")
print(f"  - deephit_history.pkl")

---

## Summary

In [None]:
print("=" * 70)
print("DEEPHIT (BLUMENSTOCK ARCHITECTURE) - SUMMARY")
print("=" * 70)

print(f"\nArchitecture:")
print(f"  Shared FFN: 3 layers x 300 nodes")
print(f"  Residual: shared_output + proj(input)")
print(f"  Prepay head: 5 layers x 200 nodes -> {NUM_DURATIONS} time bins")
print(f"  Default head: 5 layers x 200 nodes -> {NUM_DURATIONS} time bins")
print(f"  Output: Joint softmax over {2 * NUM_DURATIONS} (cause, time)")
print(f"  Total parameters: {n_params:,}")

print(f"\nData:")
print(f"  Training loans: {len(train_df):,}")
print(f"  Validation loans: {len(val_df):,}")
print(f"  Test loans: {len(test_df):,}")

print(f"\nFeatures: {len(feature_cols)}")
print(f"  Static: {len([f for f in STATIC_FEATURES if f in feature_cols or 'log_upb' in feature_cols])}")
print(f"  Behavioral: {len([f for f in BEHAVIORAL_FEATURES if f in feature_cols or 'bal_repaid_lag1' in feature_cols])}")
print(f"  Macro: {len([f for f in MACRO_FEATURES if f in feature_cols])}")

print(f"\nTraining Parameters:")
print(f"  Batch size: {DEEPHIT_PARAMS['batch_size']}")
print(f"  Learning rate: {DEEPHIT_PARAMS['learning_rate']}")
print(f"  Alpha (ranking): {DEEPHIT_PARAMS['alpha']}")
print(f"  Sigma: {DEEPHIT_PARAMS['sigma']}")
print(f"  Dropout: {DEEPHIT_PARAMS['dropout']}")

print(f"\n{'='*70}")
print("MODEL PERFORMANCE (Test Set)")
print("=" * 70)

print(f"\nPREPAYMENT:")
print(f"  Overall C-index (Harrell): {c_index_prepay[0]:.4f}")
print(f"  Time-Dependent C-index (IPCW):")
for tau, c in cindex_prepay_ipcw.items():
    print(f"    tau = {tau:3d} months: {c:.4f}")

print(f"\nDEFAULT:")
print(f"  Overall C-index (Harrell): {c_index_default[0]:.4f}")
print(f"  Time-Dependent C-index (IPCW):")
for tau, c in cindex_default_ipcw.items():
    print(f"    tau = {tau:3d} months: {c:.4f}")

print(f"\nTop 3 Important Features:")
print(f"  Prepayment: {', '.join(importance_prepay['feature'].head(3).tolist())}")
print(f"  Default: {', '.join(importance_default['feature'].head(3).tolist())}")

---

## Next Steps

**Notebook 09**: Model Comparison

Compare all models:
- Cause-Specific Cox (notebook 05)
- Fine-Gray (notebook 06)
- Random Survival Forest (notebook 07)
- DeepHit (this notebook)

Key comparisons:
- Time-dependent C-index at multiple horizons
- Calibration assessment
- Cumulative incidence predictions
- Computational cost