Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] fix mypy errors in Dataset construction #6106

Merged
merged 4 commits into from Sep 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 24 additions & 3 deletions python-package/lightgbm/basic.py
Expand Up @@ -24,6 +24,13 @@
if TYPE_CHECKING:
from typing import Literal

# typing.TypeGuard was only introduced in Python 3.10
try:
from typing import TypeGuard
except ImportError:
from typing_extensions import TypeGuard


__all__ = [
'Booster',
'Dataset',
Expand Down Expand Up @@ -279,6 +286,20 @@ def _is_1d_list(data: Any) -> bool:
return isinstance(data, list) and (not data or _is_numeric(data[0]))


def _is_list_of_numpy_arrays(data: Any) -> "TypeGuard[List[np.ndarray]]":
return (
isinstance(data, list)
and all(isinstance(x, np.ndarray) for x in data)
)


def _is_list_of_sequences(data: Any) -> "TypeGuard[List[Sequence]]":
return (
isinstance(data, list)
and all(isinstance(x, Sequence) for x in data)
)


def _is_1d_collection(data: Any) -> bool:
"""Check whether data is a 1-D collection."""
return (
Expand Down Expand Up @@ -1918,9 +1939,9 @@ def _lazy_init(
elif isinstance(data, np.ndarray):
self.__init_from_np2d(data, params_str, ref_dataset)
elif isinstance(data, list) and len(data) > 0:
if all(isinstance(x, np.ndarray) for x in data):
if _is_list_of_numpy_arrays(data):
self.__init_from_list_np2d(data, params_str, ref_dataset)
elif all(isinstance(x, Sequence) for x in data):
elif _is_list_of_sequences(data):
self.__init_from_seqs(data, ref_dataset)
else:
raise TypeError('Data list can only be of ndarray or Sequence')
Expand Down Expand Up @@ -2870,7 +2891,7 @@ def get_data(self) -> Optional[_LGBM_TrainDataType]:
self.data = self.data[self.used_indices, :]
elif isinstance(self.data, Sequence):
self.data = self.data[self.used_indices]
elif isinstance(self.data, list) and len(self.data) > 0 and all(isinstance(x, Sequence) for x in self.data):
elif _is_list_of_sequences(self.data) and len(self.data) > 0:
self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices)))
else:
_log_warning(f"Cannot subset {type(self.data).__name__} type of raw data.\n"
Expand Down