Skip to content

Commit

Permalink
mypy.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 28, 2021
1 parent 4f58abb commit 4d31d58
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def _infer_direct_predict_output(
if isinstance(data, DaskDMatrix):
features = data.num_col()
else:
features = data.shape[1]
features = data.shape[1] # type:ignore
rng = numpy.random.RandomState(1994)
test_sample = rng.randn(1, features)
if inplace:
Expand Down Expand Up @@ -1183,7 +1183,7 @@ def dispatched_predict(booster: Booster, part: Any) -> numpy.ndarray:

def predict( # pylint: disable=unused-argument
client: "distributed.Client",
model: Union[TrainReturnT, Booster],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: Union[DaskDMatrix, _DaskCollection],
output_margin: bool = False,
missing: float = numpy.nan,
Expand Down Expand Up @@ -1283,7 +1283,7 @@ def mapped_predict(

def inplace_predict( # pylint: disable=unused-argument
client: "distributed.Client",
model: Union[TrainReturnT, Booster],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = 'value',
Expand Down Expand Up @@ -1392,7 +1392,7 @@ async def _get_booster_future(self) -> "distributed.Future":
# pylint: disable=access-member-before-definition
if not hasattr(self, "_booster_future") or self._booster_future_id != id(booster):
self._booster_future = await self.client.scatter(booster)
self._booster_future_id = id(booster)
self._booster_future_id: int = id(booster)
return self._booster_future

def __await__(self) -> Awaitable[Any]:
Expand All @@ -1401,7 +1401,7 @@ async def _() -> Awaitable[Any]:
return self
return self.client.sync(_).__await__()

def __getstate__(self):
def __getstate__(self) -> Dict:
this = self.__dict__.copy()
if "_client" in this.keys():
del this["_client"]
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") ->
rows = X.shape[0]
cols = X.shape[1]

def assert_shape(shape):
def assert_shape(shape: Tuple[int, ...]) -> None:
assert shape[0] == rows
if "num_class" in params.keys():
assert shape[1] == params["num_class"]
Expand Down

0 comments on commit 4d31d58

Please sign in to comment.