Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task casting for text classification & question answering #2255

Merged
merged 66 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
9a19cf2
WIP: task templates in datasets info
SBrandeis Apr 23, 2021
7ab394a
Add rename_columnS method
SBrandeis Apr 23, 2021
e5cef8d
WIP: Add prepare_for-task method
SBrandeis Apr 23, 2021
4564eb9
WIP: usage example
SBrandeis Apr 23, 2021
29650f8
Code quality
SBrandeis Apr 23, 2021
1516ae6
Merge branch 'master' into sbrandeis/task_casting
SBrandeis Apr 30, 2021
2157bf5
wip: Add text_classification pipeline
SBrandeis Apr 30, 2021
9302893
Decorate TaskTemplate with dataclass
lewtun May 4, 2021
02bcfa2
Add text classification template to emotion dataset_infos.json
lewtun May 6, 2021
531cf94
Fix key name in task_template_from_dict
lewtun May 6, 2021
d3bccca
Refactor TextClassification init and fix from_dict keys
lewtun May 6, 2021
6dfd7f2
Remove unused import
lewtun May 7, 2021
b0899b1
Add task specification to load_dataset
lewtun May 7, 2021
3cf039d
Rename emotion columns for task template testing
lewtun May 7, 2021
b2a02c5
Move task casting to builder
lewtun May 7, 2021
28e6ab1
Refactor task loading with prepare_for_task method
lewtun May 7, 2021
7f0f683
Add docstring and error message to prepare_for_task
lewtun May 7, 2021
3a30b62
Merge branch 'master' into sbrandeis/task_casting
lewtun May 7, 2021
2ec62a2
Add TODO in docstring
lewtun May 7, 2021
5df3b65
Update datas_infos.json for emotion dataset
lewtun May 7, 2021
35e6110
Call task preparation in load_dataset instead of builder
lewtun May 10, 2021
e21d728
Extend prepare_for_task to handle nested columns
lewtun May 10, 2021
8d5427d
Fix dataclass inheritance for QuestionAnswering task template
lewtun May 10, 2021
c344557
Unflatten column renaming
lewtun May 10, 2021
6220e7f
Replace nested question-answering columns with outer answers column
lewtun May 10, 2021
388c5d4
Revert emotion dataset
lewtun May 10, 2021
d759f65
Merge branch 'master' into sbrandeis/task_casting
lewtun May 10, 2021
6b851aa
Fix imports
lewtun May 10, 2021
95d395e
Revert emotion dataset for real!
lewtun May 10, 2021
6d74541
Remove label_mapping from TextClassification template
lewtun May 10, 2021
0f6513a
Rename TextClassification task labels to "labels" for Trainer compati…
lewtun May 10, 2021
d3cdf78
Add unit tests for text classification & question answering tasks
lewtun May 10, 2021
34aa06e
Fix style
lewtun May 10, 2021
fe61d2c
Remove setup from question answering test
lewtun May 10, 2021
46beea9
Clean up dataset memory after test
lewtun May 10, 2021
c83e501
Rename task names to use hyphen instead of underscore
lewtun May 11, 2021
db52ce7
Integrate review comments
lewtun May 12, 2021
3ad17f3
Remove task template from SQuAD
lewtun May 12, 2021
14e6e5b
Add hashing of task templates to enable set intersections
lewtun May 12, 2021
316cdb0
Filter Nones from task templates intersection
lewtun May 12, 2021
8e2299a
Fix style & quality
lewtun May 12, 2021
40de679
Clean up docstring
lewtun May 12, 2021
b67ed1f
Use context manager for in_memory dataset test
lewtun May 12, 2021
944b22f
Handle cases where task templates are none in merge
lewtun May 12, 2021
1cb8c93
Add more context managers!
lewtun May 12, 2021
6858125
Add tests for task template concatenation
lewtun May 14, 2021
6c6b3c9
Mark task name and schemas a class variables
lewtun May 14, 2021
4d40558
Replace custom hash functions with frozen dataclass decorators
lewtun May 14, 2021
7076902
Replace init with post_init in task templates
lewtun May 14, 2021
81674cc
Add supported tasks to docstring
lewtun May 14, 2021
3e39872
Add supported tasks to prepare_for_task docstring
lewtun May 14, 2021
7abd2f0
Update dataset_dict doctstring
lewtun May 14, 2021
441a36e
Add TODO and tweak docstrings
lewtun May 14, 2021
894a509
Merge branch 'master' into sbrandeis/task_casting
lewtun May 14, 2021
56635e7
Allow TaskTemplate objects to be passed to prepare_for_task
lewtun May 17, 2021
1274aa5
Add tests for invalid templates
lewtun May 17, 2021
5e48403
Update docstring for prepare_for_task
lewtun May 17, 2021
5a6021a
Merge branch 'master' into sbrandeis/task_casting
lewtun May 17, 2021
a868d35
Fix docstrings and add to HTML build
lewtun May 17, 2021
99200c7
Remove redundant context manager in tests
lewtun May 17, 2021
f895389
Add TaskTemplate to typing of load_dataset
lewtun May 18, 2021
290b583
Remove default value for label_schema in TextClassificationTemplate
lewtun May 18, 2021
3578dbd
Revert default value for label_schema and add TODO
lewtun May 18, 2021
7a0de24
Add and improve docstrings for task templates
lewtun May 18, 2021
6bf7970
Add list of task templates to docs
lewtun May 18, 2021
2dc8d26
Migrate task templates to dedicated section of docs
lewtun May 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,4 @@ The documentation is organized in five parts:
package_reference/builder_classes
package_reference/table_classes
package_reference/logging_methods
package_reference/task_templates
6 changes: 3 additions & 3 deletions docs/source/package_reference/main_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The base class :class:`datasets.Dataset` implements a Dataset backed by an Apach
info, split, builder_name, citation, config_name, dataset_size,
description, download_checksums, download_size, features, homepage,
license, size_in_bytes, supervised_keys, version,
from_csv, from_json, from_text,
from_csv, from_json, from_text, prepare_for_task,

