Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Add support for learning from soft labels for Squad (MRC) models #1188

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 164 additions & 35 deletions pytext/data/sources/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pytext.data.sources.data_source import (
DataSource,
JSONString,
SafeFileWrapper,
generator_property,
)
Expand Down Expand Up @@ -125,11 +126,74 @@ def process_squad_tsv(
)

for id, row in enumerate(tsv):
doc, question, answers, answer_starts, has_answer = (
row[f] for f in field_names
)
parts = (row[f] for f in field_names)
doc, question, answers, answer_starts, has_answer = parts
answers = json.loads(answers)
answer_starts = json.loads(answer_starts)

for piece_dict in _split_document(
id,
doc,
question,
answers,
answer_starts,
has_answer == "True",
ignore_impossible,
max_character_length,
min_overlap,
):
yield piece_dict


def process_squad_tsv_for_kd(
fname, ignore_impossible, max_character_length, min_overlap, delimiter, quoted
):
if not fname:
print(f"Empty file name!")
return

field_names = [
"id1",
"doc",
"question",
"answers",
"answer_starts",
"has_answer",
"id2",
"start_logits",
"end_logits",
"has_answer_logits",
"pad_mask",
"segment_labels",
]
tsv_file = SafeFileWrapper(
get_absolute_path(fname), encoding="utf-8", errors="replace"
)
tsv = TSV(
tsv_file,
field_names=field_names,
delimiter=delimiter,
quoted=quoted,
drop_incomplete_rows=True,
)

for id, row in enumerate(tsv):
parts = (row[f] for f in field_names)
# All model output for KD are dumped using json serialization.
(
id1,
doc,
question,
answers,
answer_starts,
has_answer,
id2,
start_logits,
end_logits,
has_answer_logits,
pad_mask,
segment_labels,
) = (json.loads(s) for s in parts)
for piece_dict in _split_document(
id,
doc,
Expand All @@ -141,6 +205,15 @@ def process_squad_tsv(
max_character_length,
min_overlap,
):
piece_dict.update(
{
"start_logits": start_logits,
"end_logits": end_logits,
"has_answer_logits": has_answer_logits,
"pad_mask": pad_mask,
"segment_labels": segment_labels,
}
)
yield piece_dict


Expand All @@ -151,19 +224,34 @@ def process_squad(
min_overlap=0.1,
delimiter="\t",
quoted=False,
is_kd=False,
):
if fname.split(".")[-1] == "json":
return process_squad_json(
fname, ignore_impossible, max_character_length, min_overlap
fname=fname,
ignore_impossible=ignore_impossible,
max_character_length=max_character_length,
min_overlap=min_overlap,
)
else:
return process_squad_tsv(
fname,
ignore_impossible,
max_character_length,
min_overlap,
delimiter,
quoted,
return (
process_squad_tsv(
fname=fname,
ignore_impossible=ignore_impossible,
max_character_length=max_character_length,
min_overlap=min_overlap,
delimiter=delimiter,
quoted=quoted,
)
if not is_kd
else process_squad_tsv_for_kd(
fname=fname,
ignore_impossible=ignore_impossible,
max_character_length=max_character_length,
min_overlap=min_overlap,
delimiter=delimiter,
quoted=quoted,
)
)


Expand All @@ -173,6 +261,18 @@ class SquadDataSource(DataSource):
Will return tuples of (doc, question, answer, answer_start, has_answer)
"""

__EXPANSIBLE__ = True

DEFAULT_SCHEMA = {
"id": int,
"doc": str,
"question": str,
"answers": List[str],
"answer_starts": List[int],
"answer_ends": List[int],
"has_answer": str,
}

class Config(DataSource.Config):
train_filename: Optional[str] = "train-v2.0.json"
test_filename: Optional[str] = "dev-v2.0.json"
Expand All @@ -184,16 +284,16 @@ class Config(DataSource.Config):
quoted: bool = False

@classmethod
def from_config(cls, config: Config, schema=None):
def from_config(cls, config: Config, schema=DEFAULT_SCHEMA):
return cls(
config.train_filename,
config.test_filename,
config.eval_filename,
config.ignore_impossible,
config.max_character_length,
config.min_overlap,
config.delimiter,
config.quoted,
train_filename=config.train_filename,
test_filename=config.test_filename,
eval_filename=config.eval_filename,
ignore_impossible=config.ignore_impossible,
max_character_length=config.max_character_length,
min_overlap=config.min_overlap,
delimiter=config.delimiter,
quoted=config.quoted,
)

def __init__(
Expand All @@ -206,16 +306,8 @@ def __init__(
min_overlap=Config.min_overlap,
delimiter=Config.delimiter,
quoted=Config.quoted,
schema=DEFAULT_SCHEMA,
):
schema = {
"id": int,
"doc": str,
"question": str,
"answers": List[str],
"answer_starts": List[int],
"answer_ends": List[int],
"has_answer": str,
}
super().__init__(schema)
self.train_filename = train_filename
self.test_filename = test_filename
Expand All @@ -228,12 +320,12 @@ def __init__(

def process_file(self, fname):
return process_squad(
fname,
self.ignore_impossible,
self.max_character_length,
self.min_overlap,
self.delimiter,
self.quoted,
fname=fname,
ignore_impossible=self.ignore_impossible,
max_character_length=self.max_character_length,
min_overlap=self.min_overlap,
delimiter=self.delimiter,
quoted=self.quoted,
)

@generator_property
Expand All @@ -247,3 +339,40 @@ def test(self):
@generator_property
def eval(self):
return self.process_file(self.eval_filename)


class SquadDataSourceForKD(SquadDataSource):
"""
Squad-like data along with soft labels (logits).
Will return tuples of (
doc, question, answer, answer_start, has_answer,
start_logits, end_logits, has_answer_logits, pad_mask, segment_labels
)
"""

def __init__(self, **kwargs):
kwargs["schema"] = {
"id": int,
"doc": JSONString,
"question": JSONString,
"answers": List[str],
"answer_starts": List[int],
"has_answer": JSONString,
"start_logits": List[float],
"end_logits": List[float],
"has_answer_logits": List[float],
"pad_mask": List[int],
"segment_labels": List[int],
}
super().__init__(**kwargs)

def process_file(self, fname):
return process_squad(
fname=fname,
ignore_impossible=self.ignore_impossible,
max_character_length=self.max_character_length,
min_overlap=self.min_overlap,
delimiter=self.delimiter,
quoted=self.quoted,
is_kd=True,
)
9 changes: 9 additions & 0 deletions pytext/data/sources/tsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __init__(
self.delimiter = delimiter
self.quoted = quoted
self.drop_incomplete_rows = drop_incomplete_rows
self.total_rows_count = 0
self.incomplete_rows_count = 0
self._access_lock = threading.Lock()
csv.field_size_limit(sys.maxsize)

Expand All @@ -48,14 +50,21 @@ def __iter__(self):
)
if self.drop_incomplete_rows:
for row in reader:
self.total_rows_count += 1
if any(map(lambda v: v is None, row.values())): # drop!
self.incomplete_rows_count += 1
continue
yield row
else:
yield from reader
finally:
self._access_lock.release()

def __del__(self):
print("Destroying TSV object")
print(f"Total number of rows read: {self.total_rows_count}")
print(f"Total number of rows dropped: {self.incomplete_rows_count}")


class TSVDataSource(RootDataSource):
"""DataSource which loads data from TSV sources. Uses python's csv library."""
Expand Down
Loading