Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] Add the new device parameter. #9385

Merged
merged 7 commits into from
Jul 17, 2023

Conversation

trivialfis
Copy link
Member

@trivialfis trivialfis commented Jul 14, 2023

Related #7308 .

This PR is to add the device parameter for the Spark package. For the core implementation PR, please see #9362 .

@trivialfis
Copy link
Member Author

cc @wbo4958 .

@trivialfis trivialfis changed the title [WIP] [jvm-packages] Add the new device parameter. [jvm-packages] Add the new device parameter. Jul 15, 2023
@trivialfis
Copy link
Member Author

@dotbg Could you please help take a look into the PR when you are available? I'm not an expert in Scala/Spark.

@dotbg
Copy link
Contributor

dotbg commented Jul 16, 2023

@trivialfis At the first glance the code looks ok. I wonder whether the gpu is a better name for the device option.

@trivialfis
Copy link
Member Author

@dotbg gpu is available as well, and it's currently equivalent to cuda. We keep it as a placeholder for the future if anyone were to have implementation for different GPU devices.

@@ -77,7 +77,8 @@ public void testBooster() throws XGBoostError {
put("objective", "binary:logistic");
put("num_round", round);
put("num_workers", 1);
put("tree_method", "gpu_hist");
put("tree_method", "hist");
put("device", "cuda");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will the put("tree_method", "gpu_hist"); also work?

@dotbg
Copy link
Contributor

dotbg commented Jul 16, 2023

I see, in this case, it may be good(but not a showstopper) to align naming in other places as well.

@trivialfis
Copy link
Member Author

will the put("tree_method", "gpu_hist"); also work?

It should work but with a warning. I will add a test tomorrow.

I see, in this case, it may be good(but not a showstopper) to align naming in other places as well.

Could you please elaborate on this? Apologies for not having providing a clearer context before. I'm currently walking through each interface #7308 (comment) . The syntax is documented in here, which was part of #9362 .

I think the flink interface doesn't support GPU yet. But I will double-check the native Java interface and the Scala interface, I think they don't have hard-coded parameters that require changes (feel free to correct me).

The R interface doesn't have hard-coded parameters, and the CRAN package doesn't support GPU.

The Python interface is mostly handled in previous PRs, I will have some more specialized handling for PySpark. The naming of parameters are consistent.

@dotbg
Copy link
Contributor

dotbg commented Jul 16, 2023

@trivialfis well, one of my concerns is that the package names will contain gpu, not cuda. It is not a huge deal if documented properly. But it would be great to document what is required if someone decides to make code OpenCL comiant

@trivialfis
Copy link
Member Author

I think this part should be fine, we have documents on how GPU support is achieved for both general XGB packages and the JVM packages, along with notes for CUDA being the only option at the moment.

@@ -137,8 +137,12 @@ object GpuPreXGBoost extends PreXGBoostProvider {
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) =
estimator match {
case est: XGBoostEstimatorCommon =>
require(est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
s"GPU train requires tree_method set to gpu_hist")
require(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wbo4958 The check is here.

@trivialfis trivialfis merged commit f4fb2be into dmlc:master Jul 17, 2023
22 checks passed
@trivialfis trivialfis deleted the device-ord-jvm branch July 17, 2023 10:40
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants