## Survival Analysis: Based on Gradient Boosting Machine(GBM)

The tutorial give typical workflow of Gradient Boosting Desicion Tree-based survival analysis including data-preprocessing, model selection and traning&validation, uses R package `gbm`.

Formally, it can be listed by:
1. Data Preprocessing
  - convert variables
  - load training and test set
2. Model Selection
  - cross validation
  - tune parameters
3. Traning&Validation
  - train gbm model
  - measure CI on testset
  - survival rates on time of interest
  
The best suggestion about usage of `gbm` is official documentation named **"Generalized Boosted Models:A guide to the gbm package"**. Here, I summary some points related to usage of the package and SA, and disscussion of options to gbm that most users will need to change or tune.

### 1. Loss Function

`distribution` is corresponding to loss function and application. Here, the option `coxph` indicating SA should be selected.

### 2. Model Fitting

`shrinkage` and `n.trees` are mostly related to performance, so those should be tuned carefully.

As recommended by author of `gbm`, it is generally the case that for small shrinkage parameters, 0.001 for example.
>  My rule of thumb is to set shrinkage as small as possible while still being able to fit the model in a reasonable amount of time and storage. I usually aim for 3,000 to 10,000 iterations with shrinkage rates between 0.01 and 0.001.

### 3. The Optimal Number of Iterations

`gbm` offers three methods for estimating the optimal number of iterations after the gbm model has been fit, an independent test set (`test`), out-of-bag estimation (`OOB`), and v-fold cross validation (`cv`). The function `gbm.perf` computes the iteration estimate.

Among these methods for estimating `n.trees`, **V-Fold Cross Validation** is the best choice.
> My recommendation is to use 5- or 10-fold cross validation if you can afford the computing time. Otherwise you may choose among the other options, knowing that OOB is conservative.

### Step0 - Load library and Data

In [1]:
library('survival')
library('gbm')
# set random state
set.seed(0)

"package 'gbm' was built under R version 3.5.1"Loaded gbm 2.1.4


In [2]:
data(veteran, package = "randomForestSRC")
cat("Number of samples:", nrow(veteran), "\n")
cat("Columns of dataset:", colnames(veteran), "\n")
veteran[c(1:5), ]

Number of samples: 137 
Columns of dataset: trt celltype time status karno diagtime age prior 


trt,celltype,time,status,karno,diagtime,age,prior
1,1,72,1,60,7,69,0
1,1,411,1,70,5,64,10
1,1,228,1,60,3,38,0
1,1,126,1,60,9,63,10
1,1,118,1,70,11,65,10


### Step1 - Data Preprocessing

In [3]:
# Sample the data and create a training subset.
train <- sample(1:nrow(veteran), round(nrow(veteran) * 0.80))
data_train <- veteran[train, ]
data_test <- veteran[-train, ]

### Step2 - Model Selection

### Step3 - Model Training & Evaluation

We will pass arguments to object `gbm` for training robust model after completing hyperparameters tuning in step2, and then validate our fitted model using test set.

Here, evaluation and more in this section includes:

- calculating CI metrics
- calculating survival function on specified time
- saving result as file

#### 3.0 - Model Training & Prediction

In [5]:
# using training set fits gbm model
model <- gbm(Surv(time, status) ~ .,
             distribution='coxph',
             data=data_train)
# values of loss function on training set for each tree
print(model$train.error)

  [1] 7.452676 7.411549 7.323226 7.283217 7.242676 7.213281 7.155838 7.141298
  [9] 7.117086 7.100367 7.079169 7.065628 7.046587 7.039139 7.035529 7.025156
 [17] 7.013004 6.998314 6.994975 6.980497 6.971504 6.964180 6.957446 6.955998
 [25] 6.943020 6.936833 6.931849 6.927928 6.918406 6.912547 6.909226 6.903453
 [33] 6.894153 6.899430 6.891607 6.885760 6.880229 6.875630 6.871242 6.867023
 [41] 6.859555 6.854401 6.852603 6.848217 6.843824 6.841187 6.837437 6.829515
 [49] 6.828534 6.825406 6.823117 6.821197 6.818626 6.818113 6.814419 6.812023
 [57] 6.807420 6.803212 6.800789 6.797621 6.798241 6.797777 6.796752 6.795991
 [65] 6.794354 6.790968 6.789161 6.788353 6.782757 6.778768 6.774355 6.768779
 [73] 6.766592 6.764387 6.762682 6.754721 6.753915 6.751372 6.746385 6.746444
 [81] 6.744177 6.741354 6.739005 6.736487 6.733526 6.735000 6.732901 6.728538
 [89] 6.726291 6.726717 6.728572 6.725986 6.725578 6.721763 6.726369 6.724885
 [97] 6.720713 6.720718 6.720603 6.716355


