diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index c0a8764199fd..7dbe290ae3e5 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -14,7 +14,8 @@ from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext from pyspark.sql.session import SparkSession -from xgboost import Booster, XGBModel, collective +from xgboost import Booster, XGBModel +from xgboost.collective import CommunicatorContext as CCtx from xgboost.tracker import RabitTracker @@ -42,22 +43,12 @@ def _get_default_params_from_func( return filtered_params_dict -class CommunicatorContext: - """A context controlling collective communicator initialization and finalization. - This isn't specificially necessary (note Part 3), but it is more understandable - coding-wise. - - """ +class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods + """Context with PySpark specific task ID.""" def __init__(self, context: BarrierTaskContext, **args: Any) -> None: - self.args = args - self.args["DMLC_TASK_ID"] = str(context.partitionId()) - - def __enter__(self) -> None: - collective.init(**self.args) - - def __exit__(self, *args: Any) -> None: - collective.finalize() + args["DMLC_TASK_ID"] = str(context.partitionId()) + super().__init__(**args) def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]: