diff --git a/src/datasets/tasks/base.py b/src/datasets/tasks/base.py index df5345912cb..13d3c8ab279 100644 --- a/src/datasets/tasks/base.py +++ b/src/datasets/tasks/base.py @@ -1,10 +1,14 @@ import abc +import dataclasses from dataclasses import dataclass -from typing import ClassVar, Dict +from typing import ClassVar, Dict, Type, TypeVar from ..features import Features +T = TypeVar("T", bound="TaskTemplate") + + @dataclass(frozen=True) class TaskTemplate(abc.ABC): task: ClassVar[str] @@ -18,9 +22,9 @@ def features(self) -> Features: @property @abc.abstractmethod def column_mapping(self) -> Dict[str, str]: - return NotImplemented + raise NotImplementedError @classmethod - @abc.abstractmethod - def from_dict(cls, template_dict: dict) -> "TaskTemplate": - return NotImplemented + def from_dict(cls: Type[T], template_dict: dict) -> T: + field_names = set(f.name for f in dataclasses.fields(cls)) + return cls(**{k: v for k, v in template_dict.items() if k in field_names}) diff --git a/src/datasets/tasks/question_answering.py b/src/datasets/tasks/question_answering.py index d07763dcb0f..145be969633 100644 --- a/src/datasets/tasks/question_answering.py +++ b/src/datasets/tasks/question_answering.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict +from typing import ClassVar, Dict from ..features import Features, Sequence, Value from .base import TaskTemplate @@ -7,9 +7,9 @@ @dataclass(frozen=True) class QuestionAnswering(TaskTemplate): - task = "question-answering" - input_schema = Features({"question": Value("string"), "context": Value("string")}) - label_schema = Features( + task: ClassVar[str] = "question-answering" + input_schema: ClassVar[Features] = Features({"question": Value("string"), "context": Value("string")}) + label_schema: ClassVar[Features] = Features( { "answers": Sequence( { @@ -23,19 +23,6 @@ class QuestionAnswering(TaskTemplate): context_column: str = "context" answers_column: str = "answers" - def __post_init__(self): - object.__setattr__(self, "question_column", self.question_column) - object.__setattr__(self, "context_column", self.context_column) - object.__setattr__(self, "answers_column", self.answers_column) - @property def column_mapping(self) -> Dict[str, str]: return {self.question_column: "question", self.context_column: "context", self.answers_column: "answers"} - - @classmethod - def from_dict(cls, template_dict: dict) -> "QuestionAnswering": - return cls( - question_column=template_dict["question_column"], - context_column=template_dict["context_column"], - answers_column=template_dict["answers_column"], - ) diff --git a/src/datasets/tasks/text_classification.py b/src/datasets/tasks/text_classification.py index 04cc26d78ef..140e7526ec5 100644 --- a/src/datasets/tasks/text_classification.py +++ b/src/datasets/tasks/text_classification.py @@ -1,30 +1,40 @@ from dataclasses import dataclass -from typing import Dict, List +from typing import ClassVar, Dict, List from ..features import ClassLabel, Features, Value from .base import TaskTemplate +class FeaturesWithLazyClassLabel: + def __init__(self, features, label_column="labels"): + assert label_column in features, f"Key '{label_column}' missing in features {features}" + self._features = features + self._label_column = label_column + + def __get__(self, obj, objtype=None): + if obj is None: + return self._features + + assert hasattr(obj, self._label_column), f"Object has no attribute '{self._label_column}'" + features = self._features.copy() + features["labels"] = ClassLabel(names=getattr(obj, self._label_column)) + return features + + @dataclass(frozen=True) class TextClassification(TaskTemplate): - task = "text-classification" - input_schema = Features({"text": Value("string")}) - # TODO(lewtun): Since we update this in __post_init__ do we need to set a default? We'll need it for __init__ so - # investigate if there's a more elegant approach. - label_schema = Features({"labels": ClassLabel}) + task: ClassVar[str] = "text-classification" + input_schema: ClassVar[Features] = Features({"text": Value("string")}) + # TODO(lewtun): Find a more elegant approach without descriptors. + label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel})) labels: List[str] text_column: str = "text" label_column: str = "labels" def __post_init__(self): - assert sorted(set(self.labels)) == sorted(self.labels), "Labels must be unique" + assert len(self.labels) == len(set(self.labels)), "Labels must be unique" # Cast labels to tuple to allow hashing - object.__setattr__(self, "labels", tuple(sorted(self.labels))) - object.__setattr__(self, "text_column", self.text_column) - object.__setattr__(self, "label_column", self.label_column) - self.label_schema["labels"] = ClassLabel(names=self.labels) - object.__setattr__(self, "label2id", {label: idx for idx, label in enumerate(self.labels)}) - object.__setattr__(self, "id2label", {idx: label for label, idx in self.label2id.items()}) + self.__dict__["labels"] = tuple(sorted(self.labels)) @property def column_mapping(self) -> Dict[str, str]: @@ -33,10 +43,10 @@ def column_mapping(self) -> Dict[str, str]: self.label_column: "labels", } - @classmethod - def from_dict(cls, template_dict: dict) -> "TextClassification": - return cls( - text_column=template_dict["text_column"], - label_column=template_dict["label_column"], - labels=template_dict["labels"], - ) + @property + def label2id(self): + return {label: idx for idx, label in enumerate(self.labels)} + + @property + def id2label(self): + return {idx: label for idx, label in enumerate(self.labels)}