-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Accelerated Failure Time loss for survival analysis task #4763
Conversation
@avinashbarnwal In the PR description, can you add a short one-paragraph description of what survival analysis is? Something like: survival analysis is a new kind of learning task where we would like to predict a time to certain event. The time-to-event labels are often censored, i.e. we only know which intervals the label falls in and do not know its exact value. See https://eng.uber.com/modeling-censored-time-to-event-data-using-pyro/ for a real-world example. |
I would describe it as "censored regression" or more specifically "regression with censored outputs" because the goal is still to learn a (real-valued) regression function; this emphasizes the similarity with usual regression, where all outputs are un-censored. |
Also add a short Python example to the description: dtrain = xgboost.DMatrix(X)
dtrain.set_float_info("label_lower_bound", y_lower)
dtrain.set_float_info("label_upper_bound", y_higher)
dtest = xgboost.DMatrix(X_test)
dtest.set_float_info("label_lower_bound", y_lower_test)
dtest.set_float_info("label_upper_bound", y_higher_test)
bst = xgboost.train(params, dtrain, num_boost_round=100,
evals=[(dtrain,"train"), (dtest,"test")]) |
@tdhock Thanks for your suggestion. Yes, "censored regression" sounds reasonable. |
I'm not familiar with survival models, just skimmed through the survey. Are there other recommended materials concentrating on theoretical part? ;-) |
Hi @trivialfis, Please find the good lecture notes for learning survival modeling - https://www4.stat.ncsu.edu/~dzhang2/st745/index.html. One of the motivating books- https://www.amazon.com/Applied-Survival-Analysis-Time-Event/dp/0471754994. Prof. @tdhock and @hcho3 might give a better reference for understanding theoretical survival modeling. |
would be good if @avinashbarnwal could write a latex/PDF vignette in the xgboost R pkg describing the loss functions that he implemented |
they are the same as in R's survival::survreg, there are some docs on that man page, but the math formulas come from http://members.cbio.mines-paristech.fr/~thocking/survival.pdf |
@avinashbarnwal @tdhock Thanks for the good references. Will try to catch up. |
Hi Prof. @tdhock, Please let me know if it is fine to make the vignette-like this https://cran.r-project.org/web/packages/xgboost/vignettes/xgboostfromJSON.html. |
typically for vignettes with lots of math I prefer writing Rnw source which is rendered to tex / pdf. It is possible to include simple math in Rmd which is rendered on a web page using mathjax, but in my experience complex equations (e.g. optimization problems) do not render well on web pages. Examples of both are here: https://github.com/tdhock/PeakSegDisk/tree/master/vignettes
|
Please find R-vignette below and let me know your thoughts. |
@tdhock Do the datasets follow log-normal AFT distribution? The errors are not decreasing when we choose log-logistic and log-weibull. See http://rpubs.com/avinashbarnwal123/aft |
they are real data sets so we don't know their "true" distribution. However
in previous experience with linear models, I have observed that a loss
function with quadratic tails (like the normal distribution) works better
than linear tails (like the logistic)
…On Mon, Aug 26, 2019 at 9:36 AM Philip Hyunsu Cho ***@***.***> wrote:
@tdhock <https://github.com/tdhock> Are the datasets follow log-normal
AFT distribution? The errors are not decreasing when we choose log-logistic
and log-weibull.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#4763?email_source=notifications&email_token=AAHDX4SAKQD6MWWKEA6772DQGQBBTA5CNFSM4ILDM5N2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD5E5E2I#issuecomment-524931689>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAHDX4UO7L4EUGR6N2NAM73QGQBBTANCNFSM4ILDM5NQ>
.
|
I have updated the vignette - http://rpubs.com/avinashbarnwal123/aft. It works for the last dataset - |
Thanks. I will change the code accordingly for our paper. |
@trivialfis I added a demo, as you requested. A tutorial is available. Feel free to try it out. |
{ 0.0384f, 0.0624f, 0.0997f, 0.1551f, 0.2316f, 0.3254f, 0.4200f, 0.4861f, 0.4962f, 0.4457f, | ||
0.3567f, 0.2601f, 0.1772f, 0.1152f, 0.0726f, 0.0449f, 0.0275f, 0.0167f, 0.0101f, 0.0061f }); | ||
CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "extreme", | ||
{ -0.0000f, -29.0026f, -17.0031f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@avinashbarnwal FYI, I applied the regularization scheme to the uncensored case as well, and now I'm getting a zero gradient here, where previously we'd get something like -50.0. I'm still looking at ways to avoid INF and NAN (in general) without strange behavior like this. For this example, clamping the gradient to a reasonable quantity like -30.0 would be a lot better than giving 0.0. I'll come back to this soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve merged this PR for now. I’ll file a follow-up PR to make AFT more robust in edge cases like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! This is exciting.
Merged. Thanks everyone! |
Hi,
Please find the Accelerated Failure time loss for Survival Modeling.
Survival analysis is a "censored regression" where the goal is to learn time-to-event function. This is similar to the common regression analysis where data-points are uncensored. Time-to-event modeling is critical for understanding users/companies behaviors not limited to credit, cancer, and attrition risks.
Supports
This project is part of the Google Summer of Code - 2019. AFT-Xgboost
Compact summary of AFT loss formula
Relevant Documents
Example in Python to run -
For more details - avinashbarnwal#1.
Note: as part of this PR, the
Metric
class became a subclass of theConfigurable
interface.