Skip to content

Commit

Permalink
[ci] [python-package] check for untyped definitions with mypy (#6339)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Mar 4, 2024
1 parent 1a292f8 commit b27d81e
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 33 deletions.
1 change: 1 addition & 0 deletions .ci/test.sh
Expand Up @@ -74,6 +74,7 @@ if [[ $TASK == "lint" ]]; then
${CONDA_PYTHON_REQUIREMENT} \
cmakelint \
cpplint \
'matplotlib>=3.8.3' \
mypy \
'pre-commit>=3.6.0' \
'pyarrow>=14.0' \
Expand Down
22 changes: 11 additions & 11 deletions python-package/lightgbm/basic.py
Expand Up @@ -13,7 +13,7 @@
from os.path import getsize
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union

import numpy as np
import scipy.sparse
Expand Down Expand Up @@ -537,13 +537,13 @@ def _param_dict_to_str(data: Optional[Dict[str, Any]]) -> str:
class _TempFile:
"""Proxy class to workaround errors on Windows."""

def __enter__(self):
def __enter__(self) -> "_TempFile":
with NamedTemporaryFile(prefix="lightgbm_tmp_", delete=True) as f:
self.name = f.name
self.path = Path(self.name)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self.path.is_file():
self.path.unlink()

Expand Down Expand Up @@ -595,7 +595,7 @@ def _get_all_param_aliases() -> Dict[str, List[str]]:
)

@classmethod
def get(cls, *args) -> Set[str]:
def get(cls, *args: str) -> Set[str]:
if cls.aliases is None:
cls.aliases = cls._get_all_param_aliases()
ret = set()
Expand All @@ -610,7 +610,7 @@ def get_sorted(cls, name: str) -> List[str]:
return cls.aliases.get(name, [name])

@classmethod
def get_by_alias(cls, *args) -> Set[str]:
def get_by_alias(cls, *args: str) -> Set[str]:
if cls.aliases is None:
cls.aliases = cls._get_all_param_aliases()
ret = set(args)
Expand Down Expand Up @@ -1563,7 +1563,7 @@ def __inner_predict_sparse_csc(
start_iteration: int,
num_iteration: int,
predict_type: int,
):
) -> Tuple[Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]], int]:
ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr)
ptr_data, type_ptr_data, _ = _c_float_array(csc.data)
csc_indices = csc.indices.astype(np.int32, copy=False)
Expand Down Expand Up @@ -1813,7 +1813,7 @@ def __init__(
self._need_slice = True
self._predictor: Optional[_InnerPredictor] = None
self.pandas_categorical: Optional[List[List]] = None
self._params_back_up = None
self._params_back_up: Optional[Dict[str, Any]] = None
self.version = 0
self._start_row = 0 # Used when pushing rows one by one.

Expand Down Expand Up @@ -2195,7 +2195,7 @@ def _lazy_init(
return self.set_feature_name(feature_name)

@staticmethod
def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]):
def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]) -> Iterator[np.ndarray]:
offset = 0
seq_id = 0
seq = seqs[seq_id]
Expand Down Expand Up @@ -2697,7 +2697,7 @@ def _update_params(self, params: Optional[Dict[str, Any]]) -> "Dataset":
return self
params = deepcopy(params)

def update():
def update() -> None:
if not self.params:
self.params = params
else:
Expand Down Expand Up @@ -3704,7 +3704,7 @@ def __del__(self) -> None:
def __copy__(self) -> "Booster":
return self.__deepcopy__(None)

def __deepcopy__(self, _) -> "Booster":
def __deepcopy__(self, *args: Any, **kwargs: Any) -> "Booster":
model_str = self.model_to_string(num_iteration=-1)
return Booster(model_str=model_str)

Expand Down Expand Up @@ -4757,7 +4757,7 @@ def refit(
dataset_params: Optional[Dict[str, Any]] = None,
free_raw_data: bool = True,
validate_features: bool = False,
**kwargs,
**kwargs: Any,
) -> "Booster":
"""Refit the existing Booster by new data.
Expand Down
32 changes: 16 additions & 16 deletions python-package/lightgbm/compat.py
@@ -1,7 +1,7 @@
# coding: utf-8
"""Compatibility library."""

from typing import List
from typing import Any, List

"""pandas"""
try:
Expand All @@ -20,19 +20,19 @@
class pd_Series: # type: ignore
"""Dummy class for pandas.Series."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass

