From 02f51b87b905a84d3446c18158218fbf263e5ce1 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Sun, 7 Jun 2020 08:53:55 +0530 Subject: [PATCH 01/10] add SquadDataset --- src/transformers/data/datasets/squad.py | 189 ++++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 src/transformers/data/datasets/squad.py diff --git a/src/transformers/data/datasets/squad.py b/src/transformers/data/datasets/squad.py new file mode 100644 index 00000000000000..7e2028c25abc25 --- /dev/null +++ b/src/transformers/data/datasets/squad.py @@ -0,0 +1,189 @@ +import logging +import os +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Union + +import torch +from filelock import FileLock +from torch.utils.data.dataset import Dataset + +from ...modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING +from ...tokenization_utils import PreTrainedTokenizer +from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features + + +logger = logging.getLogger(__name__) + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class SquadDataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + model_type: str = field( + default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)} + ) + data_dir: str = field( + default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."} + ) + max_seq_length: int = field( + default=128, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + doc_stride: int = field( + default=128, + metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."}, + ) + max_query_length: int = field( + default=64, + metadata={ + "help": "The maximum number of tokens for the question. Questions longer than this will " + "be truncated to this length." + }, + ) + max_answer_length: int = field( + default=30, + metadata={ + "help": "The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + version_2_with_negative: bool = field( + default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."} + ) + null_score_diff_threshold: float = field( + default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."} + ) + n_best_size: int = field( + default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."} + ) + lang_id: int = field( + default=0, + metadata={ + "help": "language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)" + }, + ) + threads: int = field(default=1, metadata={"help": "multiple threads for converting example to features"}) + + +class Split(Enum): + train = "train" + dev = "dev" + + +class SquadDataset(Dataset): + """ + This will be superseded by a framework-agnostic approach + soon. + """ + + args: SquadDataTrainingArguments + features: List[SquadFeatures] + mode: Split + is_language_sensitive: bool + + def __init__( + self, + args: SquadDataTrainingArguments, + tokenizer: PreTrainedTokenizer, + limit_length: Optional[int] = None, + mode: Union[str, Split] = Split.train, + is_language_sensitive: Optional[bool] = False, + cache_dir: Optional[str] = None, + ): + self.args = args + self.is_language_sensitive = is_language_sensitive + self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor() + if isinstance(mode, str): + try: + mode = Split[mode] + except KeyError: + raise KeyError("mode is not a valid split name") + self.mode = mode + # Load data features from cache or dataset file + cached_features_file = os.path.join( + cache_dir if cache_dir is not None else args.data_dir, + "cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(args.max_seq_length),), + ) + + # Make sure only the first process in distributed training processes the dataset, + # and the others will use the cache. + lock_path = cached_features_file + ".lock" + with FileLock(lock_path): + if os.path.exists(cached_features_file) and not args.overwrite_cache: + start = time.time() + self.features = torch.load(cached_features_file) + logger.info( + f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start + ) + else: + if mode == Split.dev: + examples = self.processor.get_dev_examples(args.data_dir) + else: + examples = self.processor.get_train_examples(args.data_dir) + + self.features = squad_convert_examples_to_features( + examples=examples, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + doc_stride=args.doc_stride, + max_query_length=args.max_query_length, + is_training=mode == Split.train, + threads=args.threads, + ) + + start = time.time() + torch.save(self.features, cached_features_file) + # ^ This seems to take a lot of time so I want to investigate why and how we can improve. + logger.info( + "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start + ) + + def __len__(self): + return len(self.features) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + # Convert to Tensors and build dataset + feature = self.features[i] + + input_ids = torch.tensor(feature.input_ids, dtype=torch.long) + attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long) + token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long) + cls_index = torch.tensor(feature.cls_index, dtype=torch.long) + p_mask = torch.tensor(feature.p_mask, dtype=torch.float) + is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float) + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + if self.args.model_type in ["xlm", "roberta", "distilbert", "camembert"]: + del inputs["token_type_ids"] + + if self.args.model_type in ["xlnet", "xlm"]: + inputs.update({"cls_index": cls_index, "p_mask": p_mask}) + if self.args.version_2_with_negative: + inputs.update({"is_impossible": is_impossible}) + if self.is_language_sensitive: + inputs.update({"langs": (torch.ones(input_ids.shape, dtype=torch.int64) * self.args.lang_id)}) + + if self.mode == Split.train: + start_positions = torch.tensor(feature.start_position, dtype=torch.long) + end_positions = torch.tensor(feature.end_position, dtype=torch.long) + inputs.update({"start_positions": start_positions, "end_positions": end_positions}) + + return inputs From 98bd259c8440211d6528bb8c314219b4e5a4fa49 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Sun, 7 Jun 2020 08:54:09 +0530 Subject: [PATCH 02/10] add DataCollatorForQuestionAnswering --- src/transformers/data/data_collator.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 7cd095651c2fdb..b5aa7e49cb7356 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -142,3 +142,21 @@ def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] # The rest of the time (10% of the time) we keep the masked input tokens unchanged return inputs, labels + + +@dataclass +class DataCollatorForQuestionAnswering(DataCollator): + """ + Data collator used for language modeling. + - collates batches of tensors + - preprocesses batches for question answering + """ + + def collate_batch(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + keys = batch[0].keys() + inputs = {} + + for key in keys: + inputs[key] = torch.stack([example[key] for example in batch]) + + return inputs From 39aa94b274bc84df9fb8199991b6c39f146dfaa0 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Sun, 7 Jun 2020 08:54:24 +0530 Subject: [PATCH 03/10] update __init__ --- src/transformers/__init__.py | 9 ++++++++- src/transformers/data/datasets/__init__.py | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7bc19ac9344767..862be112d9602a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -344,7 +344,14 @@ # Trainer from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction from .data.data_collator import DefaultDataCollator, DataCollator, DataCollatorForLanguageModeling - from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments + from .data.datasets import ( + GlueDataset, + TextDataset, + LineByLineTextDataset, + GlueDataTrainingArguments, + SquadDataset, + SquadDataTrainingArguments, + ) # Benchmarks from .benchmark import PyTorchBenchmark, PyTorchBenchmarkArguments diff --git a/src/transformers/data/datasets/__init__.py b/src/transformers/data/datasets/__init__.py index 74a2147bc5c3e4..ca2ab15e43fbeb 100644 --- a/src/transformers/data/datasets/__init__.py +++ b/src/transformers/data/datasets/__init__.py @@ -4,3 +4,4 @@ from .glue import GlueDataset, GlueDataTrainingArguments from .language_modeling import LineByLineTextDataset, TextDataset +from .squad import SquadDataset, SquadDataTrainingArguments From 4eaa555c41be3a47c115f5e6c2b1f5ea38a44ce1 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Sun, 7 Jun 2020 08:54:54 +0530 Subject: [PATCH 04/10] add run_squad with trainer --- .../question-answering/run_squad_trainer.py | 160 ++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 examples/question-answering/run_squad_trainer.py diff --git a/examples/question-answering/run_squad_trainer.py b/examples/question-answering/run_squad_trainer.py new file mode 100644 index 00000000000000..04c06acda186a9 --- /dev/null +++ b/examples/question-answering/run_squad_trainer.py @@ -0,0 +1,160 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Fine-tuning the library models for question-answering.""" + + +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer, HfArgumentParser, SquadDataset +from transformers import SquadDataTrainingArguments as DataTrainingArguments +from transformers import Trainer, TrainingArguments + + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."}) + # If you want to tweak more attributes on your tokenizer, you should do it in a distinct script, + # or just modify its tokenizer_config.json. + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + training_args.local_rank, + training_args.device, + training_args.n_gpu, + bool(training_args.local_rank != -1), + training_args.fp16, + ) + logger.info("Training/evaluation parameters %s", training_args) + + # Prepare Question-Answering task + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + ) + model = AutoModelForQuestionAnswering.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + ) + + # Get datasets + is_language_sensitive = hasattr(model.config, "lang2id") + train_dataset = ( + SquadDataset( + data_args, tokenizer=tokenizer, is_language_sensitive=is_language_sensitive, cache_dir=model_args.cache_dir + ) + if training_args.do_train + else None + ) + eval_dataset = ( + SquadDataset( + data_args, + tokenizer=tokenizer, + mode="dev", + is_language_sensitive=is_language_sensitive, + cache_dir=model_args.cache_dir, + ) + if training_args.do_eval + else None + ) + + # Initialize our Trainer + trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,) + + # Training + if training_args.do_train: + trainer.train( + model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None + ) + trainer.save_model() + # For convenience, we also re-save the tokenizer to the same directory, + # so that you can share your model easily on huggingface.co/models =) + if trainer.is_world_master(): + tokenizer.save_pretrained(training_args.output_dir) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() From 5527121d002d62d1729570d12309aed8abbb5545 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Sun, 7 Jun 2020 09:45:26 +0530 Subject: [PATCH 05/10] add DataCollatorForQuestionAnswering in __init__ --- src/transformers/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 862be112d9602a..4581ab303efbdf 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -343,7 +343,12 @@ # Trainer from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction - from .data.data_collator import DefaultDataCollator, DataCollator, DataCollatorForLanguageModeling + from .data.data_collator import ( + DefaultDataCollator, + DataCollator, + DataCollatorForLanguageModeling, + DataCollatorForQuestionAnswering, + ) from .data.datasets import ( GlueDataset, TextDataset, From 081b40caf5bcab100c101798c2bb976b33abca5f Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Sun, 7 Jun 2020 09:45:48 +0530 Subject: [PATCH 06/10] pass data_collator to trainer --- .../question-answering/run_squad_trainer.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/examples/question-answering/run_squad_trainer.py b/examples/question-answering/run_squad_trainer.py index 04c06acda186a9..2abdc7aebb1f77 100644 --- a/examples/question-answering/run_squad_trainer.py +++ b/examples/question-answering/run_squad_trainer.py @@ -22,7 +22,14 @@ from dataclasses import dataclass, field from typing import Optional -from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer, HfArgumentParser, SquadDataset +from transformers import ( + AutoConfig, + AutoModelForQuestionAnswering, + AutoTokenizer, + DataCollatorForQuestionAnswering, + HfArgumentParser, + SquadDataset, +) from transformers import SquadDataTrainingArguments as DataTrainingArguments from transformers import Trainer, TrainingArguments @@ -135,9 +142,16 @@ def main(): if training_args.do_eval else None ) + data_collator = DataCollatorForQuestionAnswering() # Initialize our Trainer - trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,) + trainer = Trainer( + model=model, + args=training_args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) # Training if training_args.do_train: From f54332f378f3ed50dcfa7c01829b0a291cac9750 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Tue, 16 Jun 2020 10:25:16 +0000 Subject: [PATCH 07/10] doc tweak --- examples/question-answering/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/question-answering/README.md b/examples/question-answering/README.md index d524957b1bbc2c..6c27a915c27e99 100644 --- a/examples/question-answering/README.md +++ b/examples/question-answering/README.md @@ -77,7 +77,7 @@ exact_match = 86.91 ``` This fine-tuned model is available as a checkpoint under the reference -`bert-large-uncased-whole-word-masking-finetuned-squad`. +[`bert-large-uncased-whole-word-masking-finetuned-squad`](https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad). #### Fine-tuning XLNet on SQuAD @@ -178,4 +178,4 @@ python run_tf_squad.py \ --optimizer_name adamw ``` -For the moment the evaluation is not available in the Tensorflow Trainer only the training. \ No newline at end of file +For the moment evaluation is not available in the Tensorflow Trainer only the training. \ No newline at end of file From 9bf2ce565ba2c89340f4139fe823c779a7d9c065 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 7 Jul 2020 08:39:09 -0400 Subject: [PATCH 08/10] Update run_squad_trainer.py --- .../question-answering/run_squad_trainer.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/examples/question-answering/run_squad_trainer.py b/examples/question-answering/run_squad_trainer.py index 2abdc7aebb1f77..04c06acda186a9 100644 --- a/examples/question-answering/run_squad_trainer.py +++ b/examples/question-answering/run_squad_trainer.py @@ -22,14 +22,7 @@ from dataclasses import dataclass, field from typing import Optional -from transformers import ( - AutoConfig, - AutoModelForQuestionAnswering, - AutoTokenizer, - DataCollatorForQuestionAnswering, - HfArgumentParser, - SquadDataset, -) +from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer, HfArgumentParser, SquadDataset from transformers import SquadDataTrainingArguments as DataTrainingArguments from transformers import Trainer, TrainingArguments @@ -142,16 +135,9 @@ def main(): if training_args.do_eval else None ) - data_collator = DataCollatorForQuestionAnswering() # Initialize our Trainer - trainer = Trainer( - model=model, - args=training_args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - ) + trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,) # Training if training_args.do_train: From 4a1307dca06205f8181b9658d528b2f727b1d91b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 7 Jul 2020 08:39:34 -0400 Subject: [PATCH 09/10] Update __init__.py --- src/transformers/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f623426ba4566c..db12e4546bf822 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -404,7 +404,6 @@ default_data_collator, DataCollator, DataCollatorForLanguageModeling, - DataCollatorForQuestionAnswering, ) from .data.datasets import ( GlueDataset, From 5497ae615a8ce526111eadba33b27bc6559ee799 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 7 Jul 2020 08:48:48 -0400 Subject: [PATCH 10/10] Update __init__.py --- src/transformers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index db12e4546bf822..56192e6f1a0fab 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -404,6 +404,7 @@ default_data_collator, DataCollator, DataCollatorForLanguageModeling, + DataCollatorForPermutationLanguageModeling, ) from .data.datasets import ( GlueDataset,