From b27d81ea411d04d8d071d4d4e75c19ffa15c5795 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 4 Mar 2024 12:06:08 -0600 Subject: [PATCH] [ci] [python-package] check for untyped definitions with mypy (#6339) --- .ci/test.sh | 1 + python-package/lightgbm/basic.py | 22 ++++++++++---------- python-package/lightgbm/compat.py | 32 ++++++++++++++--------------- python-package/lightgbm/plotting.py | 13 +++++++----- python-package/lightgbm/sklearn.py | 2 +- python-package/pyproject.toml | 1 + 6 files changed, 38 insertions(+), 33 deletions(-) diff --git a/.ci/test.sh b/.ci/test.sh index 1df7d53205f..79b2748e41f 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -74,6 +74,7 @@ if [[ $TASK == "lint" ]]; then ${CONDA_PYTHON_REQUIREMENT} \ cmakelint \ cpplint \ + 'matplotlib>=3.8.3' \ mypy \ 'pre-commit>=3.6.0' \ 'pyarrow>=14.0' \ diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index f78d8c35216..bb7dfb3b73e 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -13,7 +13,7 @@ from os.path import getsize from pathlib import Path from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union import numpy as np import scipy.sparse @@ -537,13 +537,13 @@ def _param_dict_to_str(data: Optional[Dict[str, Any]]) -> str: class _TempFile: """Proxy class to workaround errors on Windows.""" - def __enter__(self): + def __enter__(self) -> "_TempFile": with NamedTemporaryFile(prefix="lightgbm_tmp_", delete=True) as f: self.name = f.name self.path = Path(self.name) return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self.path.is_file(): self.path.unlink() @@ -595,7 +595,7 @@ def _get_all_param_aliases() -> Dict[str, List[str]]: ) @classmethod - def get(cls, *args) -> Set[str]: + def get(cls, *args: str) -> Set[str]: if cls.aliases is None: cls.aliases = cls._get_all_param_aliases() ret = set() @@ -610,7 +610,7 @@ def get_sorted(cls, name: str) -> List[str]: return cls.aliases.get(name, [name]) @classmethod - def get_by_alias(cls, *args) -> Set[str]: + def get_by_alias(cls, *args: str) -> Set[str]: if cls.aliases is None: cls.aliases = cls._get_all_param_aliases() ret = set(args) @@ -1563,7 +1563,7 @@ def __inner_predict_sparse_csc( start_iteration: int, num_iteration: int, predict_type: int, - ): + ) -> Tuple[Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]], int]: ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr) ptr_data, type_ptr_data, _ = _c_float_array(csc.data) csc_indices = csc.indices.astype(np.int32, copy=False) @@ -1813,7 +1813,7 @@ def __init__( self._need_slice = True self._predictor: Optional[_InnerPredictor] = None self.pandas_categorical: Optional[List[List]] = None - self._params_back_up = None + self._params_back_up: Optional[Dict[str, Any]] = None self.version = 0 self._start_row = 0 # Used when pushing rows one by one. @@ -2195,7 +2195,7 @@ def _lazy_init( return self.set_feature_name(feature_name) @staticmethod - def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]): + def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]) -> Iterator[np.ndarray]: offset = 0 seq_id = 0 seq = seqs[seq_id] @@ -2697,7 +2697,7 @@ def _update_params(self, params: Optional[Dict[str, Any]]) -> "Dataset": return self params = deepcopy(params) - def update(): + def update() -> None: if not self.params: self.params = params else: @@ -3704,7 +3704,7 @@ def __del__(self) -> None: def __copy__(self) -> "Booster": return self.__deepcopy__(None) - def __deepcopy__(self, _) -> "Booster": + def __deepcopy__(self, *args: Any, **kwargs: Any) -> "Booster": model_str = self.model_to_string(num_iteration=-1) return Booster(model_str=model_str) @@ -4757,7 +4757,7 @@ def refit( dataset_params: Optional[Dict[str, Any]] = None, free_raw_data: bool = True, validate_features: bool = False, - **kwargs, + **kwargs: Any, ) -> "Booster": """Refit the existing Booster by new data. diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 086c6a199ff..965dd332525 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -1,7 +1,7 @@ # coding: utf-8 """Compatibility library.""" -from typing import List +from typing import Any, List """pandas""" try: @@ -20,19 +20,19 @@ class pd_Series: # type: ignore """Dummy class for pandas.Series.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass class pd_DataFrame: # type: ignore """Dummy class for pandas.DataFrame.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass class pd_CategoricalDtype: # type: ignore """Dummy class for pandas.CategoricalDtype.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass concat = None @@ -45,7 +45,7 @@ def __init__(self, *args, **kwargs): class np_random_Generator: # type: ignore """Dummy class for np.random.Generator.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass @@ -80,7 +80,7 @@ def __init__(self, *args, **kwargs): class dt_DataTable: # type: ignore """Dummy class for datatable.DataTable.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass @@ -104,7 +104,7 @@ def __init__(self, *args, **kwargs): from sklearn.utils.validation import check_consistent_length # dummy function to support older version of scikit-learn - def _check_sample_weight(sample_weight, X, dtype=None): + def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any: check_consistent_length(sample_weight, X) return sample_weight @@ -176,31 +176,31 @@ class _LGBMRegressorBase: # type: ignore class Client: # type: ignore """Dummy class for dask.distributed.Client.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass class Future: # type: ignore """Dummy class for dask.distributed.Future.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass class dask_Array: # type: ignore """Dummy class for dask.array.Array.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass class dask_DataFrame: # type: ignore """Dummy class for dask.dataframe.DataFrame.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass class dask_Series: # type: ignore """Dummy class for dask.dataframe.Series.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass @@ -222,19 +222,19 @@ def __init__(self, *args, **kwargs): class pa_Array: # type: ignore """Dummy class for pa.Array.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass class pa_ChunkedArray: # type: ignore """Dummy class for pa.ChunkedArray.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass class pa_Table: # type: ignore """Dummy class for pa.Table.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass class arrow_cffi: # type: ignore @@ -245,7 +245,7 @@ class arrow_cffi: # type: ignore cast = None new = None - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): pass class pa_compute: # type: ignore diff --git a/python-package/lightgbm/plotting.py b/python-package/lightgbm/plotting.py index 76f4e0deef4..9bcc1b928ff 100644 --- a/python-package/lightgbm/plotting.py +++ b/python-package/lightgbm/plotting.py @@ -3,7 +3,7 @@ import math from copy import deepcopy from io import BytesIO -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -19,6 +19,9 @@ "plot_tree", ] +if TYPE_CHECKING: + import matplotlib + def _check_not_tuple_of_2_elements(obj: Any, obj_name: str) -> None: """Check object is not tuple or does not have 2 elements.""" @@ -32,7 +35,7 @@ def _float2str(value: float, precision: Optional[int]) -> str: def plot_importance( booster: Union[Booster, LGBMModel], - ax=None, + ax: "Optional[matplotlib.axes.Axes]" = None, height: float = 0.2, xlim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None, @@ -168,7 +171,7 @@ def plot_split_value_histogram( booster: Union[Booster, LGBMModel], feature: Union[int, str], bins: Union[int, str, None] = None, - ax=None, + ax: "Optional[matplotlib.axes.Axes]" = None, width_coef: float = 0.8, xlim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None, @@ -284,7 +287,7 @@ def plot_metric( booster: Union[Dict, LGBMModel], metric: Optional[str] = None, dataset_names: Optional[List[str]] = None, - ax=None, + ax: "Optional[matplotlib.axes.Axes]" = None, xlim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None, title: Optional[str] = "Metric during training", @@ -735,7 +738,7 @@ def create_tree_digraph( def plot_tree( booster: Union[Booster, LGBMModel], - ax=None, + ax: "Optional[matplotlib.axes.Axes]" = None, tree_index: int = 0, figsize: Optional[Tuple[float, float]] = None, dpi: Optional[int] = None, diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 5e0d51f4546..0b4c9993365 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -478,7 +478,7 @@ def __init__( random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None, n_jobs: Optional[int] = None, importance_type: str = "split", - **kwargs, + **kwargs: Any, ): r"""Construct a gradient boosting model. diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index 5e75edb9005..0f07c897853 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -92,6 +92,7 @@ skip_glob = [ ] [tool.mypy] +disallow_untyped_defs = true exclude = 'build/*|compile/*|docs/*|examples/*|external_libs/*|lightgbm-python/*|tests/*' ignore_missing_imports = true