In [6]:
# return a vector of prediction on n.trees indicting log hazard scale.f(x)
pred.train <- predict(model, data_train, n.trees = 100)
pred.test <- predict(model, data_test, n.trees = 100)

### 3.1 - CI (concordance index)
We can get $\text{CI}$(concordance index) by function `rcorr.cens` from package `Hmisc`.

In [8]:
Hmisc::rcorr.cens(-pred.train, Surv(data_train$time, data_train$status))

In [7]:
Hmisc::rcorr.cens(-pred.test, Surv(data_test$time, data_test$status))

#### 3.2 - Survival function

`gbm` offers method `basehaz.gbm` to estimate the cumulative baseline hazard function $\int_0^{t}\lambda(z)dz$. Since survival function can be estimated by:
$$
s(t|X)=exp{\{-\ e^{f(X)}\int_0^{t}\lambda(z)dz\}}
$$

$f(X)$ is prediction of `gbm`, which is equal to log-hazard proportion.
So we can get survival function of individuals easily.

In [9]:
# Sepecify time of interest
time.interest <- sort(unique(data_train$time[data_train$status==1]))

In [10]:
# Estimate the cumulative baseline hazard function using training data
basehaz.cum <- basehaz.gbm(data_train$time, data_train$status, pred.train, t.eval = time.interest, cumulative = TRUE)

For individual $i$ in test set, estimation of survival function is:

In [11]:
surf.i <- exp(-exp(pred.test[1])*basehaz.cum)

In [12]:
print(surf.i)

 [1] 9.944968e-01 9.917216e-01 9.888655e-01 9.801236e-01 9.677340e-01
 [6] 9.613012e-01 9.579700e-01 9.546404e-01 9.479611e-01 9.444907e-01
[11] 9.409351e-01 9.300652e-01 9.259765e-01 9.176874e-01 9.131956e-01
[16] 9.040865e-01 8.945661e-01 8.894424e-01 8.842690e-01 8.733314e-01
[21] 8.623517e-01 8.568161e-01 8.511813e-01 8.453023e-01 8.393870e-01
[26] 8.334773e-01 8.275762e-01 8.214120e-01 8.146979e-01 7.933158e-01
[31] 7.776145e-01 7.694756e-01 7.532388e-01 7.449990e-01 7.367171e-01
[36] 7.277190e-01 7.186409e-01 7.091510e-01 6.899998e-01 6.792196e-01
[41] 6.680370e-01 6.567616e-01 6.451084e-01 6.333093e-01 6.083990e-01
[46] 5.953123e-01 5.815242e-01 5.676897e-01 5.540396e-01 5.273443e-01
[51] 5.117445e-01 4.962266e-01 4.807583e-01 4.652502e-01 4.496305e-01
[56] 4.339541e-01 4.168827e-01 3.988288e-01 3.802071e-01 3.616559e-01
[61] 3.429259e-01 3.226654e-01 3.020986e-01 2.815642e-01 2.610981e-01
[66] 2.394407e-01 2.176524e-01 1.962988e-01 1.740657e-01 1.517649e-01
[71] 1.314447e-01 1.

Estimation of survival rate of all at specified time is:

In [13]:
specif.time <- time.interest[10]
cat("Survival Rate of all at time", specif.time, "\n")
surv.rate <- exp(-exp(pred.test)*basehaz.cum[10])
print(surv.rate)

Survival Rate of all at time 15 [1] 0.9444907 0.9675000 0.8151207 0.8625493 0.9496626 0.9587409 0.6214935
 [8] 0.9418253 0.9042441 0.9435269 0.9338808 0.7847206 0.9073920 0.9709306
[15] 0.9841972 0.9329900 0.4770597 0.5935586 0.9068123 0.9495933 0.9199242
[22] 0.9396928 0.5367361 0.8560934 0.9410605 0.9146667 0.9396928


#### 3.3 - Saving as file

Here, we concate test data and prediction, survival rate, and then convert it to csv file.

In [14]:
res_test <- data_test
# predicted outcome for test set
res_test$pred <- pred.test
res_test$survival_rate <- surv.rate
# Save data
write.csv(res_test, file = "result_gbm.csv")