diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 6fd1353ccc3..045b3dae39b 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -19,6 +19,7 @@ import re import sys from collections.abc import Iterable, Mapping +from collections.abc import Sequence as SequenceABC from dataclasses import InitVar, dataclass, field, fields from functools import reduce, wraps from operator import mul @@ -944,6 +945,8 @@ def __post_init__(self, num_classes, names_file): self.names = [str(i) for i in range(self.num_classes)] else: raise ValueError("Please provide either num_classes, names or names_file.") + elif not isinstance(self.names, SequenceABC): + raise TypeError(f"Please provide names as a list, is {type(self.names)}") # Set self.num_classes if self.num_classes is None: self.num_classes = len(self.names) diff --git a/tests/features/test_features.py b/tests/features/test_features.py index e0803949032..d036e3295c7 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -287,6 +287,8 @@ def test_classlabel_init(tmp_path_factory): classlabel = ClassLabel(names=names, names_file=names_file) with pytest.raises(ValueError): classlabel = ClassLabel() + with pytest.raises(TypeError): + classlabel = ClassLabel(names=np.array(names)) def test_classlabel_str2int():