Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Seq2Seq QA example script to use SQuAD metric. #14335

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 80 additions & 40 deletions examples/pytorch/question-answering/run_seq2seq_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,20 @@
from typing import List, Optional, Tuple

import datasets
import nltk
import numpy as np
from datasets import load_dataset, load_metric

import transformers
from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed,
)
from transformers.trainer_utils import EvalPrediction, get_last_checkpoint
from transformers.trainer_utils import EvalLoopOutput, EvalPrediction, get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

Expand Down Expand Up @@ -411,7 +409,7 @@ def main():
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

def preprocess_sqaud_batch(
def preprocess_squad_batch(
examples,
question_column: str,
context_column: str,
Expand All @@ -422,14 +420,14 @@ def preprocess_sqaud_batch(
answers = examples[answer_column]

def generate_input(_question, _context):
return " ".join(["question:", _question, "context:", _context])
return " ".join(["question:", _question.lstrip(), "context:", _context.lstrip()])
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved

inputs = [generate_input(question, context) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets

def preprocess_function(examples):
inputs, targets = preprocess_sqaud_batch(examples, question_column, context_column, answer_column)
inputs, targets = preprocess_squad_batch(examples, question_column, context_column, answer_column)

model_inputs = tokenizer(inputs, max_length=max_seq_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
Expand All @@ -446,6 +444,45 @@ def preprocess_function(examples):
model_inputs["labels"] = labels["input_ids"]
return model_inputs

# Validation preprocessing
def preprocess_validation_function(examples):
inputs, targets = preprocess_squad_batch(examples, question_column, context_column, answer_column)

model_inputs = tokenizer(
inputs,
max_length=max_seq_length,
padding=padding,
truncation=True,
return_overflowing_tokens=True,
return_offsets_mapping=True,
)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_answer_length, padding=padding, truncation=True)

# Since one example might give us several features if it has a long context, we need a map from a feature to
# its corresponding example. This key gives us just that.
sample_mapping = model_inputs.pop("overflow_to_sample_mapping")

# For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
# corresponding example_id and we will store the offset mappings.
model_inputs["example_id"] = []

for i in range(len(model_inputs["input_ids"])):
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = sample_mapping[i]
model_inputs["example_id"].append(examples["id"][sample_index])

# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]

model_inputs["labels"] = labels["input_ids"]
return model_inputs

if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
Expand Down Expand Up @@ -477,7 +514,7 @@ def preprocess_function(examples):
# Validation Feature Creation
with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_examples.map(
preprocess_function,
preprocess_validation_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
Expand All @@ -498,7 +535,7 @@ def preprocess_function(examples):
# Predict Feature Creation
with training_args.main_process_first(desc="prediction dataset map pre-processing"):
predict_dataset = predict_examples.map(
preprocess_function,
preprocess_validation_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
Expand All @@ -518,50 +555,53 @@ def preprocess_function(examples):
pad_to_multiple_of=8 if training_args.fp16 else None,
)

# Post-processing:
def postprocess_text(preds, labels):
preds = [" ".join(pred) for pred in preds]
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")

# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)

return preds, labels

metric = load_metric("rouge")

def compute_metrics(eval_preds: EvalPrediction):
preds, labels = eval_preds
# Post-processing:
def post_processing_function(
examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval"
):
# Decode the predicted tokens.
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = [tokenizer.batch_decode(pred, skip_special_tokens=True) for pred in preds]
if data_args.ignore_pad_token_for_loss:
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

# Some simple post-processing
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
# Extract a few results from ROUGE
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
result = {k: round(v, 4) for k, v in result.items()}
return result
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

# Build a map example to its corresponding features.
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)}
predictions = {}
# Let's loop over all the examples!
for example_index, example in enumerate(examples):
# This is the index of the feature associated to the current example.
feature_index = feature_per_example[example_index]
predictions[example["id"]] = decoded_preds[feature_index]

# Format the result to the format the metric expects.
if data_args.version_2_with_negative:
formatted_predictions = [
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
]
else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]

references = [{"id": ex["id"], "answers": ex[answer_column]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)

# Initialize our Trainer
trainer = Seq2SeqTrainer(
trainer = QuestionAnsweringSeq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
eval_examples=eval_examples if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
post_process_function=post_processing_function,
)

# Training
Expand Down
120 changes: 120 additions & 0 deletions examples/pytorch/question-answering/trainer_seq2seq_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A subclass of `Trainer` specific to Question-Answering tasks
"""
from typing import Dict, List, Optional

from torch.utils.data import Dataset

from transformers import Seq2SeqTrainer, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput


if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met


class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):
super().__init__(*args, **kwargs)
self.eval_examples = eval_examples
self.post_process_function = post_process_function

# def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"):
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
eval_examples=None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
max_length: Optional[int] = None,
num_beams: Optional[int] = None,
) -> Dict[str, float]:
self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams

eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
eval_dataloader = self.get_eval_dataloader(eval_dataset)
eval_examples = self.eval_examples if eval_examples is None else eval_examples

# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try:
output = eval_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
)
finally:
self.compute_metrics = compute_metrics

if self.post_process_function is not None and self.compute_metrics is not None:
eval_preds = self.post_process_function(eval_examples, eval_dataset, output)
metrics = self.compute_metrics(eval_preds)

# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

self.log(metrics)
else:
metrics = {}

if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())

self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
return metrics

def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"):
predict_dataloader = self.get_test_dataloader(predict_dataset)

# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try:
output = eval_loop(
predict_dataloader,
description="Prediction",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
)
finally:
self.compute_metrics = compute_metrics

if self.post_process_function is None or self.compute_metrics is None:
return output

predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
metrics = self.compute_metrics(predictions)

# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
6 changes: 2 additions & 4 deletions examples/pytorch/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,8 @@ def test_run_squad_seq2seq(self):
with patch.object(sys, "argv", testargs):
run_squad_seq2seq.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_rouge1"], 10)
self.assertGreaterEqual(result["eval_rouge2"], 10)
self.assertGreaterEqual(result["eval_rougeL"], 10)
self.assertGreaterEqual(result["eval_rougeLsum"], 10)
self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_exact"], 30)

def test_run_swag(self):
stream_handler = logging.StreamHandler(sys.stdout)
Expand Down