Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torchdynamo with torch_tensorrt(fx path) #17765

Merged
merged 28 commits into from
Jul 13, 2022

Conversation

frank-wei
Copy link
Contributor

@frank-wei frank-wei commented Jun 18, 2022

What does this PR do?

Adding support for TorchDynamo with torch_tensor(fx2trt module). Detailed context available at #17724
This diff is about adding extra inference backend based on #17308

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

To reproduce and set up the environment

# install torch-nightly
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch-nightly

# install functorch (and reinstall after `git pull` later if need to sync up)
git clone https://github.com/pytorch/functorch
cd functorch
rm -rf build
pip install -e .[aot]

cd ..
git clone https://github.com/pytorch/torchdynamo
cd torchdynamo
pip install -r requirements.txt
python setup.py develop

# install TensorRT
pip install nvidia-pyindex
pip install nvidia-tensorrt==8.2.4.2

# install torch_tensorrt (fx path)
cd ..
git clone https://github.com/pytorch/TensorRT.git
cd TensorRT/py
python setup.py install --fx-only

cc HF @stas00
cc Meta @yinghai @Chillee
cc NV @ncomly-nvidia @narendasan

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 18, 2022

The documentation is not available anymore as the PR was closed or merged.

@frank-wei frank-wei changed the title enable fx2trt Enable torchdynamo with torch_tensorrt(fx2trt module) Jun 18, 2022
@frank-wei frank-wei changed the title Enable torchdynamo with torch_tensorrt(fx2trt module) Enable torchdynamo with torch_tensorrt(fx module) Jun 24, 2022
@frank-wei frank-wei changed the title Enable torchdynamo with torch_tensorrt(fx module) Enable torchdynamo with torch_tensorrt(fx path) Jun 24, 2022
@frank-wei
Copy link
Contributor Author

Hi, @stas00 just a friendly ping. I updated the installation part and it will be easy to repro if needed.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was away, catching up now.

trying to figure out how to install TensorRT - so can't yet appreciate your PR. Meanwhile there is already work to do - please see comments.

Additionally the installation info you shared needs to be part of the docs, since otherwise the users won't know what to do. Of course, ideally it should be on your side and HF docs link to it. As that way you can easily keep it up-to-date. Thank you.

And of course the new API options need to be documented as well in the user-facing documentation otherwise nobody will use this feature. @anijain2305, could you please complete that long overdue documentation? We won't be able to finish this PR until the doc task from #17308 (comment) is completed. Thank you! You can of course do it as part of this PR and document the overdue options and the new ones in one go.

tests/trainer/test_trainer.py Show resolved Hide resolved
@stas00
Copy link
Contributor

stas00 commented Jun 27, 2022

So I followed your instructions except I used the .deb package installer.

