Skip to content

Commit

Permalink
handle long documents in squad qa datasource and models (facebookrese…
Browse files Browse the repository at this point in the history
…arch#975)

Summary:
Pull Request resolved: facebookresearch#975

BERT encoder has a maximum sequence length, which means long paragraphs cat get cut off when training / evaluating QA models.  In the original Squad dataset, it's very rare to have a paragraph and an answer span that will get truncated, however it can be more common in other datasets.

This diff implements chunking for long paragraphs such that:
- each chunk is fixed size in terms of character length
- a minimum overlap can be specified as a fraction of chunk size

In metric reporter, we aggregate by sample id and return the highest scoring span.

To share chunking logic between SquadDataSource and SquadTSVDataSource, I merged them into one class, with the added benefit that we can now train on a .json file and test on a tsv file and visa versa.

Differential Revision: D17350243

fbshipit-source-id: fb7cfd2b40972168e24dba597bd8d0a812472f3f
  • Loading branch information
borguz authored and facebook-github-bot committed Sep 20, 2019
1 parent 7c69b97 commit e375d1b
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 59 deletions.
173 changes: 129 additions & 44 deletions pytext/data/sources/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,133 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import json
import math
from typing import List, Optional

from pytext.data.sources.data_source import DataSource, generator_property
from pytext.data.sources.tsv import TSVDataSource


def unflatten(fname, ignore_impossible):
def _shift_answers(orig_starts, piece_start, piece_end):
# Re-align answer index for each piece when we split a long document.
answer_starts = []
has_answer = False
for start in orig_starts:
if start >= piece_start and start < piece_end:
answer_starts.append(start - piece_start)
has_answer = True
return answer_starts, has_answer


def _split_document(
id,
doc,
question,
answers,
answer_starts,
has_answer,
ignore_impossible,
max_character_length,
min_overlap,
):
pieces = []
min_overlap = math.floor(max_character_length * min_overlap)
if has_answer or not ignore_impossible:
n_pieces = 1 + math.ceil(
max(0, len(doc) - max_character_length)
/ (max_character_length - min_overlap)
)
overlap = (
math.floor((n_pieces * max_character_length - len(doc)) / (n_pieces - 1))
if n_pieces > 1
else 0
)
for n in range(n_pieces):
start, end = (
n * (max_character_length - overlap),
(n + 1) * (max_character_length - overlap) + overlap,
)
answer_starts, piece_has_answer = _shift_answers(answer_starts, start, end)
pieces.append(
{
"id": id,
"doc": doc[start:end],
"question": question,
"answers": answers,
"answer_starts": answer_starts,
"has_answer": str(has_answer and piece_has_answer),
}
)
return pieces


def process_squad_json(fname, ignore_impossible, max_character_length, min_overlap):
if not fname:
return
with open(fname) as file:
dump = json.load(file)

id = 0
for article in dump["data"]:
for paragraph in article["paragraphs"]:
doc = paragraph["context"]
for question in paragraph["qas"]:
has_answer = not question.get("is_impossible", False)
if has_answer or not ignore_impossible:
answers = (
question["answers"]
if has_answer
else question["plausible_answers"]
)
yield {
"doc": doc,
"question": question["question"],
"answers": [answer["text"] for answer in answers],
"answer_starts": [int(ans["answer_start"]) for ans in answers],
"has_answer": str(has_answer),
}
answers = (
question["answers"] if has_answer else question["plausible_answers"]
)
question = question["question"]
answer_texts = [answer["text"] for answer in answers]
answer_starts = [int(answer["answer_start"]) for answer in answers]
for piece_dict in _split_document(
id,
doc,
question,
answer_texts,
answer_starts,
has_answer,
ignore_impossible,
max_character_length,
min_overlap,
):
yield piece_dict
id += 1


def process_squad_tsv(fname, ignore_impossible, max_character_length, min_overlap):
if not fname:
print(f"Empty file name!")
return

with open(fname) as file:
for id, line in enumerate(file):
doc, question, answers, answer_starts, has_answer = line.rstrip().split(
"\t"
)
answers = json.loads(answers)
answer_starts = json.loads(answer_starts)
for piece_dict in _split_document(
id,
doc,
question,
answers,
answer_starts,
has_answer == "True",
ignore_impossible,
max_character_length,
min_overlap,
):
yield piece_dict


def process_squad(fname, ignore_impossible, max_character_length, min_overlap=0.1):
if fname.split(".")[-1] == "json":
return process_squad_json(
fname, ignore_impossible, max_character_length, min_overlap
)
else:
return process_squad_tsv(
fname, ignore_impossible, max_character_length, min_overlap
)


class SquadDataSource(DataSource):
Expand All @@ -45,6 +142,8 @@ class Config(DataSource.Config):
test_filename: Optional[str] = "dev-v2.0.json"
eval_filename: Optional[str] = "dev-v2.0.json"
ignore_impossible: bool = True
max_character_length: int = 2000
min_overlap: float = 0.1 # Expressed as a fraction of the max_character_length.

@classmethod
def from_config(cls, config: Config, schema=None):
Expand All @@ -53,6 +152,8 @@ def from_config(cls, config: Config, schema=None):
config.test_filename,
config.eval_filename,
config.ignore_impossible,
config.max_character_length,
config.min_overlap,
)

