Skip to content

Commit

Permalink
Introduce logging_strategy training argument (#10267) (#10267)
Browse files Browse the repository at this point in the history
Introduce logging_strategy training argument
in TrainingArguments and TFTrainingArguments. (#9838)
  • Loading branch information
tanmay17061 committed Feb 19, 2021
1 parent 34df26e commit 709c86b
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 5 deletions.
13 changes: 11 additions & 2 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import numpy as np
from tqdm.auto import tqdm

from .trainer_utils import EvaluationStrategy
from .trainer_utils import EvaluationStrategy, LoggingStrategy
from .training_args import TrainingArguments
from .utils import logging

Expand Down Expand Up @@ -403,7 +403,11 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
# Log
if state.global_step == 1 and args.logging_first_step:
control.should_log = True
if args.logging_steps > 0 and state.global_step % args.logging_steps == 0:
if (
args.logging_strategy == LoggingStrategy.STEPS
and args.logging_steps > 0
and state.global_step % args.logging_steps == 0
):
control.should_log = True

# Evaluate
Expand All @@ -423,6 +427,11 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
return control

def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# Log
if args.logging_strategy == LoggingStrategy.EPOCH:
control.should_log = True

# Evaluate
if args.evaluation_strategy == EvaluationStrategy.EPOCH:
control.should_evaluate = True
if args.load_best_model_at_end:
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class EvaluationStrategy(ExplicitEnum):
EPOCH = "epoch"


class LoggingStrategy(ExplicitEnum):
NO = "no"
STEPS = "steps"
EPOCH = "epoch"


class BestRun(NamedTuple):
"""
The best run found by an hyperparameter search (see :class:`~transformers.Trainer.hyperparameter_search`).
Expand Down
16 changes: 14 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
is_torch_tpu_available,
torch_required,
)
from .trainer_utils import EvaluationStrategy, SchedulerType
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType
from .utils import logging


Expand Down Expand Up @@ -139,10 +139,17 @@ class TrainingArguments:
logging_dir (:obj:`str`, `optional`):
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
`runs/**CURRENT_DATETIME_HOSTNAME**`.
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`):
The logging strategy to adopt during training. Possible values are:
* :obj:`"no"`: No logging is done during training.
* :obj:`"epoch"`: Logging is done at the end of each epoch.
* :obj:`"steps"`: Logging is done every :obj:`logging_steps`.
logging_first_step (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to log and evaluate the first :obj:`global_step` or not.
logging_steps (:obj:`int`, `optional`, defaults to 500):
Number of update steps between two logs.
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
save_steps (:obj:`int`, `optional`, defaults to 500):
Number of updates steps before two checkpoint saves.
save_total_limit (:obj:`int`, `optional`):
Expand Down Expand Up @@ -339,6 +346,10 @@ class TrainingArguments:
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})

logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
logging_strategy: LoggingStrategy = field(
default="steps",
metadata={"help": "The logging strategy to use."},
)
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
Expand Down Expand Up @@ -482,6 +493,7 @@ def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
self.logging_strategy = LoggingStrategy(self.logging_strategy)
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO:
self.do_eval = True
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/training_args_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,17 @@ class TFTrainingArguments(TrainingArguments):
logging_dir (:obj:`str`, `optional`):
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
`runs/**CURRENT_DATETIME_HOSTNAME**`.
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`):
The logging strategy to adopt during training. Possible values are:
* :obj:`"no"`: No logging is done during training.
* :obj:`"epoch"`: Logging is done at the end of each epoch.
* :obj:`"steps"`: Logging is done every :obj:`logging_steps`.
logging_first_step (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to log and evaluate the first :obj:`global_step` or not.
logging_steps (:obj:`int`, `optional`, defaults to 500):
Number of update steps between two logs.
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
save_steps (:obj:`int`, `optional`, defaults to 500):
Number of updates steps before two checkpoint saves.
save_total_limit (:obj:`int`, `optional`):
Expand Down

0 comments on commit 709c86b

Please sign in to comment.