(oh and please link to https://docs.nvidia.com/deeplearning/tensorrt/archives/index.html so that the user will know how to install tensorrt)

why do I get:

Traceback (most recent call last):
  File "/mnt/nvme0/code/github/00pytorch/torchdynamo/torchdynamo/optimizations/backends.py", line 45, in inner
    return fn(model, **kwargs)
  File "/mnt/nvme0/code/github/00pytorch/torchdynamo/torchdynamo/optimizations/backends.py", line 313, in fx2trt
    from torch_tensorrt.fx.fx2trt import InputTensorSpec
ModuleNotFoundError: No module named 'torch_tensorrt'

Is it the module from tensorrt-8.2.5.1-cp38-none-linux_x86_64.whl

oh I see it failed to build:

/hf/00pytorch/TensorRT/py [pytorch/TensorRT|master]> python setup.py install --fx-only
Could not find bazel in PATH

installed bazel and it still fails:

pip install bazel
Collecting bazel
  Downloading bazel-0.0.0.20200723.tar.gz (1.4 kB)
Building wheels for collected packages: bazel
  Building wheel for bazel (setup.py) ... done
  Created wheel for bazel: filename=bazel-0.0.0.20200723-py3-none-any.whl size=1708 sha256=518429e9ce158eb7e4ffc2cefa782eb7935d39d317d67801c5ae9b7346af0500
  Stored in directory: /home/stas/.cache/pip/wheels/9b/80/e4/8d16b3eeeda264ac8105dd7fa29a124431113b2f1f5dd703bc
Successfully built bazel
Installing collected packages: bazel
Successfully installed bazel-0.0.0.20200723
(py38-pt112) /hf/00pytorch/TensorRT/py [pytorch/TensorRT|master]> python setup.py install --fx-only
Could not find bazel in PATH

so it's not a python package that it wants but a system-wide bazel? there is no apt package - probably need to add a new apt repo? this doc appears to be really outdated https://docs.bazel.build/versions/main/install-ubuntu.html

In any case this obviously requires explicit instructions.

I will wait for your instructions before proceeding.

@frank-wei
Copy link
Contributor Author

frank-wei commented Jun 27, 2022

Thanks for your time and efforts @stas00 !

  1. Yes, the TRT seems that bring the new user some troubles when they try their first time to install. I just found a way to install python version of TRT so you do not need to download TRT tarball and unzip the stuffs (this python installation will install all the dependent libs like tensorRT lib and cuDNN lib). I added this instructions to our doc as a PR. [FX] Create getting_started_with_fx_path.rst pytorch/TensorRT#1145
    $ pip3 install nvidia-pyindex
    $ pip3 install nvidia-tensorrt==8.2.4.2
  1. I am having a PR to disable the bazel check [FX] --fx-only does not need to check bazel pytorch/TensorRT#1147. (merged)
    But that is a bit weird for bazel installation. I am on centOS and conda envrioment. Here is command conda install -c conda-forge bazel. It looks like your bazel installation location is not added to $PATH but which bazel can help check. Now, with my diff 1147 (merged), we should not need bazel.
    Now below instruction is the complete instruction about install TRT, pytorch, torch_tensorrt.fx which I just verified work.
    $ conda create --name python_env python=3.8
    $ conda activate python_env
    # Recommend to install PyTorch 1.12 and later
    $ conda install pytorch torchvision torchtext cudatoolkit=11.3 -c pytorch-nightly
    # Install TensorRT python package
    $ pip3 install nvidia-pyindex
    $ pip3 install nvidia-tensorrt==8.2.4.2
    $ git clone https://github.com/pytorch/TensorRT.git
    $ cd TensorRT/py && python setup.py install --fx-only && cd ..
    # check torch_tensorrt.fx is installed
    $ python -c "import torch_tensorrt.fx"

Hope it solves your problem.

@stas00
Copy link
Contributor

stas00 commented Jun 28, 2022

conda install -c conda-forge bazel did the trick, The same with pip was giving nothing with which bazel - not a PATH issue, but a package issue I think, but probably related


    $ pip3 install nvidia-pyindex
    $ pip3 install nvidia-tensorrt==8.2.4.2

That did the trick. The tests have run successfully.

so let's update the OP with the above 2 fixes.

@stas00
Copy link
Contributor

stas00 commented Jun 28, 2022

ah, one more user-facing documentation nit - if you want users to use your magic code you will want to provide some enticement. A small benchmark table that shows what these features do usually goes a long way to get a user excited to try them. So this is something else to consider. It's not a show stopper, but as you can see if the docs aren't added right away they never get added, so it's best to do it in one go. It's still a recommendation and I'm fine merging it as is, it's just not going to be used much w/o enticing docs.

@frank-wei
Copy link
Contributor Author

ah, one more user-facing documentation nit - if you want users to use your magic code you will want to provide some enticement. A small benchmark table that shows what these features do usually goes a long way to get a user excited to try them. So this is something else to consider. It's not a show stopper, but as you can see if the docs aren't added right away they never get added, so it's best to do it in one go. It's still a recommendation and I'm fine merging it as is, it's just not going to be used much w/o enticing docs.

I will try to add the doc there. But it is better to have @anijain2305 to include the AOT part.:-)

@stas00
Copy link
Contributor

stas00 commented Jun 28, 2022

I will try to add the doc there. But it is better to have @anijain2305 to include the AOT part.:-)

Yeah, I was hoping that you'd only need to add the incremental part relevant for this PR.

@stas00
Copy link
Contributor

stas00 commented Jun 28, 2022

re: CI - yes and it's complicated

basically the live CI that you see reporting in this PR runs only CPU tests since CircleCI doesn't have gpus.

