Skip to content

Commit

Permalink
[examples] Add trainer support for question-answering (#4829)
Browse files Browse the repository at this point in the history
* add SquadDataset

* add DataCollatorForQuestionAnswering

* update __init__

* add run_squad with  trainer

* add DataCollatorForQuestionAnswering in __init__

* pass data_collator to trainer

* doc tweak

* Update run_squad_trainer.py

* Update __init__.py

* Update __init__.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
3 people committed Jul 7, 2020
1 parent fbd8792 commit e49393c
Show file tree
Hide file tree
Showing 5 changed files with 362 additions and 4 deletions.
5 changes: 3 additions & 2 deletions examples/question-answering/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ exact_match = 86.91
```

This fine-tuned model is available as a checkpoint under the reference
`bert-large-uncased-whole-word-masking-finetuned-squad`.
[`bert-large-uncased-whole-word-masking-finetuned-squad`](https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad).

#### Fine-tuning XLNet on SQuAD

Expand Down Expand Up @@ -176,4 +176,5 @@ python run_tf_squad.py \
--doc_stride 128
```

For the moment the evaluation is not available in the Tensorflow Trainer only the training.

For the moment evaluation is not available in the Tensorflow Trainer only the training.
160 changes: 160 additions & 0 deletions examples/question-answering/run_squad_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Fine-tuning the library models for question-answering."""


import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer, HfArgumentParser, SquadDataset
from transformers import SquadDataTrainingArguments as DataTrainingArguments
from transformers import Trainer, TrainingArguments


logger = logging.getLogger(__name__)


@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""

model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."})
# If you want to tweak more attributes on your tokenizer, you should do it in a distinct script,
# or just modify its tokenizer_config.json.
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)

# Prepare Question-Answering task
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.

config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
model = AutoModelForQuestionAnswering.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)

# Get datasets
is_language_sensitive = hasattr(model.config, "lang2id")
train_dataset = (
SquadDataset(
data_args, tokenizer=tokenizer, is_language_sensitive=is_language_sensitive, cache_dir=model_args.cache_dir
)
if training_args.do_train
else None
)
eval_dataset = (
SquadDataset(
data_args,
tokenizer=tokenizer,
mode="dev",
is_language_sensitive=is_language_sensitive,
cache_dir=model_args.cache_dir,
)
if training_args.do_eval
else None
)

# Initialize our Trainer
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,)

# Training
if training_args.do_train:
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
)
trainer.save_model()
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
if trainer.is_world_master():
tokenizer.save_pretrained(training_args.output_dir)


def _mp_fn(index):
# For xla_spawn (TPUs)
main()


if __name__ == "__main__":
main()
11 changes: 9 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,14 +416,21 @@
)

# Trainer
from .trainer import Trainer, torch_distributed_zero_first
from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction
from .data.data_collator import (
default_data_collator,
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
)
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
from .data.datasets import (
GlueDataset,
TextDataset,
LineByLineTextDataset,
GlueDataTrainingArguments,
SquadDataset,
SquadDataTrainingArguments,
)

# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
Expand Down
1 change: 1 addition & 0 deletions src/transformers/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@

from .glue import GlueDataset, GlueDataTrainingArguments
from .language_modeling import LineByLineTextDataset, TextDataset
from .squad import SquadDataset, SquadDataTrainingArguments
189 changes: 189 additions & 0 deletions src/transformers/data/datasets/squad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import logging
import os
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Union

import torch
from filelock import FileLock
from torch.utils.data.dataset import Dataset

from ...modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from ...tokenization_utils import PreTrainedTokenizer
from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features


logger = logging.getLogger(__name__)

MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


@dataclass
class SquadDataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""

model_type: str = field(
default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)}
)
data_dir: str = field(
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
)
max_seq_length: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
doc_stride: int = field(
default=128,
metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
)
max_query_length: int = field(
default=64,
metadata={
"help": "The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length."
},
)
max_answer_length: int = field(
default=30,
metadata={
"help": "The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
version_2_with_negative: bool = field(
default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."}
)
null_score_diff_threshold: float = field(
default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
)
n_best_size: int = field(
default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
)
lang_id: int = field(
default=0,
metadata={
"help": "language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
},
)
threads: int = field(default=1, metadata={"help": "multiple threads for converting example to features"})


class Split(Enum):
train = "train"
dev = "dev"


class SquadDataset(Dataset):
"""
This will be superseded by a framework-agnostic approach
soon.
"""

args: SquadDataTrainingArguments
features: List[SquadFeatures]
mode: Split
is_language_sensitive: bool

def __init__(
self,
args: SquadDataTrainingArguments,
tokenizer: PreTrainedTokenizer,
limit_length: Optional[int] = None,
mode: Union[str, Split] = Split.train,
is_language_sensitive: Optional[bool] = False,
cache_dir: Optional[str] = None,
):
self.args = args
self.is_language_sensitive = is_language_sensitive
self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
if isinstance(mode, str):
try:
mode = Split[mode]
except KeyError:
raise KeyError("mode is not a valid split name")
self.mode = mode
# Load data features from cache or dataset file
cached_features_file = os.path.join(
cache_dir if cache_dir is not None else args.data_dir,
"cached_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(args.max_seq_length),),
)

# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not args.overwrite_cache:
start = time.time()
self.features = torch.load(cached_features_file)
logger.info(
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
)
else:
if mode == Split.dev:
examples = self.processor.get_dev_examples(args.data_dir)
else:
examples = self.processor.get_train_examples(args.data_dir)

self.features = squad_convert_examples_to_features(
examples=examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=mode == Split.train,
threads=args.threads,
)

start = time.time()
torch.save(self.features, cached_features_file)
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
logger.info(
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
)

def __len__(self):
return len(self.features)

def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# Convert to Tensors and build dataset
feature = self.features[i]

input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long)
token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long)
cls_index = torch.tensor(feature.cls_index, dtype=torch.long)
p_mask = torch.tensor(feature.p_mask, dtype=torch.float)
is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float)

inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}

if self.args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
del inputs["token_type_ids"]

if self.args.model_type in ["xlnet", "xlm"]:
inputs.update({"cls_index": cls_index, "p_mask": p_mask})
if self.args.version_2_with_negative:
inputs.update({"is_impossible": is_impossible})
if self.is_language_sensitive:
inputs.update({"langs": (torch.ones(input_ids.shape, dtype=torch.int64) * self.args.lang_id)})

if self.mode == Split.train:
start_positions = torch.tensor(feature.start_position, dtype=torch.long)
end_positions = torch.tensor(feature.end_position, dtype=torch.long)
inputs.update({"start_positions": start_positions, "end_positions": end_positions})

return inputs

0 comments on commit e49393c

Please sign in to comment.