Skip to content

Commit

Permalink
datasets support
Browse files Browse the repository at this point in the history
  • Loading branch information
ju-resplande committed Dec 10, 2021
1 parent 8c5e785 commit 0c48982
Showing 1 changed file with 68 additions and 102 deletions.
170 changes: 68 additions & 102 deletions plue.py → plue/plue.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/jubs12/PLUE}},
howpublished = {\\url{https://github.com/jubs12/PLUE}},
commit = {CURRENT_COMMIT}
}
Expand All @@ -52,18 +52,13 @@
the GLUE benchmark and Scitail using OPUS-MT model and Google Cloud Translation.
"""

_MRPC_DEV_IDS = "https://dl.fbaipublicfiles.com/glue/data/mrpc_dev_ids.tsv"
_MRPC_TRAIN = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt"
_MRPC_TEST = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt"
MNLI_URL = "https://github.com/jubs12/PLUE/releases/download/v1.0.0/MNLI.zip"

_MNLI_BASE_KWARGS = dict(
text_features={
"premise": "sentence1",
"hypothesis": "sentence2",
},
text_features={"premise": "sentence1", "hypothesis": "sentence2",},
label_classes=["entailment", "neutral", "contradiction"],
label_column="gold_label",
data_dir="datasets/MNLI",
data_dir="PLUE-1.0.0/datasets/MNLI",
citation=textwrap.dedent(
"""\
@InProceedings{N18-1101,
Expand Down Expand Up @@ -101,7 +96,6 @@ def __init__(
self,
text_features,
label_column,
data_url,
data_dir,
citation,
url,
Expand All @@ -128,7 +122,9 @@ def __init__(
of the label and processing it to the form required by the label feature
**kwargs: keyword arguments forwarded to super.
"""
super(PlueConfig, self).__init__(version=datasets.Version("1.0.0", ""), **kwargs)
super(PlueConfig, self).__init__(
version=datasets.Version("1.0.0", ""), **kwargs
)
self.text_features = text_features
self.label_column = label_column
self.label_classes = label_classes
Expand All @@ -155,7 +151,7 @@ class Plue(datasets.GeneratorBasedBuilder):
text_features={"sentence": "sentence"},
label_classes=["unacceptable", "acceptable"],
label_column="is_acceptable",
data_dir="datasets/CoLA",
data_dir="PLUE-1.0.0/datasets/CoLA",
citation=textwrap.dedent(
"""\
@article{warstadt2018neural,
Expand All @@ -179,7 +175,7 @@ class Plue(datasets.GeneratorBasedBuilder):
text_features={"sentence": "sentence"},
label_classes=["negative", "positive"],
label_column="label",
data_dir="datasets/SST-2",
data_dir="PLUE-1.0.0/datasets/SST-2",
citation=textwrap.dedent(
"""\
@inproceedings{socher2013recursive,
Expand All @@ -203,7 +199,7 @@ class Plue(datasets.GeneratorBasedBuilder):
text_features={"sentence1": "", "sentence2": ""},
label_classes=["not_equivalent", "equivalent"],
label_column="Quality",
data_dir="datasets/MRPC",
data_dir="PLUE-1.0.0/datasets/MRPC",
citation=textwrap.dedent(
"""\
@inproceedings{dolan2005automatically,
Expand All @@ -223,13 +219,10 @@ class Plue(datasets.GeneratorBasedBuilder):
community question-answering website Quora. The task is to determine whether a
pair of questions are semantically equivalent."""
),
text_features={
"question1": "question1",
"question2": "question2",
},
text_features={"question1": "question1", "question2": "question2",},
label_classes=["not_duplicate", "duplicate"],
label_column="is_duplicate",
data_dir="datasets/QQP",
data_dir="PLUE-1.0.0/datasets/QQP_v2",
citation=textwrap.dedent(
"""\
@online{WinNT,
Expand All @@ -251,12 +244,9 @@ class Plue(datasets.GeneratorBasedBuilder):
language inference data. Each pair is human-annotated with a similarity score
from 1 to 5."""
),
text_features={
"sentence1": "sentence1",
"sentence2": "sentence2",
},
text_features={"sentence1": "sentence1", "sentence2": "sentence2",},
label_column="score",
data_dir="datasets/STS-B",
data_dir="PLUE-1.0.0/datasets/STS-B",
citation=textwrap.dedent(
"""\
@article{cer2017semeval,
Expand Down Expand Up @@ -316,13 +306,10 @@ class Plue(datasets.GeneratorBasedBuilder):
the model select the exact answer, but also removes the simplifying assumptions that the answer
is always present in the input and that lexical overlap is a reliable cue."""
), # pylint: disable=line-too-long
text_features={
"question": "question",
"sentence": "sentence",
},
text_features={"question": "question", "sentence": "sentence",},
label_classes=["entailment", "not_entailment"],
label_column="label",
data_dir="datasets/QNLI",
data_dir="PLUE-1.0.0/datasets/QNLI",
citation=textwrap.dedent(
"""\
@article{rajpurkar2016squad,
Expand All @@ -344,13 +331,10 @@ class Plue(datasets.GeneratorBasedBuilder):
constructed based on news and Wikipedia text. We convert all datasets to a two-class split, where
for three-class datasets we collapse neutral and contradiction into not entailment, for consistency."""
), # pylint: disable=line-too-long
text_features={
"sentence1": "sentence1",
"sentence2": "sentence2",
},
text_features={"sentence1": "sentence1", "sentence2": "sentence2",},
label_classes=["entailment", "not_entailment"],
label_column="label",
data_dir="datasets/RTE",
data_dir="PLUE-1.0.0/datasets/RTE",
citation=textwrap.dedent(
"""\
@inproceedings{dagan2005pascal,
Expand Down Expand Up @@ -408,13 +392,10 @@ class Plue(datasets.GeneratorBasedBuilder):
between a model's score on this task and its score on the unconverted original task. We
call converted dataset WNLI (Winograd NLI)."""
),
text_features={
"sentence1": "sentence1",
"sentence2": "sentence2",
},
text_features={"sentence1": "sentence1", "sentence2": "sentence2",},
label_classes=["not_entailment", "entailment"],
label_column="label",
data_dir="datasets/WNLI",
data_dir="PLUE-1.0.0/datasets/WNLI",
citation=textwrap.dedent(
"""\
@inproceedings{levesque2012winograd,
Expand All @@ -437,13 +418,10 @@ class Plue(datasets.GeneratorBasedBuilder):
the SciTail dataset. The dataset contains 27,026 examples with 10,101 examples with entails label and 16,925 examples
with neutral label"""
),
text_features={
"premise": "premise",
"hypothesis": "hypothesis",
},
text_features={"premise": "premise", "hypothesis": "hypothesis",},
label_classes=["entails", "neutral"],
label_column="label",
data_dir="datasets/SciTail/tsv_format/",
data_dir="PLUE-1.0.0/datasets/SciTail",
citation=""""\
inproceedings{scitail,
Author = {Tushar Khot and Ashish Sabharwal and Peter Clark},
Expand All @@ -457,9 +435,14 @@ class Plue(datasets.GeneratorBasedBuilder):
]

def _info(self):
features = {text_feature: datasets.Value("string") for text_feature in self.config.text_features.keys()}
features = {
text_feature: datasets.Value("string")
for text_feature in self.config.text_features.keys()
}
if self.config.label_classes:
features["label"] = datasets.features.ClassLabel(names=self.config.label_classes)
features["label"] = datasets.features.ClassLabel(
names=self.config.label_classes
)
else:
features["label"] = datasets.Value("float32")
features["idx"] = datasets.Value("int32")
Expand All @@ -471,34 +454,30 @@ def _info(self):
)

def _split_generators(self, dl_manager):
if self.config.name == "mrpc":
data_dir = None
mrpc_files = dl_manager.download(
{
"dev_ids": _MRPC_DEV_IDS,
"train": _MRPC_TRAIN,
"test": _MRPC_TEST,
}
)
else:
dl_dir = dl_manager.download_and_extract(self.config.data_url)
data_dir = os.path.join(dl_dir, self.config.data_dir)
mrpc_files = None
data_url = MNLI_URL if self.config.name == "mnli" else self.config.data_url
dl_dir = dl_manager.download_and_extract(data_url)
data_dir = os.path.join(dl_dir, self.config.data_dir)

train_split = datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"data_file": os.path.join(data_dir or "", "train.tsv"),
"split": "train",
"mrpc_files": mrpc_files,
},
)
if self.config.name == "mnli":
return [
train_split,
_mnli_split_generator("validation_matched", data_dir, "dev", matched=True),
_mnli_split_generator("validation_mismatched", data_dir, "dev", matched=False),
_mnli_split_generator(
"validation_matched", data_dir, "dev", matched=True
),
_mnli_split_generator(
"validation_mismatched", data_dir, "dev", matched=False
),
_mnli_split_generator("test_matched", data_dir, "test", matched=True),
_mnli_split_generator("test_mismatched", data_dir, "test", matched=False),
_mnli_split_generator(
"test_mismatched", data_dir, "test", matched=False
),
]
elif self.config.name == "mnli_matched":
return [
Expand All @@ -518,23 +497,23 @@ def _split_generators(self, dl_manager):
gen_kwargs={
"data_file": os.path.join(data_dir or "", "dev.tsv"),
"split": "dev",
"mrpc_files": mrpc_files,
},
),
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"data_file": os.path.join(data_dir or "", "test.tsv"),
"split": "test",
"mrpc_files": mrpc_files,
},
),
]

def _generate_examples(self, data_file, split, mrpc_files=None):
def _generate_examples(self, data_file, split):
if self.config.name == "mrpc":
# We have to prepare the MRPC dataset from the original sources ourselves.
examples = self._generate_example_mrpc_files(mrpc_files=mrpc_files, split=split)
examples = self._generate_example_mrpc_files(
data_file=data_file, split=split
)
for example in examples:
yield example["idx"], example
else:
Expand All @@ -557,7 +536,10 @@ def _generate_examples(self, data_file, split, mrpc_files=None):
"is_acceptable": row[1],
}

example = {feat: row[col] for feat, col in self.config.text_features.items()}
example = {
feat: row[col]
for feat, col in self.config.text_features.items()
}
example["idx"] = n

if self.config.label_column in row:
Expand All @@ -577,46 +559,30 @@ def _generate_examples(self, data_file, split, mrpc_files=None):
else:
yield example["idx"], example

def _generate_example_mrpc_files(self, mrpc_files, split):
if split == "test":
with open(mrpc_files["test"], encoding="utf8") as f:
# The first 3 bytes are the utf-8 BOM \xef\xbb\xbf, which messes with
# the Quality key.
f.seek(3)
reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
for n, row in enumerate(reader):
yield {
"sentence1": row["#1 String"],
"sentence2": row["#2 String"],
"label": int(row["Quality"]),
"idx": n,
}
else:
with open(mrpc_files["dev_ids"], encoding="utf8") as f:
reader = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
dev_ids = [[row[0], row[1]] for row in reader]
with open(mrpc_files["train"], encoding="utf8") as f:
# The first 3 bytes are the utf-8 BOM \xef\xbb\xbf, which messes with
# the Quality key.
f.seek(3)
reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
for n, row in enumerate(reader):
is_row_in_dev = [row["#1 ID"], row["#2 ID"]] in dev_ids
if is_row_in_dev == (split == "dev"):
yield {
"sentence1": row["#1 String"],
"sentence2": row["#2 String"],
"label": int(row["Quality"]),
"idx": n,
}
def _generate_example_mrpc_files(self, data_file, split):
print(data_file)

with open(data_file, encoding="utf8") as f:
reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
for idx, row in enumerate(reader):
label = row["Quality"] if split != "test" else -1

yield {
"sentence1": row["#1 String"],
"sentence2": row["#2 String"],
"label": int(label),
"idx": idx,
}


def _mnli_split_generator(name, data_dir, split, matched):
return datasets.SplitGenerator(
name=name,
gen_kwargs={
"data_file": os.path.join(data_dir, "%s_%s.tsv" % (split, "matched" if matched else "mismatched")),
"data_file": os.path.join(
data_dir, "%s_%s.tsv" % (split, "matched" if matched else "mismatched")
),
"split": split,
"mrpc_files": None,
},
)

0 comments on commit 0c48982

Please sign in to comment.