Skip to content

Commit

Permalink
[examples] Generate argparsers from type hints on dataclasses (#3669)
Browse files Browse the repository at this point in the history
* [examples] Generate argparsers from type hints on dataclasses

* [HfArgumentParser] way simpler API

* Restore run_language_modeling.py for easier diff

* [HfArgumentParser] final tweaks from code review
  • Loading branch information
julien-c authored Apr 10, 2020
1 parent 7a7fdf7 commit b169ac9
Show file tree
Hide file tree
Showing 5 changed files with 357 additions and 137 deletions.
182 changes: 45 additions & 137 deletions examples/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import logging
import os
import random
from dataclasses import dataclass, field
from typing import Optional

import numpy as np
import torch
Expand All @@ -36,6 +38,8 @@
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
TrainingArguments,
get_linear_schedule_with_warmup,
)
from transformers import glue_compute_metrics as compute_metrics
Expand Down Expand Up @@ -376,137 +380,54 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
return dataset


def main():
parser = argparse.ArgumentParser()

# Required parameters
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
)
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""

model_name_or_path: str = field(
metadata={"help": "Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)}
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
model_type: str = field(metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)})
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
parser.add_argument(
"--task_name",
default=None,
type=str,
required=True,
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pre-trained models downloaded from s3"}
)

# Other parameters
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--tokenizer_name",
default="",
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",

@dataclass
class DataProcessingArguments:
task_name: str = field(
metadata={"help": "The name of the task to train selected in the list: " + ", ".join(processors.keys())}
)
parser.add_argument(
"--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3",
data_dir: str = field(
metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
)
parser.add_argument(
"--max_seq_length",
max_seq_length: int = field(
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument(
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.",
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
parser.add_argument(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.",
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)

parser.add_argument(
"--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",
)
parser.add_argument(
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
)
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--eval_all_checkpoints",
action="store_true",
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument(
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory",
)
parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets",
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
args = parser.parse_args()
def main():
parser = HfArgumentParser((ModelArguments, DataProcessingArguments, TrainingArguments))
model_args, dataprocessing_args, training_args = parser.parse_args_into_dataclasses()

# For now, let's merge all the sets of args into one,
# but soon, we'll keep distinct sets of args, with a cleaner separation of concerns.
args = argparse.Namespace(**vars(model_args), **vars(dataprocessing_args), **vars(training_args))

if (
os.path.exists(args.output_dir)
Expand All @@ -515,20 +436,9 @@ def main():
and not args.overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args.output_dir
)
f"Output directory ({args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)

# Setup distant debugging if needed
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd

print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()

# Setup CUDA, GPU & distributed training
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
Expand Down Expand Up @@ -576,18 +486,16 @@ def main():
args.config_name if args.config_name else args.model_name_or_path,
num_labels=num_labels,
finetuning_task=args.task_name,
cache_dir=args.cache_dir if args.cache_dir else None,
cache_dir=args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir,
)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir if args.cache_dir else None,
cache_dir=args.cache_dir,
)

if args.local_rank == 0:
Expand Down Expand Up @@ -629,7 +537,7 @@ def main():
# Evaluation
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
is_tf_available,
is_torch_available,
)
from .hf_argparser import HfArgumentParser

# Model Cards
from .modelcard import ModelCard
Expand Down Expand Up @@ -141,6 +142,7 @@
from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
from .training_args import TrainingArguments


logger = logging.getLogger(__name__) # pylint: disable=invalid-name
Expand Down
113 changes: 113 additions & 0 deletions src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import dataclasses
from argparse import ArgumentParser
from enum import Enum
from typing import Any, Iterable, NewType, Tuple, Union


DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)


class HfArgumentParser(ArgumentParser):
"""
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses
to generate arguments.
The class is designed to play well with the native argparse. In particular,
you can add more (non-dataclass backed) arguments to the parser after initialization
and you'll get the output back after parsing as an additional namespace.
"""

dataclass_types: Iterable[DataClassType]

def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
"""
Args:
dataclass_types:
Dataclass type, or list of dataclass types for which we will "fill" instances
with the parsed args.
kwargs:
(Optional) Passed to `argparse.ArgumentParser()` in the regular way.
"""
super().__init__(**kwargs)
if dataclasses.is_dataclass(dataclass_types):
dataclass_types = [dataclass_types]
self.dataclass_types = dataclass_types
for dtype in self.dataclass_types:
self._add_dataclass_arguments(dtype)

def _add_dataclass_arguments(self, dtype: DataClassType):
for field in dataclasses.fields(dtype):
field_name = f"--{field.name}"
kwargs = field.metadata.copy()
# field.metadata is not used at all by Data Classes,
# it is provided as a third-party extension mechanism.
if isinstance(field.type, str):
raise ImportError(
"This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563),"
"which can be opted in from Python 3.7 with `from __future__ import annotations`."
"We will add compatibility when Python 3.9 is released."
)
typestring = str(field.type)
for x in (int, float, str):
if typestring == f"typing.Union[{x.__name__}, NoneType]":
field.type = x
if isinstance(field.type, type) and issubclass(field.type, Enum):
kwargs["choices"] = list(field.type)
kwargs["type"] = field.type
if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default
elif field.type is bool:
kwargs["action"] = "store_false" if field.default is True else "store_true"
if field.default is True:
field_name = f"--no-{field.name}"
kwargs["dest"] = field.name
else:
kwargs["type"] = field.type
if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default
else:
kwargs["required"] = True
self.add_argument(field_name, **kwargs)

def parse_args_into_dataclasses(self, args=None, return_remaining_strings=False) -> Tuple[DataClass, ...]:
"""
Parse command-line args into instances of the specified dataclass types.
This relies on argparse's `ArgumentParser.parse_known_args`.
See the doc at:
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
Args:
args:
List of strings to parse. The default is taken from sys.argv.
(same as argparse.ArgumentParser)
return_remaining_strings:
If true, also return a list of remaining argument strings.
Returns:
Tuple consisting of:
- the dataclass instances in the same order as they
were passed to the initializer.abspath
- if applicable, an additional namespace for more
(non-dataclass backed) arguments added to the parser
after initialization.
- The potential list of remaining argument strings.
(same as argparse.ArgumentParser.parse_known_args)
"""
namespace, remaining_args = self.parse_known_args(args=args)
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype)}
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
for k in keys:
delattr(namespace, k)
obj = dtype(**inputs)
outputs.append(obj)
if len(namespace.__dict__) > 0:
# additional namespace.
outputs.append(namespace)
if return_remaining_strings:
return (*outputs, remaining_args)
else:
return (*outputs,)
Loading

0 comments on commit b169ac9

Please sign in to comment.