In [1]:
import numpy

## Mini example for scikit-survival estimators

Load standard dataset for testing

In [2]:
def load_toy_data():
    import sksurv.datasets
    from sklearn.preprocessing import OneHotEncoder
    
    X, y = sksurv.datasets.load_breast_cancer()
    X = numpy.concatenate([
        X.select_dtypes('float'), 
        OneHotEncoder(sparse=False).fit_transform(X.select_dtypes('category'))
        #OneHotEncoder(sparse_output=False).fit_transform(X.select_dtypes('category'))
        ], axis=1)
    return X, y

In [3]:
X, y = load_toy_data()
X.shape

(198, 84)

Fit penalized Cox model (from scikit-survival)

In [4]:
from sksurv.linear_model import CoxnetSurvivalAnalysis
from sksurv.metrics import concordance_index_censored

In [5]:
model = CoxnetSurvivalAnalysis()
model.fit(X, y)
pred = model.predict(X)

Standard functions from scikit-learn can be used with scikit-survival models

In [6]:
from sklearn.model_selection import cross_val_score

In [7]:
cross_val_score(CoxnetSurvivalAnalysis(), X, y)

array([0.61711712, 0.53243243, 0.58865248, 0.60504202, 0.75238095])

## Scikit-learn compatibility
Scikit-learn has a checker for estimators to see if they conform to their specification. 

In [8]:
from sklearn.utils.estimator_checks import check_estimator

Scikit-survival models do not necessarily pass ;-)

In [9]:
try:
    check_estimator(CoxnetSurvivalAnalysis())
except Exception as e:
    print(e.__class__.__name__)
    print(e)

AssertionError
Estimator CoxnetSurvivalAnalysis should not set any attribute apart from parameters during init. Found attributes ['_baseline_models'].


This is an example of a minimal (empty) estimator that passes the tests. Dataclasses can be useful to avoid long __init__ functions and it appears to work. BaseEstimator include the get/set_params methods that are required. check_X_y and check_array implement checks (required by the check_estimator function) on the input data.

In [10]:
from dataclasses import dataclass
from sklearn.base import BaseEstimator
from sklearn.utils import check_X_y, check_array

@dataclass
class TestEstimator(BaseEstimator):
    param1: int = 1,

    def fit(self, X, y):
        X, y = check_X_y(X, y)
        self._validate_data(X, y)
        return self

    def predict(self, X):
        X = check_array(X)
        return numpy.full(shape=X.shape[0], fill_value=self.param1)

check_estimator(TestEstimator(param1=33))

## Outcome format

Esempio formato scikit-survival

In [11]:
y[:10]

array([( True,  723.), (False, 6591.), ( True,  524.), (False, 6255.),
       ( True, 3822.), (False, 6507.), (False, 5947.), (False, 5816.),
       (False, 6007.), ( True, 1233.)],
      dtype=[('e.tdm', '?'), ('t.tdm', '<f8')])

In [12]:
def get_indicator(y):
    return y[y.dtype.names[0]]
def get_time(y):
    return y[y.dtype.names[1]]

In [13]:
get_indicator(y[:10])

array([ True, False,  True, False,  True, False, False, False, False,
        True])

## Wrapper for the DeepHit method from the pycox module

Molto preliminare, sto recuperando e adattando il codice dal notebook ALS

In [14]:
import pycox

In [15]:
from dataclasses import dataclass
from sklearn.base import BaseEstimator

class DeepHitPycox(BaseEstimator):
    num_durations: int = 10,
    # qui mettiamo i parametri per la forma della rete, cercherei di fare qualcosa che rispetti il paper originale
    layer_sizes: list = [10, 10],
    epochs: int = 10 # maybe implement also early stopping
    batch_size: int = 16

# Pasticci

In [16]:
from sklearn.tree import DecisionTreeClassifier

In [17]:
cross_val_score(DecisionTreeClassifier(), X, get_indicator(y))

array([0.675     , 0.575     , 0.625     , 0.69230769, 0.66666667])

In [18]:
class DeepSurvivalMachines:
    def fit(self, X, y):
        import sys
        sys.path.append('./auton-survival')
        from auton_survival.models.dsm import DeepSurvivalMachines
        try:
            self.model_ = DeepSurvivalMachines(
                **self.model_params['mod']).fit(
                X, times, events, **self.model_params['fit'])
        except RuntimeError as e:
            raise FailedModel(f'{self.short_name} model fit failed: {e}')
        return self

    def predict(self, X, eval_times):
        """predict probabilites of event up to given times for each event"""
        #global dbg
        #dbg = self.model_, X, eval_times, numpy.swapaxes([self.model_.predict_risk(X, t)[:, 0] for t in eval_times], 0, 1)
        return numpy.swapaxes([self.model_.predict_risk(X, t)[:, 0] for t in eval_times], 0, 1)
        #return numpy.nan_to_num(numpy.swapaxes([self.model_.predict_risk(X, t)[:, 0] for t in eval_times], 0, 1), nan=0.5, posinf=1, neginf=0)

DeepSurvivalMachines().fit(X, y).predict(X)

ModuleNotFoundError: No module named 'auton_survival'

In [None]:
def concordance_index_score(y_true, y_pred, *args, **kwargs):
    from sksurv.metrics import concordance_index_censored
    return concordance_index_censored(
        event_indicator=get_indicator(y_true),
        event_time=get_time(y_true),
        estimate=y_pred,
        *args, **kwargs,
    )


In [None]:
cross_val_score(CoxnetSurvivalAnalysis(), X, y, scoring=concordance_index_score)

Traceback (most recent call last):
  File "/home/gbirolo/survwrap/venv/lib/python3.10/site-packages/sklearn/metrics/_scorer.py", line 117, in __call__
    score = scorer(estimator, *args, **kwargs)
  File "/tmp/ipykernel_770125/2810917926.py", line 4, in concordance_index_score
    event_indicator=get_indicator(y_true),
  File "/tmp/ipykernel_770125/2231177257.py", line 2, in get_indicator
    return y[y.dtype.names[0]]
AttributeError: 'CoxnetSurvivalAnalysis' object has no attribute 'dtype'

Traceback (most recent call last):
  File "/home/gbirolo/survwrap/venv/lib/python3.10/site-packages/sklearn/metrics/_scorer.py", line 117, in __call__
    score = scorer(estimator, *args, **kwargs)
  File "/tmp/ipykernel_770125/2810917926.py", line 4, in concordance_index_score
    event_indicator=get_indicator(y_true),
  File "/tmp/ipykernel_770125/2231177257.py", line 2, in get_indicator
    return y[y.dtype.names[0]]
AttributeError: 'CoxnetSurvivalAnalysis' object has no attribute 'dtype'

Trac

array([nan, nan, nan, nan, nan])

In [None]:
import survwrap