Skip to content

Commit

Permalink
Remove duplication.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 28, 2021
1 parent 73044cc commit 76b8e8b
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,11 +1021,25 @@ def _infer_direct_predict_output(
return test_predt.shape, meta


async def _get_model_future(
client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
) -> "distributed.Future":
if isinstance(model, Booster):
booster = await client.scatter(model, broadcast=True)
elif isinstance(model, dict):
booster = client.scatter(model["booster"])
elif isinstance(model, distributed.Future):
booster = model
else:
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
return booster


# pylint: disable=too-many-statements
async def _predict_async(
client: "distributed.Client",
global_config: Dict[str, Any],
model: Union[Booster, Dict],
model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection,
output_margin: bool,
missing: float,
Expand All @@ -1035,14 +1049,7 @@ async def _predict_async(
pred_interactions: bool,
validate_features: bool,
) -> _DaskCollection:
if isinstance(model, Booster):
_booster = await client.scatter(model, broadcast=True)
elif isinstance(model, dict):
_booster = client.scatter(model["booster"])
elif isinstance(model, distributed.Future):
_booster = model
else:
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
_booster = await _get_model_future(client, model)
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))

Expand Down Expand Up @@ -1224,21 +1231,14 @@ def predict( # pylint: disable=unused-argument
async def _inplace_predict_async(
client: "distributed.Client",
global_config: Dict[str, Any],
model: Union[Booster, Dict],
model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = 'value',
missing: float = numpy.nan
) -> _DaskCollection:
client = _xgb_get_client(client)
if isinstance(model, Booster):
booster = await client.scatter(model, broadcast=True)
elif isinstance(model, dict):
booster = await client.scatter(model['booster'])
elif isinstance(model, distributed.Future):
booster = model
else:
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
booster = await _get_model_future(client, model)
if not isinstance(data, (da.Array, dd.DataFrame)):
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))

Expand Down

0 comments on commit 76b8e8b

Please sign in to comment.