Skip to content

Commit

Permalink
[trainer] add tf32-mode control (#14606)
Browse files Browse the repository at this point in the history
* [trainer] add --tf32 support

* it's pt>=.17

* it's pt>=.17

* flip the default to True

* add experimental note

* simplify logic

* style

* switch to 3-state logic

* doc

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* re-style code

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
stas00 and sgugger committed Dec 3, 2021
1 parent aada989 commit 71b1bf7
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 27 deletions.
5 changes: 5 additions & 0 deletions docs/source/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,13 @@ Like all cases with reduced precision this may or may not be satisfactory for yo

If you're already using fp16 or bf16 mixed precision it may help with the throughput as well.

You can enable this mode in the 馃 Trainer with `--tf32`, or disable it with `--tf32 0` or `--no_tf32`.
By default the PyTorch default is used.

Note: tf32 mode is internal to CUDA and can't be accessed directly via `tensor.to(dtype=torch.tf32)` as `torch.tf32` doesn't exit.

Note: you need `torch>=1.7` to enjoy this feature.


### Gradient Checkpointing

Expand Down
68 changes: 43 additions & 25 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,34 +321,52 @@ def is_torch_cuda_available():


def is_torch_bf16_available():
if is_torch_available():
import torch
if not is_torch_available():
return False

# since currently no utility function is available we build our own.
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
# with additional check for torch version
# to succeed:
# 1. the hardware needs to support bf16 (arch >= Ampere)
# 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
# 3. CUDA >= 11
# 4. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0)

if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if not version.parse(torch.__version__) >= version.parse("1.10"):
return False
if not hasattr(torch, "autocast"):
return False
import torch

return True
else:
# since currently no utility function is available we build our own.
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
# with additional check for torch version
# to succeed:
# 1. the hardware needs to support bf16 (arch >= Ampere)
# 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
# 3. CUDA >= 11
# 4. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0)

if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if version.parse(torch.__version__) < version.parse("1.10"):
return False
if not hasattr(torch, "autocast"):
return False

return True


def is_torch_tf32_available():
if not is_torch_available():
return False

import torch

if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if version.parse(torch.__version__) < version.parse("1.7"):
return False

return True


_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
Expand Down
13 changes: 11 additions & 2 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
is_tokenizers_available,
is_torch_available,
is_torch_bf16_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
is_vision_available,
Expand Down Expand Up @@ -495,9 +496,17 @@ def require_torch_gpu(test_case):


def require_torch_bf16(test_case):
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.10."""
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
if not is_torch_bf16_available():
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.10")(test_case)
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10")(test_case)
else:
return test_case


def require_torch_tf32(test_case):
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
if not is_torch_tf32_available():
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")(test_case)
else:
return test_case

Expand Down
21 changes: 21 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_available,
is_torch_tf32_available,
is_torch_tpu_available,
torch_required,
)
Expand Down Expand Up @@ -227,6 +228,9 @@ class TrainingArguments:
fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values.
tf32 (:obj:`bool`, `optional`):
Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API
and it may change.
local_rank (:obj:`int`, `optional`, defaults to -1):
Rank of the process during distributed training.
xpu_backend (:obj:`str`, `optional`):
Expand Down Expand Up @@ -548,6 +552,12 @@ class TrainingArguments:
default=False,
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
)
tf32: bool = field(
default=None,
metadata={
"help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change."
},
)
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
xpu_backend: str = field(
default=None,
Expand Down Expand Up @@ -802,6 +812,17 @@ def __post_init__(self):
"Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices."
)

if is_torch_available() and self.tf32 is not None:
if self.tf32:
if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = True
else:
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
else:
if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = False
# no need to assert on else

if self.report_to is None:
logger.info(
"The default value for the training argument `--report_to` will change in v5 (from all installed "
Expand Down
10 changes: 10 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
require_torch_tf32,
require_torch_up_to_2_gpus,
slow,
)
Expand Down Expand Up @@ -492,6 +493,15 @@ def test_mixed_bf16(self):

# will add more specific tests once there are some bugs to fix

@require_torch_gpu
@require_torch_tf32
def test_tf32(self):

# very basic test
trainer = get_regression_trainer(learning_rate=0.1, tf32=True)
trainer.train()
self.check_trained_model(trainer.model)


@require_torch
@require_sentencepiece
Expand Down

0 comments on commit 71b1bf7

Please sign in to comment.