Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples] Add trainer support for question-answering #4829

Merged
merged 11 commits into from
Jul 7, 2020
4 changes: 2 additions & 2 deletions examples/question-answering/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ exact_match = 86.91
```

This fine-tuned model is available as a checkpoint under the reference
`bert-large-uncased-whole-word-masking-finetuned-squad`.
[`bert-large-uncased-whole-word-masking-finetuned-squad`](https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad).

#### Fine-tuning XLNet on SQuAD

Expand Down Expand Up @@ -178,4 +178,4 @@ python run_tf_squad.py \
--optimizer_name adamw
```

For the moment the evaluation is not available in the Tensorflow Trainer only the training.
For the moment evaluation is not available in the Tensorflow Trainer only the training.
174 changes: 174 additions & 0 deletions examples/question-answering/run_squad_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# coding=utf-8
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Fine-tuning the library models for question-answering."""


import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

from transformers import (
AutoConfig,
AutoModelForQuestionAnswering,
AutoTokenizer,
DataCollatorForQuestionAnswering,
HfArgumentParser,
SquadDataset,
)
from transformers import SquadDataTrainingArguments as DataTrainingArguments
from transformers import Trainer, TrainingArguments


logger = logging.getLogger(__name__)


@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""

model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."})
# If you want to tweak more attributes on your tokenizer, you should do it in a distinct script,
# or just modify its tokenizer_config.json.
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)

# Prepare Question-Answering task
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.

config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
model = AutoModelForQuestionAnswering.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)

# Get datasets
is_language_sensitive = hasattr(model.config, "lang2id")
train_dataset = (
SquadDataset(
data_args, tokenizer=tokenizer, is_language_sensitive=is_language_sensitive, cache_dir=model_args.cache_dir
)
if training_args.do_train
else None
)
eval_dataset = (
SquadDataset(
data_args,
tokenizer=tokenizer,
mode="dev",
is_language_sensitive=is_language_sensitive,
cache_dir=model_args.cache_dir,
)
if training_args.do_eval
else None
)
data_collator = DataCollatorForQuestionAnswering()

# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)

# Training
if training_args.do_train:
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
)
trainer.save_model()
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
if trainer.is_world_master():
tokenizer.save_pretrained(training_args.output_dir)


def _mp_fn(index):
# For xla_spawn (TPUs)
main()


if __name__ == "__main__":
main()
16 changes: 14 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,20 @@

# Trainer
from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction
from .data.data_collator import DefaultDataCollator, DataCollator, DataCollatorForLanguageModeling
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
from .data.data_collator import (
DefaultDataCollator,
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForQuestionAnswering,
)
from .data.datasets import (
GlueDataset,
TextDataset,
LineByLineTextDataset,
GlueDataTrainingArguments,
SquadDataset,
SquadDataTrainingArguments,
)

# Benchmarks
from .benchmark import PyTorchBenchmark, PyTorchBenchmarkArguments
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,21 @@ def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels


@dataclass
class DataCollatorForQuestionAnswering(DataCollator):
"""
Data collator used for language modeling.
- collates batches of tensors
- preprocesses batches for question answering
"""

def collate_batch(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
keys = batch[0].keys()
inputs = {}

for key in keys:
inputs[key] = torch.stack([example[key] for example in batch])

return inputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty clean implementation of data collator. cc @sgugger

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, quick question though, why does the default not work here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I think it's ~identical to the default one so maybe we can merge them and just use the default here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default collator needs List[InputDataClass] but I'm returning List[Dict[str, torch.Tensor]] so I was not able to use the default one

Copy link
Collaborator

@sgugger sgugger Jun 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InputDataClass is just a type alias for anything. I think the default should work fine for you (but we can change its implementation to use this).
Note that #5015 will change how data collator work a little bit and a few names, so we'll need a few adjustments here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InputDataClass is just a type alias for anything.

Yes, but I think it assumes that it will be class as it uses getattr and vars. How should I proceed ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can change the default implem to accept more input types

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if you want to do it @sgugger or @patil-suraj

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@julien-c can we just add a simple type check in default collator, i.e if the input is dict we can call example.get instead of getattr, or maybe convert the dict to a simple class ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add this this morning with the backward compatibility to the old DataCollator style.

1 change: 1 addition & 0 deletions src/transformers/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@

from .glue import GlueDataset, GlueDataTrainingArguments
from .language_modeling import LineByLineTextDataset, TextDataset
from .squad import SquadDataset, SquadDataTrainingArguments
Loading