then we have another set of CI workflows that runs on our machine via github actions and that's where we test all the complex/slow cases.

And yes, I completely forgot that part of this PR we need to setup our CI to install all these packages as well so that these tests will be run.

So once we polished this let's not forget that part. We will have to run all those instructions on our pt-nightly docker image - but actually there is a problem with this idea - how will the docker builder be able to download tensorRT packages if they require an NVIDIA user account?

@frank-wei
Copy link
Contributor Author

frank-wei commented Jun 28, 2022

re: CI
Actually, circleCI has gpu resource to use(V100, T4, P4). I just added to our project :-) pytorch/TensorRT#1137
These 2 commands are our saver

 $ pip3 install nvidia-pyindex
 $ pip3 install nvidia-tensorrt==8.2.4.2

Do you think we need to have @require_torchtensorrt.fx ? So it will help us to check if torch_tensorrt.fx is installed in the test?

@stas00
Copy link
Contributor

stas00 commented Jun 28, 2022

Actually, circleCI has gpu resource to use(V100, T4, P4). I just added to our project :-) pytorch/TensorRT#1137

That's great to know - thank you very much - I will pass this info on

These 2 commands are our saver

 $ pip3 install nvidia-pyindex
 $ pip3 install nvidia-tensorrt==8.2.4.2

Ah, right! so no need for nvidia user account! super - let's use that in the instructions then.

Do you think we need to have @require_torchtensorrt.fx ? So it will help us to check if torch_tensorrt.fx is installed in the test?

Absolutely, yes!

@frank-wei
Copy link
Contributor Author

@stas00 , just wondering if the circleci is flaky? Some tests errors are not related. For ex.
run_example_torch, check_code_quanlity

@stas00
Copy link
Contributor

stas00 commented Jun 30, 2022

It appears that the CI is very broken at the moment, I asked and will know more tomorrow morning.

Thank you for the heads up, @frank-wei - it doesn't look like any of the failures are related to your work. Especially since the live CI won't run any of your tests.

@stas00
Copy link
Contributor

stas00 commented Jun 30, 2022

ok, so for the quality one - please rebase this PR on main. Thank you.

The other issue I don't have an answer for yet.

update: I rebased - let's see with the update.

@stas00
Copy link
Contributor

stas00 commented Jun 30, 2022

ok, so to fix check_code_quality you need to run make style and push

after rebasing most of the CI failures are now coming from this PR:


==================================== ERRORS ====================================
______________ ERROR collecting tests/deepspeed/test_deepspeed.py ______________
ImportError while importing test module '/home/circleci/transformers/tests/deepspeed/test_deepspeed.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
/usr/local/lib/python3.7/importlib/__init__.py:127: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
tests/deepspeed/test_deepspeed.py:26: in <module>
    from tests.trainer.test_trainer import TrainerIntegrationCommon  # noqa
tests/trainer/test_trainer.py:586: in <module>
    class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
tests/trainer/test_trainer.py:1803: in TrainerIntegrationTest
    @require_torch_tensorrt_fx
src/transformers/testing_utils.py:499: in require_torch_tensorrt_fx
    return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
src/transformers/utils/import_utils.py:421: in is_torch_tensorrt_fx_available
    return importlib.util.find_spec("torch_tensorrt.fx") is not None
/usr/local/lib/python3.7/importlib/util.py:94: in find_spec
    parent = __import__(parent_name, fromlist=['__path__'])
E   ModuleNotFoundError: No module named 'torch_tensorrt'

Let me know if you need help with sorting it out.

@stas00
Copy link
Contributor

stas00 commented Jul 8, 2022

One of the tests is failing for me:

$ CUDA_VISIBLE_DEVICES=0 RUN_SLOW=1 pyt tests/trainer/test_trainer.py -k test_torchdynamo_memory -sv
        # AOT Autograd recomputaion and nvfuser recomputation optimization
        # aggressively fuses the operations and reduce the memory footprint.
>       self.assertGreater(orig_peak_mem, peak_mem * 2)
E       AssertionError: 100664832 not greater than 201330688

let me know what details you need - this is on A100.

oh, it actually crashed before that:

