Skip to content

Commit

Permalink
[python-package] fix mypy errors in Dataset construction (#6106)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Sep 22, 2023
1 parent fe7f8fe commit 7c9a985
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions python-package/lightgbm/basic.py
Expand Up @@ -24,6 +24,13 @@
if TYPE_CHECKING:
from typing import Literal

# typing.TypeGuard was only introduced in Python 3.10
try:
from typing import TypeGuard
except ImportError:
from typing_extensions import TypeGuard


__all__ = [
'Booster',
'Dataset',
Expand Down Expand Up @@ -279,6 +286,20 @@ def _is_1d_list(data: Any) -> bool:
return isinstance(data, list) and (not data or _is_numeric(data[0]))


def _is_list_of_numpy_arrays(data: Any) -> "TypeGuard[List[np.ndarray]]":
return (
isinstance(data, list)
and all(isinstance(x, np.ndarray) for x in data)
)


def _is_list_of_sequences(data: Any) -> "TypeGuard[List[Sequence]]":
return (
isinstance(data, list)
and all(isinstance(x, Sequence) for x in data)
)


def _is_1d_collection(data: Any) -> bool:
"""Check whether data is a 1-D collection."""
return (
Expand Down Expand Up @@ -1918,9 +1939,9 @@ def _lazy_init(
elif isinstance(data, np.ndarray):
self.__init_from_np2d(data, params_str, ref_dataset)
elif isinstance(data, list) and len(data) > 0:
if all(isinstance(x, np.ndarray) for x in data):
if _is_list_of_numpy_arrays(data):
self.__init_from_list_np2d(data, params_str, ref_dataset)
elif all(isinstance(x, Sequence) for x in data):
elif _is_list_of_sequences(data):
self.__init_from_seqs(data, ref_dataset)
else:
raise TypeError('Data list can only be of ndarray or Sequence')
Expand Down Expand Up @@ -2870,7 +2891,7 @@ def get_data(self) -> Optional[_LGBM_TrainDataType]:
self.data = self.data[self.used_indices, :]
elif isinstance(self.data, Sequence):
self.data = self.data[self.used_indices]
elif isinstance(self.data, list) and len(self.data) > 0 and all(isinstance(x, Sequence) for x in self.data):
elif _is_list_of_sequences(self.data) and len(self.data) > 0:
self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices)))
else:
_log_warning(f"Cannot subset {type(self.data).__name__} type of raw data.\n"
Expand Down

0 comments on commit 7c9a985

Please sign in to comment.