# `auton-survival` Cross Validation Survival Regression

`auton-survival` offers a simple to use API to train Survival Regression Models that performs cross validation model selection by minimizing integrated brier score. In this notebook we demonstrate the use of `auton-survival` to train survival models on the *SUPPORT* dataset in cross validation fashion.

In [None]:
# Suppress all warnings for cleaner output
import warnings
warnings.filterwarnings('ignore')

import sys

sys.path.append('../')
from auton_survival import datasets
outcomes, features = datasets.load_support()

In [2]:
# ====================================================================================
# Fixed here: Preprocessing to ensure numeric data for PyTorch models
# ====================================================================================
from auton_survival.preprocessing import Preprocessor
import pandas as pd
import numpy as np

cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 
             'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', 
             'glucose', 'bun', 'urine', 'adlp', 'adls']

# Data should be processed in a fold-independent manner when performing cross-validation. 
# For simplicity in this demo, we process the dataset in a non-independent manner.
# Fixed: Changed cat_feat_strat from 'ignore' to 'mode' to properly encode categorical features
preprocessor = Preprocessor(cat_feat_strat='mode', num_feat_strat='mean') 
x = preprocessor.fit_transform(features, cat_feats=cat_feats, num_feats=num_feats,
                                one_hot=True, fill_value=-1)

# Fixed: Convert any object columns to numeric for PyTorch compatibility
for col in x.columns:
    if x[col].dtype == 'object':
        x[col] = pd.to_numeric(x[col], errors='coerce')

# Fixed: Use float64 to match PyTorch's default precision (avoids dtype mismatch)
x = x.fillna(0).astype('float64')

In [3]:
import numpy as np
times = np.quantile(outcomes.time[outcomes.event==1], [0.25, 0.5, 0.75]).tolist()

In [4]:
from auton_survival.experiments import SurvivalRegressionCV

param_grid = {'k' : [3],
              'distribution' : ['Weibull'],
              'learning_rate' : [1e-4, 1e-3],
              'layers' : [[100]]}

experiment = SurvivalRegressionCV(model='dsm', num_folds=3, hyperparam_grid=param_grid, random_seed=0)
model = experiment.fit(x, outcomes, times, metric='brs')

At hyper-param {'distribution': 'Weibull', 'k': 3, 'layers': [100], 'learning_rate': 0.0001}
At fold: 0


 13%|█▎        | 1281/10000 [00:01<00:07, 1126.88it/s]
100%|██████████| 50/50 [00:03<00:00, 16.17it/s]


At fold: 1


 18%|█▊        | 1767/10000 [00:01<00:06, 1212.79it/s]
100%|██████████| 50/50 [00:02<00:00, 16.98it/s]


At fold: 2


 14%|█▍        | 1429/10000 [00:01<00:06, 1237.97it/s]
100%|██████████| 50/50 [00:02<00:00, 16.88it/s]


At hyper-param {'distribution': 'Weibull', 'k': 3, 'layers': [100], 'learning_rate': 0.001}
At fold: 0


 13%|█▎        | 1281/10000 [00:01<00:07, 1221.32it/s]
100%|██████████| 50/50 [00:03<00:00, 16.47it/s]


At fold: 1


 18%|█▊        | 1767/10000 [00:01<00:06, 1229.09it/s]
100%|██████████| 50/50 [00:03<00:00, 15.86it/s]


At fold: 2


 14%|█▍        | 1429/10000 [00:01<00:07, 1166.15it/s]
100%|██████████| 50/50 [00:03<00:00, 16.19it/s]
 19%|█▉        | 1886/10000 [00:01<00:07, 1059.32it/s]
100%|██████████| 50/50 [00:04<00:00, 10.80it/s]


In [5]:
print(experiment.folds)
model

[2 2 0 ... 0 0 0]


<auton_survival.estimators.SurvivalModel at 0x1119461d0>

In [6]:
out_risk = model.predict_risk(x, times)
out_survival = model.predict_survival(x, times)

In [7]:
from auton_survival.metrics import survival_regression_metric

for fold in set(experiment.folds):
    print(survival_regression_metric('brs', outcomes[experiment.folds==fold], 
                                     out_survival[experiment.folds==fold], 
                                     times=times))

[0.12904334 0.19417136 0.20622495]
[0.12688411 0.19247997 0.20633734]
[0.12123986 0.1915304  0.20921462]


In [8]:
from auton_survival.metrics import survival_regression_metric

for fold in set(experiment.folds):
    print(survival_regression_metric('ctd', outcomes[experiment.folds==fold], 
                                     out_survival[experiment.folds==fold], 
                                     times=times))

[0.7651762964588759, 0.7223050931427001, 0.6886754271838222]
[0.7818161784601481, 0.728035088083212, 0.688930205374894]
[0.7620162089027601, 0.7173320330689428, 0.6806088154866267]


In [9]:
for fold in set(experiment.folds):
    for time in times:
        print(time)

14.0
58.0
252.0
14.0
58.0
252.0
14.0
58.0
252.0
