Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix R dart prediction and add test. #5204

Merged
merged 2 commits into from
Jan 16, 2020

Conversation

trivialfis
Copy link
Member

Closes #5203 .

@trivialfis trivialfis requested a review from hcho3 January 14, 2020 17:32
@Kodiologist
Copy link
Contributor

I tested this PR with my code at the issue and I'm afraid it might not suffice. First, with training = T added, I am seeing different predictions with ntreelmit set or not, as expected, but neither matches the predictions from 7b65698. Second, with training not provided, the predictions should match the predictions from 7b65698 with ntreelimit set, shouldn't they? But they don't. They're unaltered from master.

@trivialfis
Copy link
Member Author

trivialfis commented Jan 14, 2020

@Kodiologist Try to use xgb.train instead of xgboost for running on commit before the fix. I believe it's because before the PR, evaluation added by xgboost causes an extra internal prediction process, which performs drop tree.

Will keep looking into it.

@trivialfis
Copy link
Member Author

@Kodiologist Indeed, using xgb.train instead of xgboost(...) on the commit before fix:

library(xgboost)

set.seed(15)

d <- cbind(
  x1 = rnorm(100),
  x2 = rnorm(100),
  x3 = rnorm(100))

y <- d[,"x1"] + d[,"x2"]^2 +
  ifelse(d[,"x3"] > .5, d[,"x3"]^2, 2^d[,"x3"]) +
  rnorm(100)
nrounds <- 30

dtrain <- xgb.DMatrix(data=d, info = list(label=y))
params <- list(
  verbose = 0,
  booster = "dart",
  objective = "reg:squarederror",
  eval_metric = "rmse",
  rate_drop = .5,
  nthread = 1,
  one_drop = TRUE
)

fit <- xgb.train(
  params,
  data = dtrain,
  nrounds = nrounds)

xgb.save(fit, 'model_after.json')

pr <- function(...) {
  predict(fit, newdata = d, ...)
}

cat("w/o ntreelimit\n")
print(pr())
cat("with ntreelimit\n")
print(pr(ntreelimit = nrounds))
print("With training.")
print(pr(training=T))

generates:

with ntreelimit
  [1]  0.9003429  2.5287831  0.8874675  4.4098668  1.9609053  1.0055976
  [7]  0.3943650  1.7076734  6.3382840  0.2433042  4.1564407  1.3798609
 [13]  0.8874675 -0.5890894  2.9998991  2.0813766  3.1467826  1.2559140
 [19]  1.2184663  1.6247368  2.8361428  3.0794213  0.3370866  0.8874675
 [25]  3.5600305  0.6429486  1.6469479  1.1856699  0.3368608  4.3451409
 [31]  6.5794563  0.3370866  4.6248546  1.6578684  5.0521989  1.0926565
 [37]  0.8874675  1.5971428  3.1831484  1.1504483  1.5336471  1.0496976
 [43]  6.5794563  6.4258838  0.3370866  3.1878321  2.2314143  3.0794213
 [49]  5.5349374  1.6474503  3.0835233  1.7872251  1.2559140  3.3080211
 [55]  5.5349374  1.6369976  5.4840751  6.5794563  1.2103307  0.3370866
 [61]  0.8154036  0.3649135  2.2150474 -0.6399471 -0.3155120 -0.6399471
 [67]  0.3370866  6.5794563  1.4331715  2.6804214 -0.5307322 -0.6230295
 [73]  1.0664585  2.3703270  0.3943650  1.6101770  5.0521989  1.1136512
 [79]  1.1178694  4.6676064  0.8874675  3.2310014  1.7612056  1.6548371
 [85]  1.2863905  2.3352165  1.7612056  1.3605072  0.3943650  1.0489140
 [91]  0.6624517  2.3703270  2.9735966  5.5349374  1.5185604  3.3080211
 [97]  0.8874675  2.6156211  0.3625812 -0.5307322

Which is consistent with current PR.

@trivialfis
Copy link
Member Author

Also w/o ntreelimit before fix is consistent with with training after fix.

@trivialfis
Copy link
Member Author

@Kodiologist Would you like to take another look?

@Kodiologist
Copy link
Contributor

I see; thanks for explaining about xgboost vs. xgb.train. This means that I've still been dropping some trees in past analyses, even when I thought I hadn't, which is depressing, but good to know. The new results look good.

@Kodiologist
Copy link
Contributor

That xgboost and xgb.train now give the same predictions is an important improvement this PR has made, so I think you should add a test for it.

@trivialfis
Copy link
Member Author

@Kodiologist Will do. I'm sorry for causing these many troubles.

@trivialfis
Copy link
Member Author

@hcho3 ping

@hcho3
Copy link
Collaborator

hcho3 commented Jan 16, 2020

Aside: I'd like to get my feet wet in the R ecosystem. The XGBoost R package can use more love. Also noticed that large portion of work in statistics and quantitative science is done with R.

@trivialfis trivialfis merged commit 5199b86 into dmlc:master Jan 16, 2020
@trivialfis trivialfis deleted the fix-R-dart-predict branch January 16, 2020 04:11
@lock lock bot locked as resolved and limited conversation to collaborators Apr 15, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

#5115 has changed predictions
3 participants