In [1]:
import survwrap as tosa
import numpy as np

In [2]:
test_X, test_y = tosa.load_test_data()

In [3]:
test_X.shape, test_y.shape

((198, 84), (198,))

## Wrapper per DSM

In [60]:
from dataclasses import dataclass, field
from collections.abc import Sequence # abc: Abstract Base Class
from survwrap import SurvivalEstimator
import numpy as np
from sklearn.utils import check_X_y, check_array
from sksurv.metrics import concordance_index_censored

#from tosa import SurvivalEstimator

@dataclass
class DeepSurvivalMachines(SurvivalEstimator):
    import numpy as np
    
    n_distr: int = 2
    distr_kind: str = 'Weibull'
    batch_size: int = 10
    layer_sizes: Sequence[int] = field(default_factory= lambda: [10,10]) #haskellismo da farsi spiegare
    learning_rate: float = 0.001
    validation_size: float = 0.1
    max_epochs: int = 10
    elbo: bool = False # what is this?
    
    def fit(self, X, y):
        from auton_survival.models.dsm import DeepSurvivalMachines

        check_X_y(X,y)
        self.median_time_ = np.median(X)
        self.model_ = DeepSurvivalMachines(
            k=self.n_distr,
            distribution=self.distr_kind,
            layers=self.layer_sizes,
        ).fit(
            X, tosa.get_time(y), tosa.get_indicator(y),
            learning_rate=self.learning_rate,
            vsize = self.validation_size,
            batch_size = self.batch_size,
            iters = self.max_epochs, # this should be the maximum number of epochs
            elbo = self.elbo,
        )
        return self
        

    def predict(self, X, eval_times=None):
        """predict probabilites of event up to given times for each event"""
        #global dbg
        #dbg = self.model_, X, eval_times, numpy.swapaxes([self.model_.predict_risk(X, t)[:, 0] for t in eval_times], 0, 1)
        check_array(X)
        if eval_times is None: self.single_event = True 
        eval_times = [self.median_time_] if self.single_event else eval_times 
        #print(self.median_time_, eval_times)
        _preds = np.swapaxes([self.model_.predict_risk(X, t)[:, 0] for t in eval_times], 0, 1)
        return _preds.flatten() if self.single_event else _preds
    #return numpy.nan_to_num(numpy.swapaxes([self.model_.predict_risk(X, t)[:, 0] fo
        #return numpy.nan_to_num(numpy.swapaxes([self.model_.predict_risk(X, t)[:, 0] for t in eval_times], 0, 1), nan=0.5, posinf=1, neginf=0)

    def harrell_score(self, y_true, y_pred, *args, **kwargs):
        "return Harrell's C-index for a prediction"

        return concordance_index_censored(
            event_indicator=y_true[y_true.dtype.names[0]],
            event_time=tosa.get_time(y_true),
            estimate=y_pred,
            *args,
            **kwargs,
        )

    def score(self, X, y): 
        "return the Harrell's c-index as a sklearn score"
        X, y = check_X_y(X, y)
        return self.harrell_score(y, self.predict(X))[0]

    
        

In [77]:
dsm_model = DeepSurvivalMachines(rng_seed=2307, max_epochs=20, layer_sizes=[10,10,10]).fit(test_X, test_y)
dsm_pred = dsm_model.predict(test_X)

 41%|████████████████████▉                              | 4098/10000 [00:03<00:04, 1194.35it/s]
 15%|████████▊                                                  | 3/20 [00:00<00:01, 10.25it/s]


In [78]:
dsm_pred.shape, dsm_pred[:10]

((198,),
 array([0.00294256, 0.0027127 , 0.00260665, 0.00255611, 0.00264647,
        0.00257948, 0.00298992, 0.00257948, 0.00257948, 0.00262632]))

In [79]:
dsm_model.harrell_score(test_y, dsm_pred), dsm_model.score(test_X, test_y)

((0.6256785759373816, 4273, 2282, 1366, 0), 0.6256785759373816)

## Wrapped DSM CV test

In [43]:
from sklearn.model_selection import cross_val_score

In [80]:
cv_score = cross_val_score(dsm_model, test_X, test_y, cv=5)

 48%|████████████████████████▋                          | 4830/10000 [00:03<00:04, 1209.53it/s]
100%|██████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 37.28it/s]
 45%|███████████████████████                            | 4517/10000 [00:03<00:04, 1221.31it/s]
 20%|███████████▊                                               | 4/20 [00:00<00:01, 13.18it/s]
 35%|█████████████████▉                                 | 3518/10000 [00:02<00:05, 1219.41it/s]
100%|██████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 36.17it/s]
 38%|███████████████████▌                               | 3845/10000 [00:03<00:05, 1222.46it/s]
 75%|███████████████████████████████████████████▌              | 15/20 [00:00<00:00, 24.58it/s]
 33%|████████████████▊                                  | 3293/10000 [00:02<00:05, 1223.55it/s]
 15%|████████▊                                                  | 3/20 [00:00<00:01, 11.35it/s]


In [81]:
print(cv_score.mean().round(3), " ± ", cv_score.std().round(3))
cv_score

0.549  ±  0.124


array([0.60135135, 0.68108108, 0.56914894, 0.57983193, 0.31428571])

In [None]:
# plussminus : unicode 00b1