Skip to content

Commit

Permalink
Improve task api code quality (#2376)
Browse files Browse the repository at this point in the history
* Improve task api code quality

* Add todo deleted by accident

* Lazy initialize label schema in text classification task
  • Loading branch information
mariosasko committed May 25, 2021
1 parent 74751e3 commit 633ddcd
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 42 deletions.
14 changes: 9 additions & 5 deletions src/datasets/tasks/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import abc
import dataclasses
from dataclasses import dataclass
from typing import ClassVar, Dict
from typing import ClassVar, Dict, Type, TypeVar

from ..features import Features


T = TypeVar("T", bound="TaskTemplate")


@dataclass(frozen=True)
class TaskTemplate(abc.ABC):
task: ClassVar[str]
Expand All @@ -18,9 +22,9 @@ def features(self) -> Features:
@property
@abc.abstractmethod
def column_mapping(self) -> Dict[str, str]:
return NotImplemented
raise NotImplementedError

@classmethod
@abc.abstractmethod
def from_dict(cls, template_dict: dict) -> "TaskTemplate":
return NotImplemented
def from_dict(cls: Type[T], template_dict: dict) -> T:
field_names = set(f.name for f in dataclasses.fields(cls))
return cls(**{k: v for k, v in template_dict.items() if k in field_names})
21 changes: 4 additions & 17 deletions src/datasets/tasks/question_answering.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from dataclasses import dataclass
from typing import Dict
from typing import ClassVar, Dict

from ..features import Features, Sequence, Value
from .base import TaskTemplate


@dataclass(frozen=True)
class QuestionAnswering(TaskTemplate):
task = "question-answering"
input_schema = Features({"question": Value("string"), "context": Value("string")})
label_schema = Features(
task: ClassVar[str] = "question-answering"
input_schema: ClassVar[Features] = Features({"question": Value("string"), "context": Value("string")})
label_schema: ClassVar[Features] = Features(
{
"answers": Sequence(
{
Expand All @@ -23,19 +23,6 @@ class QuestionAnswering(TaskTemplate):
context_column: str = "context"
answers_column: str = "answers"

def __post_init__(self):
object.__setattr__(self, "question_column", self.question_column)
object.__setattr__(self, "context_column", self.context_column)
object.__setattr__(self, "answers_column", self.answers_column)

@property
def column_mapping(self) -> Dict[str, str]:
return {self.question_column: "question", self.context_column: "context", self.answers_column: "answers"}

@classmethod
def from_dict(cls, template_dict: dict) -> "QuestionAnswering":
return cls(
question_column=template_dict["question_column"],
context_column=template_dict["context_column"],
answers_column=template_dict["answers_column"],
)
50 changes: 30 additions & 20 deletions src/datasets/tasks/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,40 @@
from dataclasses import dataclass
from typing import Dict, List
from typing import ClassVar, Dict, List

from ..features import ClassLabel, Features, Value
from .base import TaskTemplate


class FeaturesWithLazyClassLabel:
def __init__(self, features, label_column="labels"):
assert label_column in features, f"Key '{label_column}' missing in features {features}"
self._features = features
self._label_column = label_column

def __get__(self, obj, objtype=None):
if obj is None:
return self._features

assert hasattr(obj, self._label_column), f"Object has no attribute '{self._label_column}'"
features = self._features.copy()
features["labels"] = ClassLabel(names=getattr(obj, self._label_column))
return features


@dataclass(frozen=True)
class TextClassification(TaskTemplate):
task = "text-classification"
input_schema = Features({"text": Value("string")})
# TODO(lewtun): Since we update this in __post_init__ do we need to set a default? We'll need it for __init__ so
# investigate if there's a more elegant approach.
label_schema = Features({"labels": ClassLabel})
task: ClassVar[str] = "text-classification"
input_schema: ClassVar[Features] = Features({"text": Value("string")})
# TODO(lewtun): Find a more elegant approach without descriptors.
label_schema: ClassVar[Features] = FeaturesWithLazyClassLabel(Features({"labels": ClassLabel}))
labels: List[str]
text_column: str = "text"
label_column: str = "labels"

def __post_init__(self):
assert sorted(set(self.labels)) == sorted(self.labels), "Labels must be unique"
assert len(self.labels) == len(set(self.labels)), "Labels must be unique"
# Cast labels to tuple to allow hashing
object.__setattr__(self, "labels", tuple(sorted(self.labels)))
object.__setattr__(self, "text_column", self.text_column)
object.__setattr__(self, "label_column", self.label_column)
self.label_schema["labels"] = ClassLabel(names=self.labels)
object.__setattr__(self, "label2id", {label: idx for idx, label in enumerate(self.labels)})
object.__setattr__(self, "id2label", {idx: label for label, idx in self.label2id.items()})
self.__dict__["labels"] = tuple(sorted(self.labels))

@property
def column_mapping(self) -> Dict[str, str]:
Expand All @@ -33,10 +43,10 @@ def column_mapping(self) -> Dict[str, str]:
self.label_column: "labels",
}

@classmethod
def from_dict(cls, template_dict: dict) -> "TextClassification":
return cls(
text_column=template_dict["text_column"],
label_column=template_dict["label_column"],
labels=template_dict["labels"],
)
@property
def label2id(self):
return {label: idx for idx, label in enumerate(self.labels)}

@property
def id2label(self):
return {idx: label for idx, label in enumerate(self.labels)}

1 comment on commit 633ddcd

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==1.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.018538 / 0.011353 (0.007185) 0.012064 / 0.011008 (0.001056) 0.040677 / 0.038508 (0.002168) 0.031913 / 0.023109 (0.008803) 0.297091 / 0.275898 (0.021193) 0.324904 / 0.323480 (0.001425) 0.008963 / 0.007986 (0.000978) 0.004307 / 0.004328 (-0.000021) 0.009531 / 0.004250 (0.005281) 0.041076 / 0.037052 (0.004024) 0.294736 / 0.258489 (0.036247) 0.328684 / 0.293841 (0.034843) 0.116577 / 0.128546 (-0.011969) 0.090736 / 0.075646 (0.015090) 0.352642 / 0.419271 (-0.066630) 0.504587 / 0.043533 (0.461055) 0.299338 / 0.255139 (0.044199) 0.317213 / 0.283200 (0.034013) 3.298744 / 0.141683 (3.157061) 1.415001 / 1.452155 (-0.037154) 1.463159 / 1.492716 (-0.029557)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.005438 / 0.018006 (-0.012568) 0.420157 / 0.000490 (0.419668) 0.000251 / 0.000200 (0.000051) 0.000036 / 0.000054 (-0.000018)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.035416 / 0.037411 (-0.001996) 0.020931 / 0.014526 (0.006405) 0.023356 / 0.176557 (-0.153200) 0.038483 / 0.737135 (-0.698653) 0.025835 / 0.296338 (-0.270503)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.311677 / 0.215209 (0.096468) 3.107640 / 2.077655 (1.029986) 1.584836 / 1.504120 (0.080716) 1.421730 / 1.541195 (-0.119465) 1.448138 / 1.468490 (-0.020352) 4.895410 / 4.584777 (0.310633) 4.640127 / 3.745712 (0.894415) 7.222443 / 5.269862 (1.952581) 6.418152 / 4.565676 (1.852475) 0.547634 / 0.424275 (0.123358) 0.009293 / 0.007607 (0.001686) 0.446863 / 0.226044 (0.220819) 4.508920 / 2.268929 (2.239991) 2.207294 / 55.444624 (-53.237331) 1.860150 / 6.876477 (-5.016326) 1.948853 / 2.142072 (-0.193220) 5.133418 / 4.805227 (0.328191) 3.416723 / 6.500664 (-3.083941) 5.600067 / 0.075469 (5.524598)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 9.361021 / 1.841788 (7.519233) 11.006504 / 8.074308 (2.932196) 26.200079 / 10.191392 (16.008687) 0.652576 / 0.680424 (-0.027847) 0.465015 / 0.534201 (-0.069186) 0.558341 / 0.579283 (-0.020942) 0.438242 / 0.434364 (0.003879) 0.514551 / 0.540337 (-0.025786) 1.314226 / 1.386936 (-0.072710)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.018288 / 0.011353 (0.006936) 0.011857 / 0.011008 (0.000849) 0.040681 / 0.038508 (0.002173) 0.031684 / 0.023109 (0.008575) 0.251069 / 0.275898 (-0.024829) 0.277696 / 0.323480 (-0.045784) 0.009447 / 0.007986 (0.001461) 0.004232 / 0.004328 (-0.000096) 0.009691 / 0.004250 (0.005440) 0.046760 / 0.037052 (0.009707) 0.250153 / 0.258489 (-0.008336) 0.281852 / 0.293841 (-0.011989) 0.114788 / 0.128546 (-0.013758) 0.088393 / 0.075646 (0.012747) 0.347731 / 0.419271 (-0.071540) 0.330606 / 0.043533 (0.287074) 0.250340 / 0.255139 (-0.004799) 0.279151 / 0.283200 (-0.004048) 1.392925 / 0.141683 (1.251242) 1.433390 / 1.452155 (-0.018765) 1.468934 / 1.492716 (-0.023782)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.008028 / 0.018006 (-0.009978) 0.426851 / 0.000490 (0.426361) 0.004157 / 0.000200 (0.003957) 0.000152 / 0.000054 (0.000098)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.031811 / 0.037411 (-0.005600) 0.020682 / 0.014526 (0.006156) 0.023724 / 0.176557 (-0.152832) 0.038104 / 0.737135 (-0.699031) 0.024197 / 0.296338 (-0.272141)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.305415 / 0.215209 (0.090206) 3.040730 / 2.077655 (0.963075) 1.498402 / 1.504120 (-0.005718) 1.333278 / 1.541195 (-0.207916) 1.350236 / 1.468490 (-0.118254) 5.148586 / 4.584777 (0.563809) 4.389965 / 3.745712 (0.644253) 6.477180 / 5.269862 (1.207318) 5.807956 / 4.565676 (1.242280) 0.499551 / 0.424275 (0.075276) 0.009147 / 0.007607 (0.001540) 0.451514 / 0.226044 (0.225469) 4.497385 / 2.268929 (2.228457) 2.146803 / 55.444624 (-53.297822) 1.809569 / 6.876477 (-5.066908) 1.900331 / 2.142072 (-0.241741) 5.329532 / 4.805227 (0.524305) 3.368877 / 6.500664 (-3.131787) 3.815016 / 0.075469 (3.739547)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 8.123219 / 1.841788 (6.281431) 11.067707 / 8.074308 (2.993399) 25.316428 / 10.191392 (15.125036) 0.738316 / 0.680424 (0.057892) 0.495172 / 0.534201 (-0.039029) 0.643747 / 0.579283 (0.064464) 0.478725 / 0.434364 (0.044361) 0.547867 / 0.540337 (0.007529) 1.319648 / 1.386936 (-0.067288)

CML watermark

Please sign in to comment.