# Using the Brier score to evaluate the predictive ability of a time-to-event model

References:
- [scikit-survival docs](https://scikit-survival.readthedocs.io/en/stable/user_guide/evaluating-survival-models.html#Time-dependent-Brier-Score)  # noqa


Todo:
- Models:
  - M1: null model that takes in X and returns 0.5 for every case
  - M2: perfect model that takes in X and passes it to the known Weibull survival
      function to get true survival probabilities (or 1-surv_prob if we want a risk
      score)
  - M3: intermediate model that passes to very similar Weibull params, but not exact
  - M4: KM model "learned" from training data
- set up functions for generating the data and splitting into train/test
- Evaluate models M1-M4 on test data


In [1]:
!pip install scikit-survival

Collecting scikit-survival
  Obtaining dependency information for scikit-survival from https://files.pythonhosted.org/packages/7b/0f/8cd0432bf659d934f333dd9447630b01eafe0a96f0e843f4de9c17af3747/scikit_survival-0.22.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading scikit_survival-0.22.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (49 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.1/49.1 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Collecting osqp!=0.6.0,!=0.6.1 (from scikit-survival)
  Obtaining dependency information for osqp!=0.6.0,!=0.6.1 from https://files.pythonhosted.org/packages/dd/d1/a091ae0a5fb583147184592011952aeb7827cde73a0fe7b7e95d84d752fd/osqp-0.6.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading osqp-0.6.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (

In [2]:
from sksurv.datasets import load_gbsg2

In [3]:
load_gbsg2()

(      age  estrec horTh menostat  pnodes  progrec tgrade  tsize
 0    70.0    66.0    no     Post     3.0     48.0     II   21.0
 1    56.0    77.0   yes     Post     7.0     61.0     II   12.0
 2    58.0   271.0   yes     Post     9.0     52.0     II   35.0
 3    59.0    29.0   yes     Post     4.0     60.0     II   17.0
 4    73.0    65.0    no     Post     1.0     26.0     II   35.0
 ..    ...     ...   ...      ...     ...      ...    ...    ...
 681  49.0    84.0    no      Pre     3.0      1.0    III   30.0
 682  53.0     0.0   yes     Post    17.0      0.0    III   25.0
 683  51.0     0.0    no      Pre     5.0     43.0    III   25.0
 684  52.0    34.0    no     Post     3.0     15.0     II   23.0
 685  55.0    15.0    no     Post     9.0    116.0     II   23.0
 
 [686 rows x 8 columns],
 array([( True, 1814.), ( True, 2018.), ( True,  712.), ( True, 1807.),
        ( True,  772.), ( True,  448.), (False, 2172.), (False, 2161.),
        ( True,  471.), (False, 2014.), ( True,  