========== TorchDynamo Stack Trace ==========
Traceback (most recent call last):
  File "/mnt/nvme0/code/github/00pytorch/torchdynamo/torchdynamo/convert_frame.py", line 295, in _convert_frame_assert
    code = transform_code_object(frame.f_code, transform)
  File "/mnt/nvme0/code/github/00pytorch/torchdynamo/torchdynamo/bytecode_transformation.py", line 338, in transform_code_object
    transformations(instructions, code_options)
  File "/mnt/nvme0/code/github/00pytorch/torchdynamo/torchdynamo/convert_frame.py", line 261, in transform
    tracer = InstructionTranslator(
  File "/mnt/nvme0/code/github/00pytorch/torchdynamo/torchdynamo/symbolic_convert.py", line 1220, in __init__
    self.symbolic_locals = collections.OrderedDict(
  File "/mnt/nvme0/code/github/00pytorch/torchdynamo/torchdynamo/symbolic_convert.py", line 1221, in <genexpr>
    (k, VariableBuilder(self, LocalSource(k))(f_locals[k]))
  File "/mnt/nvme0/code/github/00pytorch/torchdynamo/torchdynamo/variables/builder.py", line 104, in __call__
    return self._wrap(value).clone(**self.options())
  File "/mnt/nvme0/code/github/00pytorch/torchdynamo/torchdynamo/variables/builder.py", line 130, in _wrap
    return self.wrap_tensor(value)
  File "/mnt/nvme0/code/github/00pytorch/torchdynamo/torchdynamo/variables/builder.py", line 327, in wrap_tensor
    tensor_variable = TensorVariable.create(
  File "/mnt/nvme0/code/github/00pytorch/torchdynamo/torchdynamo/variables/tensor.py", line 121, in create
    cls.wrap_to_fake_tensor, fake_mode=tx.fake_mode
  File "/mnt/nvme0/code/github/00pytorch/torchdynamo/torchdynamo/symbolic_convert.py", line 1136, in fake_mode
    return self._fake_mode
AttributeError: 'InstructionTranslator' object has no attribute '_fake_mode'

This is not great, shouldn't the test have failed here and not in a misleading later place of comparison?

Comment on lines 1223 to 1246
self.ctx_manager_torchdynamo = contextlib.nullcontext()
if self.torchdynamo:
if not is_torchdynamo_available():
raise RuntimeError("Torchdynamo is not installed.")

import torchdynamo
from torchdynamo.optimizations import backends
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy

if self.torchdynamo == "eager":
self.ctx_manager_torchdynamo = torchdynamo.optimize("eager")
elif self.torchdynamo == "nvfuser":
self.ctx_manager_torchdynamo = torchdynamo.optimize(aot_autograd_speedup_strategy)
elif self.torchdynamo == "fx2trt-fp16":
if not is_torch_tensorrt_fx_available():
raise RuntimeError("Torch-TensorRT FX path is not installed.")
self.ctx_manager_torchdynamo = torchdynamo.optimize(backends.fx2trt_compiler_fp16)
elif self.torchdynamo == "fx2trt":
if not is_torch_tensorrt_fx_available():
raise RuntimeError("Torch-TensorRT FX path is not installed.")
self.ctx_manager_torchdynamo = torchdynamo.optimize(backends.fx2trt_compiler)
else:
raise RuntimeError(f"Torchdynamo backend {self.torchdynamo} is not supported.")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.ctx_manager_torchdynamo = contextlib.nullcontext()
if self.torchdynamo:
if not is_torchdynamo_available():
raise RuntimeError("Torchdynamo is not installed.")
import torchdynamo
from torchdynamo.optimizations import backends
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
if self.torchdynamo == "eager":
self.ctx_manager_torchdynamo = torchdynamo.optimize("eager")
elif self.torchdynamo == "nvfuser":
self.ctx_manager_torchdynamo = torchdynamo.optimize(aot_autograd_speedup_strategy)
elif self.torchdynamo == "fx2trt-fp16":
if not is_torch_tensorrt_fx_available():
raise RuntimeError("Torch-TensorRT FX path is not installed.")
self.ctx_manager_torchdynamo = torchdynamo.optimize(backends.fx2trt_compiler_fp16)
elif self.torchdynamo == "fx2trt":
if not is_torch_tensorrt_fx_available():
raise RuntimeError("Torch-TensorRT FX path is not installed.")
self.ctx_manager_torchdynamo = torchdynamo.optimize(backends.fx2trt_compiler)
else:
raise RuntimeError(f"Torchdynamo backend {self.torchdynamo} is not supported.")
if self.torchdynamo:
if not is_torchdynamo_available():
raise RuntimeError("Torchdynamo is not installed.")
import torchdynamo
from torchdynamo.optimizations import backends
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
def get_ctx():
# Normal
if self.torchdynamo == "eager":
return torchdynamo.optimize("eager")
elif self.torchdynamo == "nvfuser":
return torchdynamo.optimize(aot_autograd_speedup_strategy)
# TensorRT
if self.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
if not is_torch_tensorrt_fx_available():
raise RuntimeError("Torch-TensorRT FX path is not installed.")
if self.torchdynamo == "fx2trt-fp16":
return torchdynamo.optimize(backends.fx2trt_compiler_fp16)
elif self.torchdynamo == "fx2trt":
return torchdynamo.optimize(backends.fx2trt_compiler)
raise RuntimeError(f"Torchdynamo backend {self.torchdynamo} is not supported.")
self.ctx_manager_torchdynamo = get_ctx()
else:
self.ctx_manager_torchdynamo = contextlib.nullcontext()

I've tried to make the very long ifelse a little bit easier to read and removing repetition. let me know if this resonates.

Untested.

Copy link
Contributor

@stas00 stas00 Jul 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if the number of options continues to grow, since it's quite symmetric we can just make a dict look up table and look it up for the context and the requirements. But this is probably good enough for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the efforts of refactoring. I took it in the change.

@frank-wei
Copy link
Contributor Author

The failure may due to the torchdynamo outdated? Could you install the newest torchdynamo? Here are the command to install it:

git clone https://github.com/pytorch/functorch
cd functorch
rm -rf build
pip install -e .[aot]

cd ..
git clone https://github.com/pytorch/torchdynamo
cd torchdynamo
pip install -r requirements.txt
python setup.py develop

It looks good from my testing:

(mypy38-fx-only) [wwei6@devgpu005.ftw6 /data/users/wwei6/Work/transformers] CUDA_VISIBLE_DEVICES=6 pytest tests/trainer/test_trainer.py  -k test_torchdynamo_memory -sv
===================================================================================== test session starts =====================================================================================
platform linux -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0 -- /data/users/wwei6/miniconda3/envs/mypy38-fx-only/bin/python
cachedir: .pytest_cache
benchmark: 3.4.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/data/users/wwei6/Work/transformers/.hypothesis/examples')
rootdir: /data/users/wwei6/Work/transformers, configfile: setup.cfg
plugins: benchmark-3.4.1, hydra-core-1.1.2, hypothesis-6.49.1
collected 70 items / 69 deselected / 1 selected                                                                                                                                               

tests/trainer/test_trainer.py::TrainerIntegrationTest::test_torchdynamo_memory PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
PASSED

====================================================================================== warnings summary =======================================================================================
../../miniconda3/envs/mypy38-fx-only/lib/python3.8/site-packages/torch/utils/tensorboard/__init__.py:4
  /data/users/wwei6/miniconda3/envs/mypy38-fx-only/lib/python3.8/site-packages/torch/utils/tensorboard/__init__.py:4: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
    if not hasattr(tensorboard, "__version__") or LooseVersion(

../../miniconda3/envs/mypy38-fx-only/lib/python3.8/site-packages/torch/utils/tensorboard/__init__.py:6
  /data/users/wwei6/miniconda3/envs/mypy38-fx-only/lib/python3.8/site-packages/torch/utils/tensorboard/__init__.py:6: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
    ) < LooseVersion("1.15"):

tests/trainer/test_trainer.py::TrainerIntegrationTest::test_torchdynamo_memory
  /data/users/wwei6/miniconda3/envs/mypy38-fx-only/lib/python3.8/site-packages/torch/nn/utils/_stateless.py:5: DeprecationWarning: The `torch.nn.utils._stateless` code is deprecated now that it is publicly available. Please use `torch.nn.utils.stateless instead.
    warnings.warn("The `torch.nn.utils._stateless` code is deprecated now that "

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================== 1 passed, 69 deselected, 3 warnings in 7.51s =========================================================================

@stas00
Copy link
Contributor

stas00 commented Jul 9, 2022

I suspected that was the case, but what I was trying to say is that the test should have failed on the torchdynamo error and not the mismatch in values, i.e. something is trapping the real error and the user could be not the wiser that their torchdynamo is broken - e.g. when there are a lot of logs.

It needs to assert on the actual error. Does it make sense?

@frank-wei
Copy link
Contributor Author

frank-wei commented Jul 9, 2022

I suspected that was the case, but what I was trying to say is that the test should have failed on the torchdynamo error and not the mismatch in values, i.e. something is trapping the real error and the user could be not the wiser that their torchdynamo is broken - e.g. when there are a lot of logs.

It needs to assert on the actual error. Does it make sense?

hm.. that is something out of my expertise as it relates with torchdynamo. If it is torch_tensorrt related, I'd love to help.

For the CI test error, it seems that test is flaky? I did not find useful any information. Could you help guide/triage that? Thanks.

@@ -1846,7 +1846,6 @@ def test_torchdynamo_full_eval(self):

@require_torch_non_multi_gpu
@require_torchdynamo
@require_torch_tensorrt_fx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this requirement, since it's not being used. please correct me if I'm wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catching! This test does not use tensorrt.

@stas00 stas00 requested a review from sgugger July 12, 2022 02:53
Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @frank-wei, I had to rebuild the whole environment against pt-nightly and now everything works.

I think it'd be good to save the instructions in the OP somewhere so that it's easier for the user and us to be able to rebuild the environment.

Would you like to maintain a section or a file on your side that contains the instructions in the OP and we could point to it?

Other than that, I will just ask Sylvain to have a quick review and we can merge this.

Thank you for your patience.

@frank-wei
Copy link
Contributor Author

frank-wei commented Jul 12, 2022

Hi @frank-wei, I had to rebuild the whole environment against pt-nightly and now everything works.

I think it'd be good to save the instructions in the OP somewhere so that it's easier for the user and us to be able to rebuild the environment.

Would you like to maintain a section or a file on your side that contains the instructions in the OP and we could point to it?

Other than that, I will just ask Sylvain to have a quick review and we can merge this.

Thank you for your patience.

Hi @frank-wei, I had to rebuild the whole environment against pt-nightly and now everything works.

I think it'd be good to save the instructions in the OP somewhere so that it's easier for the user and us to be able to rebuild the environment.

Would you like to maintain a section or a file on your side that contains the instructions in the OP and we could point to it?

Other than that, I will just ask Sylvain to have a quick review and we can merge this.

Thank you for your patience.

Thanks @stas00 , do you think I can add a 3 pointers for installations of torchdynamo, functorch, torch_tensorrt in docs/source/en/perf_train_gpu_one.mdx ?
Torchdynamo: https://github.com/pytorch/torchdynamo#requirements-and-setup
Functorch:https://github.com/pytorch/functorch#install
Torch-TensorRT(FX):https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst#installation

@stas00
Copy link
Contributor

stas00 commented Jul 12, 2022

I think that works, @frank-wei

@frank-wei
Copy link
Contributor Author

I think that works, @frank-wei

Cool. Update finished.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR! I left some comment on where the context manager "getter" should live, as it shouldn't be an attribute of TrainingArguments, a dataclass we serialize at each save.

Comment on lines 1223 to 1250
if self.torchdynamo:
if not is_torchdynamo_available():
raise RuntimeError("Torchdynamo is not installed.")

import torchdynamo
from torchdynamo.optimizations import backends
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy

def get_ctx():
# Normal
if self.torchdynamo == "eager":
return torchdynamo.optimize("eager")
elif self.torchdynamo == "nvfuser":
return torchdynamo.optimize(aot_autograd_speedup_strategy)
# TensorRT
if self.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
if not is_torch_tensorrt_fx_available():
raise RuntimeError("Torch-TensorRT FX path is not installed.")
if self.torchdynamo == "fx2trt-fp16":
return torchdynamo.optimize(backends.fx2trt_compiler_fp16)
elif self.torchdynamo == "fx2trt":
return torchdynamo.optimize(backends.fx2trt_compiler)
else:
raise RuntimeError(f"Torchdynamo backend {self.torchdynamo} is not supported.")

self.ctx_manager_torchdynamo = get_ctx()
else:
self.ctx_manager_torchdynamo = contextlib.nullcontext()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big fan of having this as an attribute of the TrainingArguments: I think it will break serialization (see here). This all could fit in a function that takes the value of self.dynamo (since it's the only field of TrainingArguments it uses) and lies in integrations.py. The code in the trainer file should then be adapted slightly.

Copy link
Contributor Author

@frank-wei frank-wei Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger, that sounds good to me. I moved the function to integrations.py but has circular import issue.

  File "/home/runner/work/transformers/transformers/src/transformers/training_args.py", line 26, in <module>
    from .integrations import get_torchdynamo_ctx
  File "/home/runner/work/transformers/transformers/src/transformers/integrations.py", line 47, in <module>
    from .trainer_callback import ProgressCallback, TrainerCallback  # noqa: E402
  File "/home/runner/work/transformers/transformers/src/transformers/trainer_callback.py", line 27, in <module>
    from .training_args import TrainingArguments
ImportError: cannot import name 'TrainingArguments' from partially initialized module 'transformers.training_args' (most likely due to a circular import) (/home/runner/work/transformers/transformers/src/transformers/training_args.py)

Is it good to leave function get_torchdynamo_ctx as a member of TrainingArguments? Or leave it in import_utils.py to stay together with is_torchdynamo_available()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be imported at all in the training_args.py module, only in the trainer.py. As I said, you shouldn't add new attributes to TrainingArguments that are not serializable.

Copy link
Contributor

@stas00 stas00 Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger, the original implementation calculated the context on every call - that's why I suggested to move the logic to the argrparse stage, since this logic needs to be done only once per program run.

What would be a good place then to perform this figuring out? In trainer's init probably, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works, yes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frank-wei, please let me know if you need help here - moving to trainer's init that is.

@frank-wei
Copy link
Contributor Author

@stas00 @sgugger please check the change. The failed test seems flaky and not related.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Way better this way, thanks for iterating! Left a last nit.

Comment on lines 603 to 625
if self.args.torchdynamo:
if not is_torchdynamo_available():
raise RuntimeError("Torchdynamo is not installed.")
import torchdynamo
from torchdynamo.optimizations import backends
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy

def get_ctx():
# Normal
if self.args.torchdynamo == "eager":
return torchdynamo.optimize("eager")
elif self.args.torchdynamo == "nvfuser":
return torchdynamo.optimize(aot_autograd_speedup_strategy)
# TensorRT
if self.args.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
if not is_torch_tensorrt_fx_available():
raise RuntimeError("Torch-TensorRT FX path is not installed.")
if self.args.torchdynamo == "fx2trt-fp16":
return torchdynamo.optimize(backends.fx2trt_compiler_fp16)
elif self.args.torchdynamo == "fx2trt":
return torchdynamo.optimize(backends.fx2trt_compiler)
else:
raise RuntimeError(f"Torchdynamo backend {self.args.torchdynamo} is not supported.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All self.args -> args, we've defined a shortcut :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Fixed that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@sgugger
Copy link
Collaborator

sgugger commented Jul 13, 2022

@stas00 Are you good with this last iteration (as long as all tests pass?)

@stas00
Copy link
Contributor

stas00 commented Jul 13, 2022

Let me run the tests.

@stas00
Copy link
Contributor

stas00 commented Jul 13, 2022

All tests pass. Good to merge once the CI is green.

I created a new task #18127 to handle the CI requirements.

@sgugger sgugger merged commit 7ea6ccc into huggingface:main Jul 13, 2022
viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
* enable fx2trt

* Update perf_train_gpu_one.mdx

* Update perf_train_gpu_one.mdx

* add lib check

* update

* format

* update

* fix import check

* fix isort

* improve doc

* refactor ctx manager

* fix isort

* black format

* isort fix

* fix format

* update args

* update black

* cleanups

* Update perf_train_gpu_one.mdx

* code refactor

* code refactor to init

* remove redundancy

* isort

* replace self.args with args

Co-authored-by: Stas Bekman <stas@stason.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants