In [1]:
import numpy

## Mini example for scikit-survival estimators

Load standard dataset for testing

In [2]:
def load_toy_data():
    import sksurv.datasets
    from sklearn.preprocessing import OneHotEncoder
    
    X, y = sksurv.datasets.load_breast_cancer()
    X = numpy.concatenate([
        X.select_dtypes('float'), 
        OneHotEncoder(sparse=False).fit_transform(X.select_dtypes('category'))
        #OneHotEncoder(sparse_output=False).fit_transform(X.select_dtypes('category'))
        ], axis=1)
    return X, y

In [3]:
X, y = load_toy_data()
X.shape

(198, 84)

Fit penalized Cox model (from scikit-survival)

In [4]:
from sksurv.linear_model import CoxnetSurvivalAnalysis
from sksurv.metrics import concordance_index_censored

In [5]:
model = CoxnetSurvivalAnalysis()
model.fit(X, y)
pred = model.predict(X)

Standard functions from scikit-learn can be used with scikit-survival models

In [6]:
from sklearn.model_selection import cross_val_score

In [7]:
numpy.mean(cross_val_score(CoxnetSurvivalAnalysis(), X, y))

0.6191250002013456

## Scikit-learn compatibility
Scikit-learn has a checker for estimators to see if they conform to their specification. 

In [8]:
from sklearn.utils.estimator_checks import check_estimator

Scikit-survival models do not necessarily pass ;-)

In [9]:
try:
    check_estimator(CoxnetSurvivalAnalysis())
except Exception as e:
    print(e.__class__.__name__)
    print(e)

AssertionError
Estimator CoxnetSurvivalAnalysis should not set any attribute apart from parameters during init. Found attributes ['_baseline_models'].


This is an example of a minimal (empty) estimator that passes the tests. Dataclasses can be useful to avoid long __init__ functions and it appears to work. BaseEstimator include the get/set_params methods that are required. check_X_y and check_array implement checks (required by the check_estimator function) on the input data.

In [10]:
from dataclasses import dataclass
from sklearn.base import BaseEstimator
from sklearn.utils import check_X_y, check_array

@dataclass
class TestEstimator(BaseEstimator):
    param1: int = 1,

    def fit(self, X, y):
        X, y = check_X_y(X, y)
        self._validate_data(X, y)
        return self

    def predict(self, X):
        X = check_array(X)
        return numpy.full(shape=X.shape[0], fill_value=self.param1)

check_estimator(TestEstimator(param1=33))

## Outcome format

Esempio formato scikit-survival

In [11]:
y[:10]

array([( True,  723.), (False, 6591.), ( True,  524.), (False, 6255.),
       ( True, 3822.), (False, 6507.), (False, 5947.), (False, 5816.),
       (False, 6007.), ( True, 1233.)],
      dtype=[('e.tdm', '?'), ('t.tdm', '<f8')])

In [12]:
def get_indicator(y):
    return y[y.dtype.names[0]]
def get_time(y):
    return y[y.dtype.names[1]]

In [13]:
get_indicator(y[:10])

array([ True, False,  True, False,  True, False, False, False, False,
        True])

In [14]:
def check_surv_y(y):
    assert y.dtype[1] == float
    if y.dtype[0] == bool: # single event
        assert any(get_indicator(y))
    elif y.dtype[0] == int: # competing events
        e = get_indicator(y)
        assert any(e == 1)
    #return y
check_surv_y(y)

## Wrapper for the DeepHit method from the pycox module

Molto preliminare, sto recuperando e adattando il codice dal notebook ALS

In [15]:
import pycox

Ho fatto questa classe perche' la maggior parte dei modelli deep learning ha questi parametri...

In [27]:
from dataclasses import dataclass, field
from sklearn.base import BaseEstimator
from collections.abc import Sequence

@dataclass
class BaseSurvival(BaseEstimator):
    layer_sizes: Sequence[int] = field(default_factory=lambda: [10, 10])
    epochs: int = 10 # maybe implement also early stopping
    batch_size: int = 16
    learning_rate: float = 0.001
    device: str = 'cpu'

    def fit(self, X, y):
        self.median_time_ = numpy.median(get_time(y))
        self._fit(check_array(X), y)
        return self

    def predict(self, X, times=None):
        r = self._predict(check_array(X), [self.median_time_] if times is None else times)
        if times is None:
            r = r[:, 0]
        return r


In [28]:
@dataclass
class DeepHitPycox(BaseSurvival):
    num_durations: int = 10
    # qui mettiamo i parametri per la forma della rete, cercherei di fare qualcosa che rispetti il paper originale
    layer_sizes: Sequence[int] = field(default_factory=lambda: [10, 10])
    epochs: int = 10 # maybe implement also early stopping
    batch_size: int = 16
    learning_rate: float = 0.001
    device: str = 'cpu'

    def _fit(self, X, y):
        from pycox.models import DeepHitSingle
        import torchtuples as tt
        
        optimizer = tt.optim.AdamWR(
            lr=self.learning_rate,
            #decoupled_weight_decay=,
            #cycle_eta_multiplier=,
        )
        self.labtrans_ = DeepHitSingle.label_transform(self.num_durations)
        y_discrete = self.labtrans_.fit_transform(get_time(y), get_indicator(y))
        net = tt.practical.MLPVanilla(
            in_features=X.shape[1], 
            out_features=self.labtrans_.out_features, 
            num_nodes=self.layer_sizes,
            #batch_norm, dropout,
            #**self.model_params['indepnet'], **self.model_params['net']
        )
        self.model_ = DeepHitSingle(
            net, optimizer, 
            device=self.device,
        )

        self.model_.fit(
            X.astype('float32'), y_discrete, 
            num_workers=0 if True else n_jobs, 
            verbose=False, 
            epochs=self.epochs,
            batch_size=self.batch_size,
        )

    def predict(self, X, eval_times=None):
        preds = 1 - self.model_.predict_surv(X.astype('float32'))
        
        r = numpy.array([numpy.interp(
            [self.median_time_] if eval_times is None else eval_times, 
            self.labtrans_.cuts, p, left=0, right=1
        ) for p in preds])
        if eval_times is None:
            r = r[:, 0]
        return r

DeepHitPycox(epochs=3).fit(X, y).predict(X)

