Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trainer] add tf32-mode control #14606

Merged
merged 11 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,12 @@ Like all cases with reduced precision this may or may not be satisfactory for yo

If you're already using fp16 or bf16 mixed precision it may help with the throughput as well.

The 🤗 Trainer has this mode enabled by default but can be disabled automatically by passing `--tf32 0`.

Note: tf32 mode is internal to CUDA and can't be accessed directly via `tensor.to(dtype=torch.tf32)` as `torch.tf32` doesn't exit.

Note: you need `torch>=1.7` to enjoy this feature.


### Gradient Checkpointing

Expand Down
20 changes: 19 additions & 1 deletion src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def is_torch_bf16_available():
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if not version.parse(torch.__version__) >= version.parse("1.10"):
if version.parse(torch.__version__) < version.parse("1.10"):
return False
if not hasattr(torch, "autocast"):
return False
Expand All @@ -351,6 +351,24 @@ def is_torch_bf16_available():
return False


def is_torch_tf32_available():
if is_torch_available():
import torch
stas00 marked this conversation as resolved.
Show resolved Hide resolved

if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if version.parse(torch.__version__) < version.parse("1.7"):
return False

return True
else:
return False


_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
if _torch_available:
torch_version = version.parse(importlib_metadata.version("torch"))
Expand Down
13 changes: 11 additions & 2 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
is_tokenizers_available,
is_torch_available,
is_torch_bf16_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
is_vision_available,
Expand Down Expand Up @@ -495,9 +496,17 @@ def require_torch_gpu(test_case):


def require_torch_bf16(test_case):
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.10."""
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
if not is_torch_bf16_available():
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.10")(test_case)
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10")(test_case)
else:
return test_case


def require_torch_tf32(test_case):
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
if not is_torch_tf32_available():
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")(test_case)
else:
return test_case

Expand Down
13 changes: 13 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_available,
is_torch_tf32_available,
is_torch_tpu_available,
torch_required,
)
Expand Down Expand Up @@ -227,6 +228,9 @@ class TrainingArguments:
fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values.
tf32 (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API
and it may change.
local_rank (:obj:`int`, `optional`, defaults to -1):
Rank of the process during distributed training.
xpu_backend (:obj:`str`, `optional`):
Expand Down Expand Up @@ -548,6 +552,12 @@ class TrainingArguments:
default=False,
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
)
tf32: bool = field(
default=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default should be whatever PyTorch has by default, so None here and the user can set it to True or False to force/unforce it.

I understand it's True for versions >= 1.7 and < 1.10 but False after?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, good idea! let the user decide!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True for versions >= 1.7 and < 1.11 and probably False after - the nightly is still True as of today.

metadata={
"help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change."
},
)
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
xpu_backend: str = field(
default=None,
Expand Down Expand Up @@ -802,6 +812,9 @@ def __post_init__(self):
"Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices."
)

if is_torch_available() and is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = True if self.tf32 else False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here we should only change that boolean if the value set was not None. If the value is True, there should be an error if is_torch_tf_32_available() is False so the user is not surprised if they don't get what they want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for this great feedback, Sylvain. Please have another look.


if self.report_to is None:
logger.info(
"The default value for the training argument `--report_to` will change in v5 (from all installed "
Expand Down
10 changes: 10 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
require_torch_tf32,
require_torch_up_to_2_gpus,
slow,
)
Expand Down Expand Up @@ -492,6 +493,15 @@ def test_mixed_bf16(self):

# will add more specific tests once there are some bugs to fix

@require_torch_gpu
@require_torch_tf32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, do we have a setup that has the right CUDA version an GPU capabilities?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have rtx-3090 if that's what you ask.

Running benchmarks now - will post those shortly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering for our testing machines on the automatic CI :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one day we will have those newer gpus.

def test_tf32(self):

# very basic test
trainer = get_regression_trainer(learning_rate=0.1, tf32=True)
trainer.train()
self.check_trained_model(trainer.model)


@require_torch
@require_sentencepiece
Expand Down