diff --git a/examples/pytorch/question-answering/run_seq2seq_qa.py b/examples/pytorch/question-answering/run_seq2seq_qa.py index e64ab4d355e208..19da4d30e99c25 100644 --- a/examples/pytorch/question-answering/run_seq2seq_qa.py +++ b/examples/pytorch/question-answering/run_seq2seq_qa.py @@ -25,22 +25,20 @@ from typing import List, Optional, Tuple import datasets -import nltk -import numpy as np from datasets import load_dataset, load_metric import transformers +from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer from transformers import ( AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, HfArgumentParser, - Seq2SeqTrainer, Seq2SeqTrainingArguments, set_seed, ) -from transformers.trainer_utils import EvalPrediction, get_last_checkpoint +from transformers.trainer_utils import EvalLoopOutput, EvalPrediction, get_last_checkpoint from transformers.utils import check_min_version from transformers.utils.versions import require_version @@ -411,7 +409,7 @@ def main(): ) max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) - def preprocess_sqaud_batch( + def preprocess_squad_batch( examples, question_column: str, context_column: str, @@ -422,14 +420,14 @@ def preprocess_sqaud_batch( answers = examples[answer_column] def generate_input(_question, _context): - return " ".join(["question:", _question, "context:", _context]) + return " ".join(["question:", _question.lstrip(), "context:", _context.lstrip()]) inputs = [generate_input(question, context) for question, context in zip(questions, contexts)] targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] return inputs, targets def preprocess_function(examples): - inputs, targets = preprocess_sqaud_batch(examples, question_column, context_column, answer_column) + inputs, targets = preprocess_squad_batch(examples, question_column, context_column, answer_column) model_inputs = tokenizer(inputs, max_length=max_seq_length, padding=padding, truncation=True) # Setup the tokenizer for targets @@ -446,6 +444,45 @@ def preprocess_function(examples): model_inputs["labels"] = labels["input_ids"] return model_inputs + # Validation preprocessing + def preprocess_validation_function(examples): + inputs, targets = preprocess_squad_batch(examples, question_column, context_column, answer_column) + + model_inputs = tokenizer( + inputs, + max_length=max_seq_length, + padding=padding, + truncation=True, + return_overflowing_tokens=True, + return_offsets_mapping=True, + ) + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + labels = tokenizer(targets, max_length=max_answer_length, padding=padding, truncation=True) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = model_inputs.pop("overflow_to_sample_mapping") + + # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the + # corresponding example_id and we will store the offset mappings. + model_inputs["example_id"] = [] + + for i in range(len(model_inputs["input_ids"])): + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + model_inputs["example_id"].append(examples["id"][sample_index]) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["input_ids"] + return model_inputs + if training_args.do_train: if "train" not in raw_datasets: raise ValueError("--do_train requires a train dataset") @@ -477,7 +514,7 @@ def preprocess_function(examples): # Validation Feature Creation with training_args.main_process_first(desc="validation dataset map pre-processing"): eval_dataset = eval_examples.map( - preprocess_function, + preprocess_validation_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, @@ -498,7 +535,7 @@ def preprocess_function(examples): # Predict Feature Creation with training_args.main_process_first(desc="prediction dataset map pre-processing"): predict_dataset = predict_examples.map( - preprocess_function, + preprocess_validation_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, @@ -518,50 +555,53 @@ def preprocess_function(examples): pad_to_multiple_of=8 if training_args.fp16 else None, ) - # Post-processing: - def postprocess_text(preds, labels): - preds = [" ".join(pred) for pred in preds] - preds = [pred.strip() for pred in preds] - labels = [label.strip() for label in labels] + metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") - # rougeLSum expects newline after each sentence - preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] - labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + def compute_metrics(p: EvalPrediction): + return metric.compute(predictions=p.predictions, references=p.label_ids) - return preds, labels - - metric = load_metric("rouge") - - def compute_metrics(eval_preds: EvalPrediction): - preds, labels = eval_preds + # Post-processing: + def post_processing_function( + examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval" + ): + # Decode the predicted tokens. + preds = outputs.predictions if isinstance(preds, tuple): preds = preds[0] - decoded_preds = [tokenizer.batch_decode(pred, skip_special_tokens=True) for pred in preds] - if data_args.ignore_pad_token_for_loss: - # Replace -100 in the labels as we can't decode them. - labels = np.where(labels != -100, labels, tokenizer.pad_token_id) - decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) - - # Some simple post-processing - decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) - result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) - # Extract a few results from ROUGE - result = {key: value.mid.fmeasure * 100 for key, value in result.items()} - - prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] - result["gen_len"] = np.mean(prediction_lens) - result = {k: round(v, 4) for k, v in result.items()} - return result + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)} + predictions = {} + # Let's loop over all the examples! + for example_index, example in enumerate(examples): + # This is the index of the feature associated to the current example. + feature_index = feature_per_example[example_index] + predictions[example["id"]] = decoded_preds[feature_index] + + # Format the result to the format the metric expects. + if data_args.version_2_with_negative: + formatted_predictions = [ + {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() + ] + else: + formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] + + references = [{"id": ex["id"], "answers": ex[answer_column]} for ex in examples] + return EvalPrediction(predictions=formatted_predictions, label_ids=references) # Initialize our Trainer - trainer = Seq2SeqTrainer( + trainer = QuestionAnsweringSeq2SeqTrainer( model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, + eval_examples=eval_examples if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, + post_process_function=post_processing_function, ) # Training diff --git a/examples/pytorch/question-answering/trainer_seq2seq_qa.py b/examples/pytorch/question-answering/trainer_seq2seq_qa.py new file mode 100644 index 00000000000000..ac260dadbc338f --- /dev/null +++ b/examples/pytorch/question-answering/trainer_seq2seq_qa.py @@ -0,0 +1,120 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A subclass of `Trainer` specific to Question-Answering tasks +""" +from typing import Dict, List, Optional + +from torch.utils.data import Dataset + +from transformers import Seq2SeqTrainer, is_torch_tpu_available +from transformers.trainer_utils import PredictionOutput + + +if is_torch_tpu_available(): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + + +class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): + def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs): + super().__init__(*args, **kwargs) + self.eval_examples = eval_examples + self.post_process_function = post_process_function + + # def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"): + def evaluate( + self, + eval_dataset: Optional[Dataset] = None, + eval_examples=None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + max_length: Optional[int] = None, + num_beams: Optional[int] = None, + ) -> Dict[str, float]: + self._max_length = max_length if max_length is not None else self.args.generation_max_length + self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams + + eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset + eval_dataloader = self.get_eval_dataloader(eval_dataset) + eval_examples = self.eval_examples if eval_examples is None else eval_examples + + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + self.compute_metrics = None + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + try: + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + ) + finally: + self.compute_metrics = compute_metrics + + if self.post_process_function is not None and self.compute_metrics is not None: + eval_preds = self.post_process_function(eval_examples, eval_dataset, output) + metrics = self.compute_metrics(eval_preds) + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + self.log(metrics) + else: + metrics = {} + + if self.args.tpu_metrics_debug or self.args.debug: + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) + return metrics + + def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"): + predict_dataloader = self.get_test_dataloader(predict_dataset) + + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + self.compute_metrics = None + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + try: + output = eval_loop( + predict_dataloader, + description="Prediction", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + ) + finally: + self.compute_metrics = compute_metrics + + if self.post_process_function is None or self.compute_metrics is None: + return output + + predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict") + metrics = self.compute_metrics(predictions) + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics) diff --git a/examples/pytorch/test_examples.py b/examples/pytorch/test_examples.py index 6c9f7e1cd751e7..de045630d8d3a7 100644 --- a/examples/pytorch/test_examples.py +++ b/examples/pytorch/test_examples.py @@ -274,10 +274,8 @@ def test_run_squad_seq2seq(self): with patch.object(sys, "argv", testargs): run_squad_seq2seq.main() result = get_results(tmp_dir) - self.assertGreaterEqual(result["eval_rouge1"], 10) - self.assertGreaterEqual(result["eval_rouge2"], 10) - self.assertGreaterEqual(result["eval_rougeL"], 10) - self.assertGreaterEqual(result["eval_rougeLsum"], 10) + self.assertGreaterEqual(result["eval_f1"], 30) + self.assertGreaterEqual(result["eval_exact"], 30) def test_run_swag(self): stream_handler = logging.StreamHandler(sys.stdout)