Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torchdynamo with torch_tensorrt(fx path) #17765

Merged
merged 28 commits into from
Jul 13, 2022
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions docs/source/en/perf_train_gpu_one.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o

# Efficient Training on a Single GPU

This guide focuses on training large models efficiently on a single GPU. These approaches are still valid if you have access to a machine with multiple GPUs but you will also have access to additional methods outlined in the [multi-GPU section](perf_train_gpu_many).
This guide focuses on training large models efficiently on a single GPU. These approaches are still valid if you have access to a machine with multiple GPUs but you will also have access to additional methods outlined in the [multi-GPU section](perf_train_gpu_many).

In this section we have a look at a few tricks to reduce the memory footprint and speed up training for large models and how they are integrated in the [`Trainer`] and [🤗 Accelerate](https://huggingface.co/docs/accelerate/). Each method can improve speed or memory usage which is summarized in the table below:

Expand Down Expand Up @@ -367,7 +367,7 @@ Samples/second: 10.09
GPU memory occupied: 7275 MB.
```

We can see that with these tweaks we use about half the GPU memory as at the beginning while also being slightly faster.
We can see that with these tweaks we use about half the GPU memory as at the beginning while also being slightly faster.

### BF16
If you have access to a Ampere or newer hardware you can use bf16 for your training and evaluation. While bf16 has a worse precision than fp16, it has a much much bigger dynamic range. Therefore, if in the past you were experiencing overflow issues while training the model, bf16 will prevent this from happening most of the time. Remember that in fp16 the biggest number you can have is `65535` and any number above that will overflow. A bf16 number can be as large as `3.39e+38` (!) which is about the same as fp32 - because both have 8-bits used for the numerical range.
Expand All @@ -394,7 +394,7 @@ Like all cases with reduced precision this may or may not be satisfactory for yo

If you're already using fp16 or bf16 mixed precision it may help with the throughput as well.

You can enable this mode in the 🤗 Trainer with:
You can enable this mode in the 🤗 Trainer with:
```python
TrainingArguments(tf32=True)
```
Expand Down Expand Up @@ -654,7 +654,7 @@ https://github.com/huggingface/transformers/blob/master/src/transformers/trainer


## Choice of GPU
Sometimes, even when applying all the above tweaks the throughput on a given GPU might still not be good enough. One easy solution is to change the type of GPU. For example switching from let's say a K80 (which you typically get on Google Colab) to a fancier GPU such as the V100 or A100. Although they are more expensive they are usually more cost effective than cheaper GPUs due to their larger memory and faster architecture.
Sometimes, even when applying all the above tweaks the throughput on a given GPU might still not be good enough. One easy solution is to change the type of GPU. For example switching from let's say a K80 (which you typically get on Google Colab) to a fancier GPU such as the V100 or A100. Although they are more expensive they are usually more cost effective than cheaper GPUs due to their larger memory and faster architecture.

Now, let's take a step back and discuss what we should optimize for when scaling the training of large models.

Expand Down Expand Up @@ -718,3 +718,15 @@ For some applications, such as pretraining large language models, applying all t

Another use case for training on many GPUs is if the model does not fit on a single GPU with all the mentioned tricks. There are still more methods we can apply although life starts to get a bit more complicated. This usually involves some form of pipeline or tensor parallelism where the model itself is distributed across several GPUs. One can also make use of DeepSpeed which implements some of these parallelism strategies along with some more optimization to reduce the memory footprint such as partitioning the optimizer states. You can read more about this in the ["Multi-GPU training" section](perf_train_gpu_many).

## Inference with torchdynamo
TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower the graph to an optimized engine. One solution is using the [TensorRT](https://developer.nvidia.com/tensorrt) or NVFuser as backend. You can choose one option below for performance boost.
```
TrainingArguments(torchdynamo="eager") #enable eager model GPU. No performance boost
TrainingArguments(torchdynamo="nvfuser") #enable nvfuser
TrainingArguments(torchdynamo="fx2trt") #enable tensorRT fp32
TrainingArguments(torchdynamo="fx2trt-f16") #enable tensorRT fp16
```
This feature involves 3 different libraries. To install them, please follow the instructions below:
- [Torchdynamo installation](https://github.com/pytorch/torchdynamo#requirements-and-setup)
- [Functorch installation](https://github.com/pytorch/functorch#install)
- [Torch-TensorRT(FX) installation](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst#installation)
6 changes: 6 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
is_torch_available,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
Expand Down Expand Up @@ -494,6 +495,11 @@ def require_torchdynamo(test_case):
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)


def require_torch_tensorrt_fx(test_case):
"""Decorator marking a test that requires Torch-TensorRT FX"""
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)


def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
Expand Down
12 changes: 1 addition & 11 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
is_torchdynamo_available,
logging,
)
from .utils.generic import ContextManagers
Expand Down Expand Up @@ -2291,16 +2290,7 @@ def torchdynamo_smart_context_manager(self):
"""
A helper wrapper that creates an appropriate context manager for `torchdynamo`.
"""
ctx_manager = contextlib.nullcontext()
if is_torchdynamo_available():
import torchdynamo
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy

if self.args.torchdynamo == "eager":
ctx_manager = torchdynamo.optimize("eager")
elif self.args.torchdynamo == "nvfuser":
ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
return ctx_manager
return self.args.ctx_manager_torchdynamo

def autocast_smart_context_manager(self):
"""
Expand Down
33 changes: 32 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@
is_torch_available,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torchdynamo_available,
logging,
torch_required,
)
Expand Down Expand Up @@ -935,7 +937,7 @@ class TrainingArguments:
" are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
" nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
),
"choices": ["eager", "nvfuser"],
"choices": ["eager", "nvfuser", "fx2trt", "fx2trt-fp16"],
},
)
ray_scope: Optional[str] = field(
Expand Down Expand Up @@ -1218,6 +1220,35 @@ def __post_init__(self):
FutureWarning,
)

if self.torchdynamo:
if not is_torchdynamo_available():
raise RuntimeError("Torchdynamo is not installed.")

import torchdynamo
from torchdynamo.optimizations import backends
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy

def get_ctx():
# Normal
if self.torchdynamo == "eager":
return torchdynamo.optimize("eager")
elif self.torchdynamo == "nvfuser":
return torchdynamo.optimize(aot_autograd_speedup_strategy)
# TensorRT
if self.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
if not is_torch_tensorrt_fx_available():
raise RuntimeError("Torch-TensorRT FX path is not installed.")
if self.torchdynamo == "fx2trt-fp16":
return torchdynamo.optimize(backends.fx2trt_compiler_fp16)
elif self.torchdynamo == "fx2trt":
return torchdynamo.optimize(backends.fx2trt_compiler)
else:
raise RuntimeError(f"Torchdynamo backend {self.torchdynamo} is not supported.")

self.ctx_manager_torchdynamo = get_ctx()
else:
self.ctx_manager_torchdynamo = contextlib.nullcontext()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big fan of having this as an attribute of the TrainingArguments: I think it will break serialization (see here). This all could fit in a function that takes the value of self.dynamo (since it's the only field of TrainingArguments it uses) and lies in integrations.py. The code in the trainer file should then be adapted slightly.

Copy link
Contributor Author

@frank-wei frank-wei Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger, that sounds good to me. I moved the function to integrations.py but has circular import issue.

  File "/home/runner/work/transformers/transformers/src/transformers/training_args.py", line 26, in <module>
    from .integrations import get_torchdynamo_ctx
  File "/home/runner/work/transformers/transformers/src/transformers/integrations.py", line 47, in <module>
    from .trainer_callback import ProgressCallback, TrainerCallback  # noqa: E402
  File "/home/runner/work/transformers/transformers/src/transformers/trainer_callback.py", line 27, in <module>
    from .training_args import TrainingArguments
ImportError: cannot import name 'TrainingArguments' from partially initialized module 'transformers.training_args' (most likely due to a circular import) (/home/runner/work/transformers/transformers/src/transformers/training_args.py)

Is it good to leave function get_torchdynamo_ctx as a member of TrainingArguments? Or leave it in import_utils.py to stay together with is_torchdynamo_available()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be imported at all in the training_args.py module, only in the trainer.py. As I said, you shouldn't add new attributes to TrainingArguments that are not serializable.

Copy link
Contributor

@stas00 stas00 Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger, the original implementation calculated the context on every call - that's why I suggested to move the logic to the argrparse stage, since this logic needs to be done only once per program run.

What would be a good place then to perform this figuring out? In trainer's init probably, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works, yes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frank-wei, please let me know if you need help here - moving to trainer's init that is.


def __str__(self):
self_as_dict = asdict(self)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_onnx_dict_inputs_support_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,12 @@ def is_torchdynamo_available():
return importlib.util.find_spec("torchdynamo") is not None


def is_torch_tensorrt_fx_available():
if importlib.util.find_spec("torch_tensorrt") is None:
return False
return importlib.util.find_spec("torch_tensorrt.fx") is not None


def is_datasets_available():
return _datasets_available

Expand Down
32 changes: 26 additions & 6 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
require_torch_tensorrt_fx,
require_torch_tf32,
require_torch_up_to_2_gpus,
require_torchdynamo,
Expand Down Expand Up @@ -1799,6 +1800,7 @@ def test_fp16_full_eval(self):

@require_torch_non_multi_gpu
@require_torchdynamo
@require_torch_tensorrt_fx
def test_torchdynamo_full_eval(self):
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
n_gpus = get_gpu_count()
Expand Down Expand Up @@ -1827,6 +1829,21 @@ def test_torchdynamo_full_eval(self):
metrics = trainer.evaluate()
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)

# 4. TorchDynamo fx2trt
frank-wei marked this conversation as resolved.
Show resolved Hide resolved
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
metrics = trainer.evaluate()
t1 = metrics["eval_loss"]
t2 = original_eval_loss
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)

# 5. TorchDynamo fx2trt-fp16
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
metrics = trainer.evaluate()
t1 = metrics["eval_loss"]
t2 = original_eval_loss
# fp16 has accuracy accuracy degradation
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)

@require_torch_non_multi_gpu
@require_torchdynamo
def test_torchdynamo_memory(self):
Expand All @@ -1852,24 +1869,23 @@ def forward(self, x):

mod = MyModule()

# 1. Default - without TorchDynamo
# 1. without TorchDynamo (eager baseline)
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
a.grad = None
trainer = CustomTrainer(model=mod)
# warmup
for _ in range(10):
orig_loss = trainer.training_step(mod, {"x": a})

# resets
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

orig_loss = trainer.training_step(mod, {"x": a})
orig_peak_mem = torch.cuda.max_memory_allocated()
del trainer

# Reset the peak for another measurement
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# 2. TorchDynamo nvfuser
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
a.grad = None
Expand All @@ -1879,7 +1895,11 @@ def forward(self, x):
for _ in range(10):
loss = trainer.training_step(mod, {"x": a})

# resets
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

loss = trainer.training_step(mod, {"x": a})
peak_mem = torch.cuda.max_memory_allocated()
del trainer
Expand Down