From 98541310569a8a54c1671faa76df4bce066b488d Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 8 Sep 2022 08:26:11 +0800 Subject: [PATCH] update Signed-off-by: Weichen Xu --- python-package/xgboost/spark/core.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index edff40349676..302bd709ff45 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -20,7 +20,9 @@ HasWeightCol, ) from pyspark.ml.util import MLReadable, MLWritable -from pyspark.sql.functions import col, countDistinct, pandas_udf, struct +from pyspark.sql.functions import ( + col, countDistinct, pandas_udf, struct, monotonically_increasing_id +) from pyspark.sql.types import ( ArrayType, DoubleType, @@ -270,15 +272,6 @@ def _validate_params(self): f"It cannot be less than 1 [Default is 1]" ) - if ( - self.getOrDefault(self.force_repartition) - and self.getOrDefault(self.num_workers) == 1 - ): - get_logger(self.__class__.__name__).warning( - "You set force_repartition to true when there is no need for a repartition." - "Therefore, that parameter will be ignored." - ) - if self.getOrDefault(self.features_cols): if not self.getOrDefault(self.use_gpu): raise ValueError("features_cols param requires enabling use_gpu.") @@ -691,7 +684,10 @@ def _fit(self, dataset): ) if self._repartition_needed(dataset): - dataset = dataset.repartition(num_workers) + # Repartition on `monotonically_increasing_id` column to avoid repartition + # result unbalance. Directly using `.repartition(N)` might result in some + # empty partitions. + dataset = dataset.repartition(num_workers, monotonically_increasing_id()) train_params = self._get_distributed_train_params(dataset) booster_params, train_call_kwargs_params = self._get_xgb_train_call_args( train_params