class pd_DataFrame: # type: ignore
"""Dummy class for pandas.DataFrame."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass

class pd_CategoricalDtype: # type: ignore
"""Dummy class for pandas.CategoricalDtype."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass

concat = None
Expand All @@ -45,7 +45,7 @@ def __init__(self, *args, **kwargs):
class np_random_Generator: # type: ignore
"""Dummy class for np.random.Generator."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass


Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(self, *args, **kwargs):
class dt_DataTable: # type: ignore
"""Dummy class for datatable.DataTable."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass


Expand All @@ -104,7 +104,7 @@ def __init__(self, *args, **kwargs):
from sklearn.utils.validation import check_consistent_length

# dummy function to support older version of scikit-learn
def _check_sample_weight(sample_weight, X, dtype=None):
def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
check_consistent_length(sample_weight, X)
return sample_weight

Expand Down Expand Up @@ -176,31 +176,31 @@ class _LGBMRegressorBase: # type: ignore
class Client: # type: ignore
"""Dummy class for dask.distributed.Client."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass

class Future: # type: ignore
"""Dummy class for dask.distributed.Future."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass

class dask_Array: # type: ignore
"""Dummy class for dask.array.Array."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass

class dask_DataFrame: # type: ignore
"""Dummy class for dask.dataframe.DataFrame."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass

class dask_Series: # type: ignore
"""Dummy class for dask.dataframe.Series."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass


Expand All @@ -222,19 +222,19 @@ def __init__(self, *args, **kwargs):
class pa_Array: # type: ignore
"""Dummy class for pa.Array."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass

class pa_ChunkedArray: # type: ignore
"""Dummy class for pa.ChunkedArray."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass

class pa_Table: # type: ignore
"""Dummy class for pa.Table."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass

class arrow_cffi: # type: ignore
Expand All @@ -245,7 +245,7 @@ class arrow_cffi: # type: ignore
cast = None
new = None

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
pass

class pa_compute: # type: ignore
Expand Down
13 changes: 8 additions & 5 deletions python-package/lightgbm/plotting.py
Expand Up @@ -3,7 +3,7 @@
import math
from copy import deepcopy
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np

Expand All @@ -19,6 +19,9 @@
"plot_tree",
]

if TYPE_CHECKING:
import matplotlib


def _check_not_tuple_of_2_elements(obj: Any, obj_name: str) -> None:
"""Check object is not tuple or does not have 2 elements."""
Expand All @@ -32,7 +35,7 @@ def _float2str(value: float, precision: Optional[int]) -> str:

def plot_importance(
booster: Union[Booster, LGBMModel],
ax=None,
ax: "Optional[matplotlib.axes.Axes]" = None,
height: float = 0.2,
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
Expand Down Expand Up @@ -168,7 +171,7 @@ def plot_split_value_histogram(
booster: Union[Booster, LGBMModel],
feature: Union[int, str],
bins: Union[int, str, None] = None,
ax=None,
ax: "Optional[matplotlib.axes.Axes]" = None,
width_coef: float = 0.8,
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
Expand Down Expand Up @@ -284,7 +287,7 @@ def plot_metric(
booster: Union[Dict, LGBMModel],
metric: Optional[str] = None,
dataset_names: Optional[List[str]] = None,
ax=None,
ax: "Optional[matplotlib.axes.Axes]" = None,
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
title: Optional[str] = "Metric during training",
Expand Down Expand Up @@ -735,7 +738,7 @@ def create_tree_digraph(

def plot_tree(
booster: Union[Booster, LGBMModel],
ax=None,
ax: "Optional[matplotlib.axes.Axes]" = None,
tree_index: int = 0,
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
Expand Down
2 changes: 1 addition & 1 deletion python-package/lightgbm/sklearn.py
Expand Up @@ -478,7 +478,7 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
**kwargs,
**kwargs: Any,
):
r"""Construct a gradient boosting model.
Expand Down
1 change: 1 addition & 0 deletions python-package/pyproject.toml
Expand Up @@ -92,6 +92,7 @@ skip_glob = [
]

[tool.mypy]
disallow_untyped_defs = true
exclude = 'build/*|compile/*|docs/*|examples/*|external_libs/*|lightgbm-python/*|tests/*'
ignore_missing_imports = true

Expand Down

0 comments on commit b27d81e

Please sign in to comment.