Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Accelerated Failure Time loss for survival analysis task #4763

Open
wants to merge 49 commits into
base: master
from

Conversation

@avinashbarnwal
Copy link

avinashbarnwal commented Aug 12, 2019

Hi,

Please find the Accelerated Failure time loss for Survival Modeling.

Survival analysis is a "censored regression" where the goal is to learn time-to-event function. This is similar to the common regression analysis where data-points are uncensored. Time-to-event modeling is critical for understanding users/companies behaviors not limited to credit, cancer, and attrition risks.

Supports

  • 4 kinds of datasets - Left, Right, Interval Censored and Uncensored.
  • Normal, Logistic and Extreme Distributions for underlying error distribution.

This project is part of the Google Summer of Code - 2019. AFT-Xgboost

Relevant Documents

Example in Python to run -


res    = {}
dtrain = xgboost.DMatrix(X)
dtrain.set_float_info("label_lower_bound",y_lower)
dtrain.set_float_info("label_upper_bound",y_higher)

dtest  = xgboost.DMatrix(X_val)
dtest.set_float_info("label_lower_bound",y_lower_val)
dtest.set_float_info("label_upper_bound",y_higher_val)

bst    = xgboost.train(params,dtrain,num_boost_round=100,evals=[(dtrain,"train"),(dtest,"test")],evals_result=res)

params = {'learning_rate':0.1, 'aft_noise_distribution' : 'normal', 'aft_sigma': 1.0,'eval_metric':'aft-nloglik@normal,1.0','objective':"aft:survival"}

bst    = xgboost.train(params,dtrain,num_boost_round=100,evals=[(dtrain,"train"),(dtest,"test")],evals_result=res)

For more details - avinashbarnwal#1.

src/common/survival_util.cc Outdated Show resolved Hide resolved
@hcho3 hcho3 changed the title Survival analysis1 Add Accelerated Failure Time loss for survival analysis task Aug 12, 2019
@hcho3 hcho3 self-assigned this Aug 12, 2019
@hcho3 hcho3 requested review from trivialfis and hcho3 and removed request for trivialfis Aug 12, 2019
src/common/survival_util.cc Outdated Show resolved Hide resolved
@hcho3

This comment has been minimized.

Copy link
Collaborator

hcho3 commented Aug 12, 2019

@avinashbarnwal In the PR description, can you add a short one-paragraph description of what survival analysis is? Something like:

survival analysis is a new kind of learning task where we would like to predict a time to certain event. The time-to-event labels are often censored, i.e. we only know which intervals the label falls in and do not know its exact value. See https://eng.uber.com/modeling-censored-time-to-event-data-using-pyro/ for a real-world example.

src/common/survival_util.cc Outdated Show resolved Hide resolved
@tdhock

This comment has been minimized.

Copy link

tdhock commented Aug 12, 2019

survival analysis is a new kind of learning task where we would like to predict a time to certain event

I would describe it as "censored regression" or more specifically "regression with censored outputs" because the goal is still to learn a (real-valued) regression function; this emphasizes the similarity with usual regression, where all outputs are un-censored.

@hcho3

This comment has been minimized.

Copy link
Collaborator

hcho3 commented Aug 12, 2019

Also add a short Python example to the description:

dtrain = xgboost.DMatrix(X)
dtrain.set_float_info("label_lower_bound", y_lower)
dtrain.set_float_info("label_upper_bound", y_higher)
    
dtest = xgboost.DMatrix(X_test)
dtest.set_float_info("label_lower_bound", y_lower_test)
dtest.set_float_info("label_upper_bound", y_higher_test)
    
bst = xgboost.train(params, dtrain, num_boost_round=100,
                    evals=[(dtrain,"train"), (dtest,"test")])
@hcho3 hcho3 added the status: WIP label Aug 12, 2019
@hcho3

This comment has been minimized.

Copy link
Collaborator

hcho3 commented Aug 12, 2019

@tdhock Thanks for your suggestion. Yes, "censored regression" sounds reasonable.

src/common/survival_util.cc Outdated Show resolved Hide resolved
@trivialfis

This comment has been minimized.

Copy link
Member

trivialfis commented Aug 14, 2019

I'm not familiar with survival models, just skimmed through the survey. Are there other recommended materials concentrating on theoretical part? ;-)

@avinashbarnwal

This comment has been minimized.

Copy link
Author

avinashbarnwal commented Aug 14, 2019

Hi @trivialfis,

Please find the good lecture notes for learning survival modeling - https://www4.stat.ncsu.edu/~dzhang2/st745/index.html.

One of the motivating books- https://www.amazon.com/Applied-Survival-Analysis-Time-Event/dp/0471754994.

Prof. @tdhock and @hcho3 might give a better reference for understanding theoretical survival modeling.

@tdhock

This comment has been minimized.

Copy link

tdhock commented Aug 14, 2019

would be good if @avinashbarnwal could write a latex/PDF vignette in the xgboost R pkg describing the loss functions that he implemented

@tdhock

This comment has been minimized.

Copy link

tdhock commented Aug 14, 2019

they are the same as in R's survival::survreg, there are some docs on that man page, but the math formulas come from http://members.cbio.mines-paristech.fr/~thocking/survival.pdf

@trivialfis

