From 9a7272cd4222383a5b932b0083a4cc173fda44e8 Mon Sep 17 00:00:00 2001 From: Freddy Heppell Date: Thu, 22 Dec 2022 16:32:49 +0000 Subject: [PATCH] Raise error if ClassLabel names is not python list (#5359) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Raise error if ClassLabel names is not python list * Change to accepting Sequence for names * Replace `ValueError` with `TypeError` Co-authored-by: Mario Šaško --- src/datasets/features/features.py | 3 +++ tests/features/test_features.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 6fd1353ccc3..045b3dae39b 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -19,6 +19,7 @@ import re import sys from collections.abc import Iterable, Mapping +from collections.abc import Sequence as SequenceABC from dataclasses import InitVar, dataclass, field, fields from functools import reduce, wraps from operator import mul @@ -944,6 +945,8 @@ def __post_init__(self, num_classes, names_file): self.names = [str(i) for i in range(self.num_classes)] else: raise ValueError("Please provide either num_classes, names or names_file.") + elif not isinstance(self.names, SequenceABC): + raise TypeError(f"Please provide names as a list, is {type(self.names)}") # Set self.num_classes if self.num_classes is None: self.num_classes = len(self.names) diff --git a/tests/features/test_features.py b/tests/features/test_features.py index e0803949032..d036e3295c7 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -287,6 +287,8 @@ def test_classlabel_init(tmp_path_factory): classlabel = ClassLabel(names=names, names_file=names_file) with pytest.raises(ValueError): classlabel = ClassLabel() + with pytest.raises(TypeError): + classlabel = ClassLabel(names=np.array(names)) def test_classlabel_str2int():