def __init__(
Expand All @@ -61,8 +162,11 @@ def __init__(
test_filename=None,
eval_filename=None,
ignore_impossible=Config.ignore_impossible,
max_character_length=Config.max_character_length,
min_overlap=Config.min_overlap,
):
schema = {
"id": int,
"doc": str,
"question": str,
"answers": List[str],
Expand All @@ -75,41 +179,22 @@ def __init__(
self.test_filename = test_filename
self.eval_filename = eval_filename
self.ignore_impossible = ignore_impossible
self.max_character_length = max_character_length
self.min_overlap = min_overlap

def process_file(self, fname):
return process_squad(
fname, self.ignore_impossible, self.max_character_length, self.min_overlap
)

@generator_property
def train(self):
return unflatten(self.train_filename, self.ignore_impossible)
return self.process_file(self.train_filename)

@generator_property
def test(self):
return unflatten(self.test_filename, self.ignore_impossible)
return self.process_file(self.test_filename)

@generator_property
def eval(self):
return unflatten(self.eval_filename, self.ignore_impossible)


class SquadTSVDataSource(TSVDataSource):
"""
Squad-like data passed in TSV format.
Will return tuples of (doc, question, answer, answer_start, has_answer)
"""

class Config(TSVDataSource.Config):
field_names: List[str] = [
"doc",
"question",
"answers",
"answer_starts",
"has_answer",
]

def __init__(self, **kwargs):
kwargs["schema"] = {
"doc": str,
"question": str,
"answers": List[str],
"answer_starts": List[int],
"has_answer": str,
}
super().__init__(**kwargs)
return self.process_file(self.eval_filename)
55 changes: 52 additions & 3 deletions pytext/metric_reporters/squad_metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from collections import Counter
from typing import Dict, List

from pytext.common.constants import RawExampleFieldName, Stage
import numpy as np
from pytext.common.constants import Stage
from pytext.metric_reporters.channel import Channel, ConsoleChannel, FileChannel
from pytext.metric_reporters.metric_reporter import MetricReporter
from pytext.metrics.squad_metrics import SquadMetrics
Expand Down Expand Up @@ -37,7 +38,7 @@ def gen_content(self, metrics, loss, preds, targets, scores, contexts, *args):
start_pos_scores, end_pos_scores, has_answer_scores = scores
for i in range(len(pred_answers)):
yield [
contexts[RawExampleFieldName.ROW_INDEX][i],
contexts[SquadMetricReporter.ROW_INDEX][i],
contexts[SquadMetricReporter.QUES_COLUMN][i],
contexts[SquadMetricReporter.DOC_COLUMN][i],
pred_answers[i],
Expand All @@ -58,7 +59,7 @@ class SquadMetricReporter(MetricReporter):
QUES_COLUMN = "question"
ANSWERS_COLUMN = "answers"
DOC_COLUMN = "doc"
ROW_INDEX = "row_index"
ROW_INDEX = "id"

class Config(MetricReporter.Config):
n_best_size: int = 5
Expand Down Expand Up @@ -177,6 +178,54 @@ def batch_context(self, raw_batch, batch):
return context

def calculate_metric(self):
all_rows = zip(
self.all_context[self.ROW_INDEX],
self.all_context[self.ANSWERS_COLUMN],
self.all_context[self.QUES_COLUMN],
self.all_context[self.DOC_COLUMN],
self.all_pred_answers,
self.all_start_pos_preds,
self.all_end_pos_preds,
self.all_has_answer_preds,
self.all_start_pos_targets,
self.all_end_pos_targets,
self.all_has_answer_targets,
self.all_start_pos_scores,
self.all_end_pos_scores,
self.all_has_answer_scores,
)

all_rows_dict = {}
for row in all_rows:
try:
all_rows_dict[row[0]].append(row)
except KeyError:
all_rows_dict[row[0]] = [row]

all_rows = []
for rows in all_rows_dict.values():
argmax = np.argmax([row[11] + row[12] for row in rows])
all_rows.append(rows[argmax])

sorted(all_rows, key=lambda x: int(x[0]))

(
self.all_context[self.ROW_INDEX],
self.all_context[self.ANSWERS_COLUMN],
self.all_context[self.QUES_COLUMN],
self.all_context[self.DOC_COLUMN],
self.all_pred_answers,
self.all_start_pos_preds,
self.all_end_pos_preds,
self.all_has_answer_preds,
self.all_start_pos_targets,
self.all_end_pos_targets,
self.all_has_answer_targets,
self.all_start_pos_scores,
self.all_end_pos_scores,
self.all_has_answer_scores,
) = zip(*all_rows)

exact_matches, count = self._compute_exact_matches(
self.all_pred_answers,
self.all_context[self.ANSWERS_COLUMN],
Expand Down
Loading

0 comments on commit e375d1b

Please sign in to comment.