.. autofunction:: datasets.concatenate_datasets

Expand All @@ -54,7 +54,7 @@ It also has dataset transform methods like map or filter, to process all the spl
flatten_, cast_, remove_columns_, rename_column_,
flatten, cast, remove_columns, rename_column, class_encode_column,
save_to_disk, load_from_disk,
from_csv, from_json, from_text,
from_csv, from_json, from_text, prepare_for_task


``Features``
Expand Down Expand Up @@ -113,4 +113,4 @@ The base class ``Metric`` implements a Metric backed by one or several :class:`d

.. autofunction:: datasets.filesystems.extract_path_from_uri

.. autofunction:: datasets.filesystems.is_remote_filesystem
.. autofunction:: datasets.filesystems.is_remote_filesystem
8 changes: 8 additions & 0 deletions docs/source/package_reference/task_templates.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Task templates
----------------------------------------------------

The tasks supported by :class:`datasets.Dataset.prepare_for_task` and :class:`datasets.DatasetDict.prepare_for_task`.

.. automodule:: datasets.tasks
:members:
:exclude-members: TaskTemplate
39 changes: 39 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .search import IndexableMixin
from .splits import NamedSplit
from .table import ConcatenationTable, InMemoryTable, MemoryMappedTable, Table, concat_tables, list_table_cache_files
from .tasks import TaskTemplate
from .utils import map_nested
from .utils.deprecation_utils import deprecated
from .utils.file_utils import estimate_dataset_size
Expand Down Expand Up @@ -1384,6 +1385,44 @@ def with_transform(
dataset.set_transform(transform=transform, columns=columns, output_all_columns=output_all_columns)
return dataset

def prepare_for_task(self, task: Union[str, TaskTemplate]) -> "Dataset":
"""Prepare a dataset for the given task by casting the dataset's :class:`Features` to standardized column names and types as detailed in :py:mod:`datasets.tasks`.

Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema.
lewtun marked this conversation as resolved.
Show resolved Hide resolved

Args:
task (:obj:`Union[str, TaskTemplate]`): The task to prepare the dataset for during training and evaluation. If :obj:`str`, supported tasks include:

- :obj:`"text-classification"`
- :obj:`"question-answering"`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question-answering exists in two forms: abstractive and extractive question answering.

we can keep a generic question-answering but then it will probably mean diferrent schema of input/output for both (abstractive will have text for both while extractive can use spans indication as well as text).

Or we can also propose to use abstractive-question-answering and extractive-question-answering for instance.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could have question-answering-abstractive and question-answering-extractive if somehow we can use a for a completion or search in the future (detail).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I see that people are more organizing in terms of general and sub-tasks, for instance on paperwithcode: https://paperswithcode.com/area/natural-language-processing and on nlpprogress: https://github.com/sebastianruder/NLP-progress/blob/master/english/question_answering.md#squad

Probably the best is to align with one of these in terms of denomination, PaperWithCode is probably the most active and maintained and we work with them as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our idea was to start by following the pipeline taxonomy in transformers, where question-answering is indeed restricted to the extractive case.

but you're 100% correct that we should already start considering the abstractive case or grouping in terms of the sub-domains that PwC adopts.

since our focus right now is on getting text-classification integrated and tested with autonlp, i've opened an issue here to tackle this next: #2371

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PWC uses "Question Answering" (meaning extractive) vs "Generative Question Answering" (includes abstractive) which encapsulates most of the QA tasks except Open/Closed Domain QA and Knowledge base QA (they require no context since the knowledge is not part of the query).

I'm fine with these names, or simply extractive and abstractive


If :obj:`TaskTemplate`, must be one of the task templates in :py:mod:`datasets.tasks`.
"""
# TODO(lewtun): Add support for casting nested features like answers.text and answers.answer_start in SQuAD
if isinstance(task, str):
tasks = [template.task for template in (self.info.task_templates or [])]
compatible_templates = [template for template in (self.info.task_templates or []) if template.task == task]
if not compatible_templates:
raise ValueError(f"Task {task} is not compatible with this dataset! Available tasks: {tasks}")

if len(compatible_templates) > 1:
raise ValueError(
f"Expected 1 task template but found {len(compatible_templates)}! Please ensure that `datasets.DatasetInfo.task_templates` contains a unique set of task types."
)
template = compatible_templates[0]
elif isinstance(task, TaskTemplate):
template = task
else:
raise ValueError(
f"Expected a `str` or `datasets.tasks.TaskTemplate` object but got task {task} with type {type(task)}."
)
column_mapping = template.column_mapping
columns_to_drop = [column for column in self.column_names if column not in column_mapping]
dataset = self.remove_columns(columns_to_drop)
dataset = dataset.rename_columns(column_mapping)
dataset = dataset.cast(features=template.features)
return dataset

def _getitem(
self,
key: Union[int, slice, str],
Expand Down
17 changes: 17 additions & 0 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .features import Features
from .filesystems import extract_path_from_uri, is_remote_filesystem
from .table import Table
from .tasks import TaskTemplate
from .utils.deprecation_utils import deprecated
from .utils.typing import PathLike

Expand Down Expand Up @@ -790,3 +791,19 @@ def from_text(
return TextDatasetReader(
path_or_paths, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
).read()

def prepare_for_task(self, task: Union[str, TaskTemplate]):
"""Prepare a dataset for the given task by casting the dataset's :class:`Features` to standardized column names and types as detailed in :py:mod:`datasets.tasks`.

Casts :attr:`datasets.DatasetInfo.features` according to a task-specific schema.

Args:
task (:obj:`Union[str, TaskTemplate]`): The task to prepare the dataset for during training and evaluation. If :obj:`str`, supported tasks include:

- :obj:`"text-classification"`
- :obj:`"question-answering"`

If :obj:`TaskTemplate`, must be one of the task templates in :py:mod:`datasets.tasks`.
"""
self._check_values_type()
return DatasetDict({k: dataset.prepare_for_task(task=task) for k, dataset in self.items()})
27 changes: 27 additions & 0 deletions src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from . import config
from .features import Features, Value
from .splits import SplitDict
from .tasks import TaskTemplate, task_template_from_dict
from .utils import Version
from .utils.logging import get_logger

Expand Down Expand Up @@ -108,6 +109,7 @@ class DatasetInfo:
post_processing_size (int, optional):
dataset_size (int, optional):
size_in_bytes (int, optional):
task_templates (List[TaskTemplate], optional):
"""

# Set in the dataset scripts
Expand All @@ -118,6 +120,7 @@ class DatasetInfo:
features: Optional[Features] = None
post_processed: Optional[PostProcessedInfo] = None
supervised_keys: Optional[SupervisedKeysData] = None
task_templates: Optional[List[TaskTemplate]] = None

# Set later by the builder
builder_name: Optional[str] = None
Expand Down Expand Up @@ -150,6 +153,19 @@ def __post_init__(self):
else:
self.supervised_keys = SupervisedKeysData(**self.supervised_keys)

if self.task_templates is not None:
if isinstance(self.task_templates, (list, tuple)):
templates = [
template if isinstance(template, TaskTemplate) else task_template_from_dict(template)
for template in self.task_templates
]
self.task_templates = [template for template in templates if template is not None]
elif isinstance(self.task_templates, TaskTemplate):
self.task_templates = [self.task_templates]
else:
template = task_template_from_dict(self.task_templates)
self.task_templates = [template] if template is not None else []

def _license_path(self, dataset_info_dir):
return os.path.join(dataset_info_dir, config.LICENSE_FILENAME)

Expand Down Expand Up @@ -188,6 +204,16 @@ def unique(values):
license = "\n\n".join(unique(info.license for info in dataset_infos))
features = None
supervised_keys = None
task_templates = None

# Find common task templates across all dataset infos
all_task_templates = [info.task_templates for info in dataset_infos if info.task_templates is not None]
if len(all_task_templates) > 1:
task_templates = list(set(all_task_templates[0]).intersection(*all_task_templates[1:]))
elif len(all_task_templates):
task_templates = list(set(all_task_templates[0]))
# If no common task templates found, replace empty list with None
task_templates = task_templates if task_templates else None

return cls(
description=description,
Expand All @@ -196,6 +222,7 @@ def unique(values):
license=license,
features=features,
supervised_keys=supervised_keys,
task_templates=task_templates,
)

@classmethod
Expand Down
6 changes: 6 additions & 0 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .metric import Metric
from .packaged_modules import _PACKAGED_DATASETS_MODULES, hash_python_lines
from .splits import Split
from .tasks import TaskTemplate
from .utils.download_manager import GenerateMode
from .utils.file_utils import (
DownloadConfig,
Expand Down Expand Up @@ -635,6 +636,7 @@ def load_dataset(
save_infos: bool = False,
script_version: Optional[Union[str, Version]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
task: Optional[Union[str, TaskTemplate]] = None,
**config_kwargs,
) -> Union[DatasetDict, Dataset]:
"""Load a dataset.
Expand Down Expand Up @@ -694,6 +696,7 @@ def load_dataset(
You can specify a different version that the default "main" by using a commit sha or a git tag of the dataset repository.
use_auth_token (``str`` or ``bool``, optional): Optional string or boolean to use as Bearer token for remote files on the Datasets Hub.
If True, will get token from `"~/.huggingface"`.
task (``str``): The task to prepare the dataset for during training and evaluation. Casts the dataset's :class:`Features` to standardized column names and types as detailed in :py:mod:`datasets.tasks`.
**config_kwargs: Keyword arguments to be passed to the :class:`BuilderConfig` and used in the :class:`DatasetBuilder`.

Returns:
Expand Down Expand Up @@ -752,6 +755,9 @@ def load_dataset(
keep_in_memory if keep_in_memory is not None else is_small_dataset(builder_instance.info.dataset_size)
)
ds = builder_instance.as_dataset(split=split, ignore_verifications=ignore_verifications, in_memory=keep_in_memory)
# Rename and cast features to match task schema
if task is not None:
ds = ds.prepare_for_task(task)
if save_infos:
builder_instance._save_infos()

Expand Down
22 changes: 22 additions & 0 deletions src/datasets/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Optional

from .base import TaskTemplate
from .question_answering import QuestionAnswering
from .text_classification import TextClassification


__all__ = ["TaskTemplate", "QuestionAnswering", "TextClassification"]


NAME2TEMPLATE = {QuestionAnswering.task: QuestionAnswering, TextClassification.task: TextClassification}


def task_template_from_dict(task_template_dict: dict) -> Optional[TaskTemplate]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function supposed to be used by the user? If so it should have a doc string and doc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no this is not supposed to be used by the user (at least for now). still no harm in having a docstring, so i've added:

"""Create one of the supported task templates in :obj:`datasets.tasks` from a dictionary."""

"""Create one of the supported task templates in :py:mod:`datasets.tasks` from a dictionary."""
task_name = task_template_dict.get("task")
if task_name is None:
return None
template = NAME2TEMPLATE.get(task_name)
if template is None:
return None
return template.from_dict(task_template_dict)
26 changes: 26 additions & 0 deletions src/datasets/tasks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import abc
from dataclasses import dataclass
from typing import ClassVar, Dict

from ..features import Features


@dataclass(frozen=True)
class TaskTemplate(abc.ABC):
task: ClassVar[str]
input_schema: ClassVar[Features]
label_schema: ClassVar[Features]

@property
def features(self) -> Features:
return Features(**self.input_schema, **self.label_schema)

@property
@abc.abstractmethod
def column_mapping(self) -> Dict[str, str]:
return NotImplemented

@classmethod
@abc.abstractmethod
def from_dict(cls, template_dict: dict) -> "TaskTemplate":
return NotImplemented
41 changes: 41 additions & 0 deletions src/datasets/tasks/question_answering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from dataclasses import dataclass
from typing import Dict

from ..features import Features, Sequence, Value
from .base import TaskTemplate


@dataclass(frozen=True)
class QuestionAnswering(TaskTemplate):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QuestionAnsweringExtractive ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestion! i'll keep this unchanged for now since it will soon be overhauled by #2371

task = "question-answering"
input_schema = Features({"question": Value("string"), "context": Value("string")})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you want to check with a few QA datasets that this schema make sense. Typically NaturalQuestions, TriviaQA and can be good second datasets to compare to and be sure of the generality of the schema.

A good recent list of QA datasets to compare the schemas among, is for instance in the UnitedQA paper: https://arxiv.org/abs/2101.00178

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the tip! added to #2371

label_schema = Features(
{
"answers": Sequence(
{
"text": Value("string"),
"answer_start": Value("int32"),
}
)
}
)
question_column: str = "question"
context_column: str = "context"
answers_column: str = "answers"

def __post_init__(self):
object.__setattr__(self, "question_column", self.question_column)
object.__setattr__(self, "context_column", self.context_column)
object.__setattr__(self, "answers_column", self.answers_column)

@property
def column_mapping(self) -> Dict[str, str]:
return {self.question_column: "question", self.context_column: "context", self.answers_column: "answers"}

@classmethod
def from_dict(cls, template_dict: dict) -> "QuestionAnswering":
return cls(
question_column=template_dict["question_column"],
context_column=template_dict["context_column"],
answers_column=template_dict["answers_column"],
)
42 changes: 42 additions & 0 deletions src/datasets/tasks/text_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from dataclasses import dataclass
from typing import Dict, List

from ..features import ClassLabel, Features, Value
from .base import TaskTemplate


@dataclass(frozen=True)
class TextClassification(TaskTemplate):
task = "text-classification"
input_schema = Features({"text": Value("string")})
# TODO(lewtun): Since we update this in __post_init__ do we need to set a default? We'll need it for __init__ so
# investigate if there's a more elegant approach.
label_schema = Features({"labels": ClassLabel})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Since we update this in __post_init__ do we need to declare a default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! i'll fix it :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah actually we need to declare a default because label_schema is a required argument for __init__. i've reverted the change and added a TODO so we can think about a more elegant approach as we iterate on the other tasks

labels: List[str]
text_column: str = "text"
label_column: str = "labels"

def __post_init__(self):
assert sorted(set(self.labels)) == sorted(self.labels), "Labels must be unique"
# Cast labels to tuple to allow hashing
object.__setattr__(self, "labels", tuple(sorted(self.labels)))
object.__setattr__(self, "text_column", self.text_column)
object.__setattr__(self, "label_column", self.label_column)
self.label_schema["labels"] = ClassLabel(names=self.labels)
object.__setattr__(self, "label2id", {label: idx for idx, label in enumerate(self.labels)})
object.__setattr__(self, "id2label", {idx: label for label, idx in self.label2id.items()})

@property
def column_mapping(self) -> Dict[str, str]:
return {
self.text_column: "text",
self.label_column: "labels",
}

@classmethod
def from_dict(cls, template_dict: dict) -> "TextClassification":
return cls(
text_column=template_dict["text_column"],
label_column=template_dict["label_column"],
labels=template_dict["labels"],
)
Loading