array([0.43400747, 0.43261647, 0.40482927, 0.51853378, 0.48376092,
       0.52510303, 0.48095094, 0.41796547, 0.4811762 , 0.49664181,
       0.51762345, 0.51315589, 0.43757693, 0.46304867, 0.47778717,
       0.46782003, 0.48432079, 0.51733526, 0.45714169, 0.53752853,
       0.53918343, 0.45799163, 0.45954248, 0.5149161 , 0.55116003,
       0.4868627 , 0.61068153, 0.50941051, 0.47016023, 0.48105872,
       0.46622029, 0.49530748, 0.436892  , 0.47476353, 0.53043029,
       0.49867153, 0.45486153, 0.51499362, 0.52545263, 0.47394102,
       0.4404548 , 0.56097785, 0.53205022, 0.60503965, 0.46855244,
       0.42425613, 0.7992805 , 0.70881606, 0.47032489, 0.44825696,
       0.47254311, 0.47980306, 0.51670312, 0.50449704, 0.46059427,
       0.47339197, 0.43157607, 0.54600291, 0.4768812 , 0.42066925,
       0.49484315, 0.51078752, 0.4886064 , 0.44483711, 0.48468494,
       0.45136087, 0.48303066, 0.4528864 , 0.4324962 , 0.53132864,
       0.50725421, 0.4974609 , 0.48598324, 0.42342955, 0.46994

Qui ho wrappato il c-index di scikit-survival nella forma delle metriche di scikit-learn (perche' non e' proprio compatibile)

In [18]:
def concordance_index_score(y_true, y_pred, *args, **kwargs):
    from sksurv.metrics import concordance_index_censored
    return concordance_index_censored(
        event_indicator=get_indicator(y_true),
        event_time=get_time(y_true),
        estimate=y_pred,
        *args, **kwargs,
    )[0]

In [19]:
from sklearn.metrics import make_scorer
numpy.mean(cross_val_score(DeepHitPycox(), X, y, scoring=make_scorer(concordance_index_score)))

0.4922461474526555

### Wrapper per DeepSurvivalMachines

In [30]:
@dataclass
class DeepSurvivalMachines(BaseSurvival):
    n_distr: int = 2
    distr_kind: str = 'Weibull'
    validation_size: float = 0.1
    elbo: bool = False # what is this?
    
    def _fit(self, X, y):
        import sys
        sys.path.append('./auton-survival-master')
        from auton_survival.models.dsm import DeepSurvivalMachines

        self.model_ = DeepSurvivalMachines(
            k=self.n_distr,
            distribution=self.distr_kind,
            layers=self.layer_sizes,
        ).fit(
            X, get_time(y), get_indicator(y),
            learning_rate=self.learning_rate,
            vsize = self.validation_size,
            batch_size = self.batch_size,
            iters = self.epochs, # this should be the maximum number of epochs
            elbo = self.elbo,
        )
        

    def _predict(self, X, eval_times):
        """predict probabilites of event up to given times for each event"""
        #global dbg
        #dbg = self.model_, X, eval_times, numpy.swapaxes([self.model_.predict_risk(X, t)[:, 0] for t in eval_times], 0, 1)
        return numpy.swapaxes([self.model_.predict_risk(X, t)[:, 0] for t in eval_times], 0, 1)
        #return numpy.nan_to_num(numpy.swapaxes([self.model_.predict_risk(X, t)[:, 0] for t in eval_times], 0, 1), nan=0.5, posinf=1, neginf=0)

DeepSurvivalMachines().fit(X, y).predict(X)

 41%|████      | 4098/10000 [00:07<00:10, 565.83it/s]
100%|██████████| 10/10 [00:00<00:00, 23.90it/s]


array([0.22734015, 0.22734015, 0.22734015, 0.22734015, 0.22734015,
       0.22734015, 0.22734015, 0.22734015, 0.22734015, 0.22734015,
       0.22734015, 0.22734015, 0.22734015, 0.22734015, 0.22931426,
       0.22734015, 0.22734015, 0.22734015, 0.22734015, 0.22734015,
       0.22734015, 0.22734015, 0.22734015, 0.22734015, 0.22734015,
       0.22734015, 0.22734015, 0.22734015, 0.22734015, 0.22734015,
       0.22734015, 0.22734015, 0.22734015, 0.22734015, 0.22734015,
       0.22734015, 0.23071089, 0.22734015, 0.22734015, 0.2273842 ,
       0.22734015, 0.22734015, 0.22734015, 0.22734015, 0.22734015,
       0.22734015, 0.22734015, 0.22734015, 0.22734015, 0.22734015,
       0.22734015, 0.22734015, 0.22734015, 0.22734015, 0.22734015,
       0.22734015, 0.22734015, 0.22734015, 0.22734015, 0.22734015,
       0.22734015, 0.22796632, 0.22734015, 0.22734015, 0.22734015,
       0.22734015, 0.22734015, 0.22734015, 0.22734015, 0.22734015,
       0.22734015, 0.22734015, 0.22734015, 0.22734015, 0.22734

In [32]:
numpy.mean(cross_val_score(DeepSurvivalMachines(), X, y, scoring=make_scorer(concordance_index_score)))

 48%|████▊     | 4830/10000 [00:08<00:08, 602.43it/s]
100%|██████████| 10/10 [00:00<00:00, 21.46it/s]
 45%|████▌     | 4517/10000 [00:07<00:08, 612.28it/s]
 90%|█████████ | 9/10 [00:00<00:00, 15.84it/s]
 35%|███▌      | 3518/10000 [00:05<00:10, 603.20it/s]
 50%|█████     | 5/10 [00:00<00:00,  8.89it/s]
 38%|███▊      | 3845/10000 [00:06<00:10, 609.31it/s]
100%|██████████| 10/10 [00:00<00:00, 30.42it/s]
 33%|███▎      | 3293/10000 [00:05<00:11, 598.87it/s]
100%|██████████| 10/10 [00:00<00:00, 19.71it/s]


0.5631863832364458

# Pasticci

In [None]:
from sklearn.tree import DecisionTreeClassifier

In [None]:
cross_val_score(DecisionTreeClassifier(), X, get_indicator(y))

In [None]:
cross_val_score(CoxnetSurvivalAnalysis(), X, y, scoring=concordance_index_score)

In [None]:
import survwrap