This comment has been minimized.

Copy link
Member

trivialfis commented Aug 14, 2019

@avinashbarnwal @tdhock Thanks for the good references. Will try to catch up.

@avinashbarnwal

This comment has been minimized.

Copy link
Author

avinashbarnwal commented Aug 14, 2019

Hi Prof. @tdhock and @hcho3,

I will start writing loss functions in latex/PDF vignette for the xgboost R pkg.

@avinashbarnwal

This comment has been minimized.

Copy link
Author

avinashbarnwal commented Aug 15, 2019

Hi Prof. @tdhock,

Please let me know if it is fine to make the vignette-like this https://cran.r-project.org/web/packages/xgboost/vignettes/xgboostfromJSON.html.

@tdhock

This comment has been minimized.

Copy link

tdhock commented Aug 15, 2019

typically for vignettes with lots of math I prefer writing Rnw source which is rendered to tex / pdf. It is possible to include simple math in Rmd which is rendered on a web page using mathjax, but in my experience complex equations (e.g. optimization problems) do not render well on web pages.

Examples of both are here: https://github.com/tdhock/PeakSegDisk/tree/master/vignettes

include/xgboost/data.h Outdated Show resolved Hide resolved
@avinashbarnwal

This comment has been minimized.

Copy link
Author

avinashbarnwal commented Aug 24, 2019

Hi Prof. @tdhock and @hcho3,

Please find R-vignette below and let me know your thoughts.
http://rpubs.com/avinashbarnwal123/aft

@hcho3

This comment has been minimized.

Copy link
Collaborator

hcho3 commented Aug 26, 2019

@tdhock Do the datasets follow log-normal AFT distribution? The errors are not decreasing when we choose log-logistic and log-weibull. See http://rpubs.com/avinashbarnwal123/aft

@tdhock

This comment has been minimized.

Copy link

tdhock commented Aug 26, 2019

@avinashbarnwal

This comment has been minimized.

Copy link
Author

avinashbarnwal commented Aug 26, 2019

Hi Prof. @tdhock and @hcho3,

I have updated the vignette - http://rpubs.com/avinashbarnwal123/aft. It works for the last dataset -
H3K36me3_AM_immune. Please check last fold. This might be not clear because of the scale. It works for both Logistic and Extreme. I think we need datasets like that where it works.

avinashbarnwal and others added 9 commits Aug 13, 2019
@hcho3 hcho3 force-pushed the avinashbarnwal:survival_analysis1 branch from 5b6d707 to cf272f5 Sep 19, 2019
@hcho3

This comment has been minimized.

Copy link
Collaborator

hcho3 commented Sep 19, 2019

Rebased against the latest master. I'll address @RAMitchell's remark soon, by putting y_lower and y_upper in a std::map<std::string, HostDeviceVector>

hcho3 added 2 commits Sep 19, 2019
@hcho3 hcho3 changed the title Add Accelerated Failure Time loss for survival analysis task [WIP] Add Accelerated Failure Time loss for survival analysis task Sep 19, 2019
hcho3 added 3 commits Sep 19, 2019
Copy link
Member

trivialfis left a comment

Looking forward to it. Still haven't got the time to learn the materials. I may need to work harder ...

}

double AFTLoss::Hessian(double y_lower, double y_higher, double y_pred, double sigma) {
double z;

This comment has been minimized.

Copy link
@trivialfis

trivialfis Sep 25, 2019

Member

Please add a summarization comment for these symbols, and link/names corresponding material of papers/books/notes ...

This comment has been minimized.

Copy link
@avinashbarnwal

avinashbarnwal Sep 25, 2019

Author

Yes. I will add.

src/common/survival_util.h Show resolved Hide resolved

struct AFTParam : public dmlc::Parameter<AFTParam> {
AFTDistributionType aft_noise_distribution;
float aft_sigma;

This comment has been minimized.

Copy link
@trivialfis

trivialfis Sep 25, 2019

Member

This is a personal preference, gamma looks good on paper, but in code it doesn't ... I would go with something more specific like distribution_scale_factor (still not a good name) and document it with the gamma parameter used in xxx paper.

This comment has been minimized.

Copy link
@avinashbarnwal

avinashbarnwal Sep 25, 2019

Author

Did you mean aft_sigma?

src/common/survival_util.h Show resolved Hide resolved
hcho3 added 2 commits Oct 3, 2019
…erface
@hcho3

This comment has been minimized.

Copy link
Collaborator

hcho3 commented Oct 4, 2019

@RAMitchell Can you take a look at the change I made to MetaInfo? Now we have a generic key-value store for 1D vectors.

@hcho3 hcho3 changed the title [WIP] Add Accelerated Failure Time loss for survival analysis task Add Accelerated Failure Time loss for survival analysis task Oct 6, 2019
@trivialfis

This comment has been minimized.

Copy link
Member

trivialfis commented Dec 12, 2019

Feel free to reach me for any needed help here. I spent some time with ngboost, might overlaps with this one. Will dig into these interesting things once next release is sorted out. ;-)

@hcho3

This comment has been minimized.

Copy link
Collaborator

hcho3 commented Dec 12, 2019

@trivialfis This is blocked by the binary MetaInfo format. I need to get to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
5 participants
You can’t perform that action at this time.