-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pyspark] make the model saved by pyspark compatible #8219
Conversation
@trivialfis Could you help to check a python test failed [2022-09-05T12:59:12.999Z] =================================== FAILURES ===================================
[2022-09-05T12:59:12.999Z] ____________________________ test_gpu_data_iterator ____________________________
[2022-09-05T12:59:12.999Z]
[2022-09-05T12:59:12.999Z] cls = <class '_pytest.runner.CallInfo'> |
@WeichenXu123 @trivialfis Could you help to review this PR? |
Will look into it tomorrow. |
@WeichenXu123 @trivialfis Any feedback for this PR? |
Hi @WeichenXu123 @trivialfis, could you help to review it? |
Can we document the function |
Yeah, We can, but the issue will be the same with the one JVM package previously encountered. Most users dump the mode by the spark way, they may don't like to do another |
@@ -21,34 +21,28 @@ def _get_or_create_tmp_dir(): | |||
return xgb_tmp_dir | |||
|
|||
|
|||
def serialize_xgb_model(model): | |||
def dump_model_to_json_file(model) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the term save
. Dump has a specific meaning in XGBoost's code base.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
).write.parquet(model_save_path) | ||
model_save_path = os.path.join(path, "model") | ||
xgb_model_file = dump_model_to_json_file(xgb_model) | ||
# The json file written by Spark base on `booster.save_raw("json").decode("utf-8")` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some " \ " " in the json file which can't be loaded by xgboost. Do you want to check more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will take a look tomorrow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@trivialfis No need anymore, I just found another way to do it.
xgb_model_file = save_model_to_json_file(xgb_model) | ||
# The json file written by Spark base on `booster.save_raw("json").decode("utf-8")` | ||
# can't be loaded by XGBoost directly. | ||
_get_spark_session().read.text(xgb_model_file).write.text(model_save_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_get_spark_session().read.text(xgb_model_file).
This line is not correct.
spark.read.text(path) the path must be a distributed file system path which all spark executor can access.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use distributed FS API to copy local file xgb_model_file
into the model saved path (a hadoop FS path)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, right. you're correct, @WeichenXu123 Good findings. Could you point me to what is the "distributed FS API"? Really appreciate it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use this:
https://arrow.apache.org/docs/python/generated/pyarrow.fs.HadoopFileSystem.html
But, this does not support DBFS (databricks filesystem), we need support databricks case as well.
Databricks mount dbfs:/xxx/xxx
to local file system /dbfs/xxx/xxx
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example code in the PR description
import xgboost as xgb
bst = xgb.Booster()
# Basically, YOUR_MODEL_PATH should be like "xxxx/model/xxx.txt"
YOUR_MODEL_PATH="xxx"
bst.load_model(YOUR_MODEL_PATH)
seems does not wok ? If the path is a distributed FS path ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WeichenXu123, I use the RDD to save the text file, it should work with all kinds of hadoop-compatible FS..
Do we really need this PR ? |
Yeah, guess the scenario, the data scientist who does not know spark gets a model saved by xgboost-spark and wants to load it by xgboost python package, what does he/she can do? Although we can doc it, trust me, not everyone would like to read the whole doc carefully. Previously, XGBoost-JVM has the same issue, so I changed that. |
Hi @hcho3, what does "Pending" mean for pipelines like xgboost-ci/pr? |
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile( | ||
model_save_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting idea, but how to control the saved file name ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the booster
string contain "\n" character ? If yes, when loading back (by sparkContext.textFile(model_load_path)
, each line will become one RDD element, and these lines might be split into multiple RDD partitions)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested, and It is always part-00000
, seems there is a pattern for the generated file according to the task id since we only have 1 partition, so the id should be 00000
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's document the file name "part-00000" is the model json file.
and pls add a test to ensure the model json file does not contain \n
character and document the reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checked the code, the file name is defined by https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala#L225.
override def initWriter(taskContext: NewTaskAttemptContext, splitId: Int): Unit = {
val numfmt = NumberFormat.getInstance(Locale.US)
numfmt.setMinimumIntegerDigits(5)
numfmt.setGroupingUsed(false)
val outputName = "part-" + numfmt.format(splitId)
val path = FileOutputFormat.getOutputPath(getConf)
val fs: FileSystem = {
if (path != null) {
path.getFileSystem(getConf)
} else {
// scalastyle:off FileSystemGet
FileSystem.get(getConf)
// scalastyle:on FileSystemGet
}
}
...
here the splitId is the TaskContext.partitionId(). In our case, there is only 1 partition, so the file name is "part-00000"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I know that. My point is can we customize the file name to make it more user-friendly.
Not a must though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's the internal behavior of pyspark, not sure if it's a good idea to rely on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, If you guys insist, I can use the FileSystem java API to achieve it by py4j.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, If you guys insist, I can use the FileSystem java API to achieve it by py4j.
No need to do that, it makes code hard to maintain, your current code is fine.
The CI pipeline doesn't run until one of the admins (like me) give approval. We do this to save the CI costs. |
bst = xgb.Booster() | ||
path = glob.glob(f"{model_path}/**/model/part-00000", recursive=True)[0] | ||
bst.load_model(path) | ||
self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a test to assert model file does not include \n
char.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, per my understanding, seems we don't need to do this, since if there is "\n", the assertion must be failed self.assertEqual(model.get_booster().save_raw("json"), bst.save_raw("json"))
or bst.load_model(path)
will fail.
I will leave the approval to @WeichenXu123 . Could you please add document as well? About the get_booster and your workaround for the model serialization. |
Sure, I will add the doc in the following PR along with how to leverage RAPIDS to accelerate xgboost pyspark. @trivialfis @hcho3 could you trigger the CI of this PR |
Users can't directly load the model using xgboost python package trained by pyspark. it requires much effort to do that, see #8186. This PR first saves the model in JSON format and then writes it to txt file. Then the user can easily load the model by