-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[examples] Add trainer support for question-answering (#4829)
* add SquadDataset * add DataCollatorForQuestionAnswering * update __init__ * add run_squad with trainer * add DataCollatorForQuestionAnswering in __init__ * pass data_collator to trainer * doc tweak * Update run_squad_trainer.py * Update __init__.py * Update __init__.py Co-authored-by: Julien Chaumond <chaumond@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
- Loading branch information
1 parent
fbd8792
commit e49393c
Showing
5 changed files
with
362 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | ||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" Fine-tuning the library models for question-answering.""" | ||
|
||
|
||
import logging | ||
import os | ||
import sys | ||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
|
||
from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer, HfArgumentParser, SquadDataset | ||
from transformers import SquadDataTrainingArguments as DataTrainingArguments | ||
from transformers import Trainer, TrainingArguments | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class ModelArguments: | ||
""" | ||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. | ||
""" | ||
|
||
model_name_or_path: str = field( | ||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} | ||
) | ||
config_name: Optional[str] = field( | ||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} | ||
) | ||
tokenizer_name: Optional[str] = field( | ||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} | ||
) | ||
use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."}) | ||
# If you want to tweak more attributes on your tokenizer, you should do it in a distinct script, | ||
# or just modify its tokenizer_config.json. | ||
cache_dir: Optional[str] = field( | ||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} | ||
) | ||
|
||
|
||
def main(): | ||
# See all possible arguments in src/transformers/training_args.py | ||
# or by passing the --help flag to this script. | ||
# We now keep distinct sets of args, for a cleaner separation of concerns. | ||
|
||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) | ||
|
||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): | ||
# If we pass only one argument to the script and it's the path to a json file, | ||
# let's parse it to get our arguments. | ||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) | ||
else: | ||
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | ||
|
||
if ( | ||
os.path.exists(training_args.output_dir) | ||
and os.listdir(training_args.output_dir) | ||
and training_args.do_train | ||
and not training_args.overwrite_output_dir | ||
): | ||
raise ValueError( | ||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." | ||
) | ||
|
||
# Setup logging | ||
logging.basicConfig( | ||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||
datefmt="%m/%d/%Y %H:%M:%S", | ||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, | ||
) | ||
logger.warning( | ||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", | ||
training_args.local_rank, | ||
training_args.device, | ||
training_args.n_gpu, | ||
bool(training_args.local_rank != -1), | ||
training_args.fp16, | ||
) | ||
logger.info("Training/evaluation parameters %s", training_args) | ||
|
||
# Prepare Question-Answering task | ||
# Load pretrained model and tokenizer | ||
# | ||
# Distributed training: | ||
# The .from_pretrained methods guarantee that only one local process can concurrently | ||
# download model & vocab. | ||
|
||
config = AutoConfig.from_pretrained( | ||
model_args.config_name if model_args.config_name else model_args.model_name_or_path, | ||
cache_dir=model_args.cache_dir, | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, | ||
cache_dir=model_args.cache_dir, | ||
) | ||
model = AutoModelForQuestionAnswering.from_pretrained( | ||
model_args.model_name_or_path, | ||
from_tf=bool(".ckpt" in model_args.model_name_or_path), | ||
config=config, | ||
cache_dir=model_args.cache_dir, | ||
) | ||
|
||
# Get datasets | ||
is_language_sensitive = hasattr(model.config, "lang2id") | ||
train_dataset = ( | ||
SquadDataset( | ||
data_args, tokenizer=tokenizer, is_language_sensitive=is_language_sensitive, cache_dir=model_args.cache_dir | ||
) | ||
if training_args.do_train | ||
else None | ||
) | ||
eval_dataset = ( | ||
SquadDataset( | ||
data_args, | ||
tokenizer=tokenizer, | ||
mode="dev", | ||
is_language_sensitive=is_language_sensitive, | ||
cache_dir=model_args.cache_dir, | ||
) | ||
if training_args.do_eval | ||
else None | ||
) | ||
|
||
# Initialize our Trainer | ||
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,) | ||
|
||
# Training | ||
if training_args.do_train: | ||
trainer.train( | ||
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None | ||
) | ||
trainer.save_model() | ||
# For convenience, we also re-save the tokenizer to the same directory, | ||
# so that you can share your model easily on huggingface.co/models =) | ||
if trainer.is_world_master(): | ||
tokenizer.save_pretrained(training_args.output_dir) | ||
|
||
|
||
def _mp_fn(index): | ||
# For xla_spawn (TPUs) | ||
main() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import logging | ||
import os | ||
import time | ||
from dataclasses import dataclass, field | ||
from enum import Enum | ||
from typing import Dict, List, Optional, Union | ||
|
||
import torch | ||
from filelock import FileLock | ||
from torch.utils.data.dataset import Dataset | ||
|
||
from ...modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING | ||
from ...tokenization_utils import PreTrainedTokenizer | ||
from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()) | ||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) | ||
|
||
|
||
@dataclass | ||
class SquadDataTrainingArguments: | ||
""" | ||
Arguments pertaining to what data we are going to input our model for training and eval. | ||
""" | ||
|
||
model_type: str = field( | ||
default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)} | ||
) | ||
data_dir: str = field( | ||
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."} | ||
) | ||
max_seq_length: int = field( | ||
default=128, | ||
metadata={ | ||
"help": "The maximum total input sequence length after tokenization. Sequences longer " | ||
"than this will be truncated, sequences shorter will be padded." | ||
}, | ||
) | ||
doc_stride: int = field( | ||
default=128, | ||
metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."}, | ||
) | ||
max_query_length: int = field( | ||
default=64, | ||
metadata={ | ||
"help": "The maximum number of tokens for the question. Questions longer than this will " | ||
"be truncated to this length." | ||
}, | ||
) | ||
max_answer_length: int = field( | ||
default=30, | ||
metadata={ | ||
"help": "The maximum length of an answer that can be generated. This is needed because the start " | ||
"and end predictions are not conditioned on one another." | ||
}, | ||
) | ||
overwrite_cache: bool = field( | ||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} | ||
) | ||
version_2_with_negative: bool = field( | ||
default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."} | ||
) | ||
null_score_diff_threshold: float = field( | ||
default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."} | ||
) | ||
n_best_size: int = field( | ||
default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."} | ||
) | ||
lang_id: int = field( | ||
default=0, | ||
metadata={ | ||
"help": "language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)" | ||
}, | ||
) | ||
threads: int = field(default=1, metadata={"help": "multiple threads for converting example to features"}) | ||
|
||
|
||
class Split(Enum): | ||
train = "train" | ||
dev = "dev" | ||
|
||
|
||
class SquadDataset(Dataset): | ||
""" | ||
This will be superseded by a framework-agnostic approach | ||
soon. | ||
""" | ||
|
||
args: SquadDataTrainingArguments | ||
features: List[SquadFeatures] | ||
mode: Split | ||
is_language_sensitive: bool | ||
|
||
def __init__( | ||
self, | ||
args: SquadDataTrainingArguments, | ||
tokenizer: PreTrainedTokenizer, | ||
limit_length: Optional[int] = None, | ||
mode: Union[str, Split] = Split.train, | ||
is_language_sensitive: Optional[bool] = False, | ||
cache_dir: Optional[str] = None, | ||
): | ||
self.args = args | ||
self.is_language_sensitive = is_language_sensitive | ||
self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor() | ||
if isinstance(mode, str): | ||
try: | ||
mode = Split[mode] | ||
except KeyError: | ||
raise KeyError("mode is not a valid split name") | ||
self.mode = mode | ||
# Load data features from cache or dataset file | ||
cached_features_file = os.path.join( | ||
cache_dir if cache_dir is not None else args.data_dir, | ||
"cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(args.max_seq_length),), | ||
) | ||
|
||
# Make sure only the first process in distributed training processes the dataset, | ||
# and the others will use the cache. | ||
lock_path = cached_features_file + ".lock" | ||
with FileLock(lock_path): | ||
if os.path.exists(cached_features_file) and not args.overwrite_cache: | ||
start = time.time() | ||
self.features = torch.load(cached_features_file) | ||
logger.info( | ||
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start | ||
) | ||
else: | ||
if mode == Split.dev: | ||
examples = self.processor.get_dev_examples(args.data_dir) | ||
else: | ||
examples = self.processor.get_train_examples(args.data_dir) | ||
|
||
self.features = squad_convert_examples_to_features( | ||
examples=examples, | ||
tokenizer=tokenizer, | ||
max_seq_length=args.max_seq_length, | ||
doc_stride=args.doc_stride, | ||
max_query_length=args.max_query_length, | ||
is_training=mode == Split.train, | ||
threads=args.threads, | ||
) | ||
|
||
start = time.time() | ||
torch.save(self.features, cached_features_file) | ||
# ^ This seems to take a lot of time so I want to investigate why and how we can improve. | ||
logger.info( | ||
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start | ||
) | ||
|
||
def __len__(self): | ||
return len(self.features) | ||
|
||
def __getitem__(self, i) -> Dict[str, torch.Tensor]: | ||
# Convert to Tensors and build dataset | ||
feature = self.features[i] | ||
|
||
input_ids = torch.tensor(feature.input_ids, dtype=torch.long) | ||
attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long) | ||
token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long) | ||
cls_index = torch.tensor(feature.cls_index, dtype=torch.long) | ||
p_mask = torch.tensor(feature.p_mask, dtype=torch.float) | ||
is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float) | ||
|
||
inputs = { | ||
"input_ids": input_ids, | ||
"attention_mask": attention_mask, | ||
"token_type_ids": token_type_ids, | ||
} | ||
|
||
if self.args.model_type in ["xlm", "roberta", "distilbert", "camembert"]: | ||
del inputs["token_type_ids"] | ||
|
||
if self.args.model_type in ["xlnet", "xlm"]: | ||
inputs.update({"cls_index": cls_index, "p_mask": p_mask}) | ||
if self.args.version_2_with_negative: | ||
inputs.update({"is_impossible": is_impossible}) | ||
if self.is_language_sensitive: | ||
inputs.update({"langs": (torch.ones(input_ids.shape, dtype=torch.int64) * self.args.lang_id)}) | ||
|
||
if self.mode == Split.train: | ||
start_positions = torch.tensor(feature.start_position, dtype=torch.long) | ||
end_positions = torch.tensor(feature.end_position, dtype=torch.long) | ||
inputs.update({"start_positions": start_positions, "end_positions": end_positions}) | ||
|
||
return inputs |