Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trainer] move secondary methods into a separate file #10363

Merged
merged 3 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 3 additions & 58 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import collections
import gc
import inspect
import json
import math
import os
import re
Expand Down Expand Up @@ -82,7 +81,6 @@
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
get_learning_rate,
nested_concat,
nested_detach,
nested_numpify,
Expand Down Expand Up @@ -226,6 +224,8 @@ class Trainer:

"""

from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics

def __init__(
self,
model: Union[PreTrainedModel, torch.nn.Module] = None,
Expand Down Expand Up @@ -1130,7 +1130,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
tr_loss -= tr_loss

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = get_learning_rate(self)
logs["learning_rate"] = self._get_learning_rate()

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
Expand Down Expand Up @@ -1345,61 +1345,6 @@ def log(self, logs: Dict[str, float]) -> None:
self.state.log_history.append(output)
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)

def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
"""
Reformat Trainer metrics values to a human-readable format

Args:
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict

Returns:
metrics (:obj:`Dict[str, float]`): The reformatted metrics
"""

metrics_copy = metrics.copy()
for k, v in metrics_copy.items():
if "_mem_" in k:
metrics_copy[k] = f"{ v >> 20 }MB"
elif k == "total_flos":
metrics_copy[k] = f"{ int(v) >> 30 }GF"
elif type(metrics_copy[k]) == float:
metrics_copy[k] = round(v, 4)

return metrics_copy

def log_metrics(self, split, metrics):
"""
Log metrics in a specially formatted way

Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predictmetrics: metrics dict
"""

logger.info(f"***** {split} metrics *****")
metrics_formatted = self.metrics_format(metrics)
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")

def save_metrics(self, split, metrics):
"""
Save metrics into a json file for that split, e.g. ``train_results.json``.

Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
"""
path = os.path.join(self.args.output_dir, f"{split}_results.json")
with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True)

def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
Expand Down
112 changes: 88 additions & 24 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
Torch utilities for the Trainer class.
"""

import json
import math
import os
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Iterator, List, Optional, Union
from typing import Dict, Iterator, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -263,29 +265,6 @@ def _get_first_shape(arrays):
return arrays.shape


def get_learning_rate(trainer):
if trainer.deepspeed:
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
try:
last_lr = trainer.lr_scheduler.get_last_lr()[0]
except AssertionError as e:
if "need to call step" in str(e):
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
last_lr = 0
else:
raise
else:
last_lr = (
# backward compatibility for pytorch schedulers
trainer.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else trainer.lr_scheduler.get_lr()[0]
)
return last_lr


class DistributedTensorGatherer:
"""
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
Expand Down Expand Up @@ -563,3 +542,88 @@ def __iter__(self) -> Iterator:
assert len(indices) == self.num_samples

return iter(indices)


# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
# helper methods here


def _get_learning_rate(self):
if self.deepspeed:
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
try:
last_lr = self.lr_scheduler.get_last_lr()[0]
except AssertionError as e:
if "need to call step" in str(e):
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
last_lr = 0
else:
raise
else:
last_lr = (
# backward compatibility for pytorch schedulers
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
return last_lr


def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
"""
Reformat Trainer metrics values to a human-readable format

Args:
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict

Returns:
metrics (:obj:`Dict[str, float]`): The reformatted metrics
"""

metrics_copy = metrics.copy()
for k, v in metrics_copy.items():
if "_mem_" in k:
metrics_copy[k] = f"{ v >> 20 }MB"
elif k == "total_flos":
metrics_copy[k] = f"{ int(v) >> 30 }GF"
elif type(metrics_copy[k]) == float:
metrics_copy[k] = round(v, 4)

return metrics_copy


def log_metrics(self, split, metrics):
"""
Log metrics in a specially formatted way

Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predictmetrics: metrics dict
"""

logger.info(f"***** {split} metrics *****")
metrics_formatted = self.metrics_format(metrics)
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")


def save_metrics(self, split, metrics):
"""
Save metrics into a json file for that split, e.g. ``train_results.json``.

Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
"""
path = os.path.join(self.args.output_dir, f"{split}_results.json")
with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True)