Skip to content

Conversation

BowenBao
Copy link
Contributor

@BowenBao BowenBao commented Jul 26, 2023

Stack from ghstack (oldest at bottom):

Retrieving torch dtype is slow due to fetching 'JitScalarType' through
pybind.

In a recent change in PyTorch exporter, this method is invoked heavily
to produce diagnostic info for operator arguments, which leads to slow
down.

Before this PR

...
└─ 17.767 FxOnnxInterpreter.run  torch/onnx/_internal/fx/fx_onnx_interpreter.py:418
   └─ 17.763 run_node  <@beartype(torch.onnx._internal.fx.fx_onnx_interpreter.FxOnnxInterpreter.run_node) at 0x7f7c73513e50>:1
         [2 frames hidden]  <@beartype(torch.onnx._internal.fx.fx...
            17.761 wrapper  torch/onnx/_internal/diagnostics/infra/decorator.py:70
            ├─ 15.735 format_function_signature_in_markdown  <@beartype(torch.onnx._internal.diagnostics.infra.decorator.format_function_signature_in_markdown) at 0x7f7c8c946040>:1
            │     [2 frames hidden]  <@beartype(torch.onnx._internal.diagn...
            │        15.733 format_function_signature_in_markdown  torch/onnx/_internal/diagnostics/infra/decorator.py:30
            │        └─ 15.658 format_argument  torch/onnx/_internal/fx/diagnostics.py:35
            │           └─ 15.503 _dict  torch/onnx/_internal/fx/diagnostics.py:150
            │              └─ 15.031 format_argument  torch/onnx/_internal/fx/diagnostics.py:35
            │                 ├─ 13.234 _onnxscript_torch_script_tensor  torch/onnx/_internal/fx/diagnostics.py:165
            │                 │  ├─ 11.797 TorchScriptTensor.dtype  onnxscript/function_libs/torch_lib/graph_building.py:150
            │                 │  │  └─ 11.207 from_value  <@beartype(torch.onnx._type_utils.JitScalarType.from_value) at 0x7f7c8d909e50>:1
            │                 │  │        [46 frames hidden]  <@beartype(torch.onnx._type_utils.Jit...
            │                 │  │           7.431 JitScalarType.from_value  torch/onnx/_type_utils.py:148
            │                 │  │           ├─ 4.399 _from_name  <@beartype(torch.onnx._type_utils.JitScalarType._from_name) at 0x7f7c8d909b80>:1
            │                 │  │           │     [46 frames hidden]  <@beartype(torch.onnx._type_utils.Jit...
            │                 │  │           └─ 2.813 [self]  None
            │                 │  └─ 0.779 _stringify_shape  torch/onnx/_internal/fx/diagnostics.py:181
...

After this PR:

...
└─ 4.689 FxOnnxInterpreter.run  torch/onnx/_internal/fx/fx_onnx_interpreter.py:418
   └─ 4.687 run_node  <@beartype(torch.onnx._internal.fx.fx_onnx_interpreter.FxOnnxInterpreter.run_node) at 0x7f5a630b5c10>:1
         [4 frames hidden]  <@beartype(torch.onnx._internal.fx.fx...
            4.684 wrapper  torch/onnx/_internal/diagnostics/infra/decorator.py:70
            ├─ 3.113 format_function_signature_in_markdown  <@beartype(torch.onnx._internal.diagnostics.infra.decorator.format_function_signature_in_markdown) at 0x7f5a77f14b80>:1
            │     [2 frames hidden]  <@beartype(torch.onnx._internal.diagn...
            │        3.110 format_function_signature_in_markdown  torch/onnx/_internal/diagnostics/infra/decorator.py:30
            │        └─ 3.049 format_argument  torch/onnx/_internal/fx/diagnostics.py:35
            │           └─ 2.894 _dict  torch/onnx/_internal/fx/diagnostics.py:150
            │              └─ 2.409 format_argument  torch/onnx/_internal/fx/diagnostics.py:35
            │                 └─ 1.551 _onnxscript_torch_script_tensor  torch/onnx/_internal/fx/diagnostics.py:165
            │                    └─ 0.628 _stringify_shape  torch/onnx/_internal/fx/diagnostics.py:181
...

Retrieving torch dtype is slow due to fetching 'JitScalarType' through
pybind.

In a recent change in PyTorch exporter, this method is invoked heavily
to produce diagnostic info for operator arguments, which leads to slow
down.

Before this PR
```
...
|     │  │     └─ 17.767 FxOnnxInterpreter.run  torch/onnx/_internal/fx/fx_onnx_interpreter.py:418
│     │  │        └─ 17.763 run_node  <@beartype(torch.onnx._internal.fx.fx_onnx_interpreter.FxOnnxInterpreter.run_node) at 0x7f7c73513e50>:1
│     │  │              [2 frames hidden]  <@beartype(torch.onnx._internal.fx.fx...
│     │  │                 17.761 wrapper  torch/onnx/_internal/diagnostics/infra/decorator.py:70
│     │  │                 ├─ 15.735 format_function_signature_in_markdown  <@beartype(torch.onnx._internal.diagnostics.infra.decorator.format_function_signature_in_markdown) at 0x7f7c8c946040>:1
│     │  │                 │     [2 frames hidden]  <@beartype(torch.onnx._internal.diagn...
│     │  │                 │        15.733 format_function_signature_in_markdown  torch/onnx/_internal/diagnostics/infra/decorator.py:30
│     │  │                 │        └─ 15.658 format_argument  torch/onnx/_internal/fx/diagnostics.py:35
│     │  │                 │           └─ 15.503 _dict  torch/onnx/_internal/fx/diagnostics.py:150
│     │  │                 │              └─ 15.031 format_argument  torch/onnx/_internal/fx/diagnostics.py:35
│     │  │                 │                 ├─ 13.234 _onnxscript_torch_script_tensor  torch/onnx/_internal/fx/diagnostics.py:165
│     │  │                 │                 │  ├─ 11.797 TorchScriptTensor.dtype  onnxscript/function_libs/torch_lib/graph_building.py:150
│     │  │                 │                 │  │  └─ 11.207 from_value  <@beartype(torch.onnx._type_utils.JitScalarType.from_value) at 0x7f7c8d909e50>:1
│     │  │                 │                 │  │        [46 frames hidden]  <@beartype(torch.onnx._type_utils.Jit...
│     │  │                 │                 │  │           7.431 JitScalarType.from_value  torch/onnx/_type_utils.py:148
│     │  │                 │                 │  │           ├─ 4.399 _from_name  <@beartype(torch.onnx._type_utils.JitScalarType._from_name) at 0x7f7c8d909b80>:1
│     │  │                 │                 │  │           │     [46 frames hidden]  <@beartype(torch.onnx._type_utils.Jit...
│     │  │                 │                 │  │           └─ 2.813 [self]  None
│     │  │                 │                 │  └─ 0.779 _stringify_shape  torch/onnx/_internal/fx/diagnostics.py:181
...
```

After this PR:
```
...
│     │  │     └─ 4.689 FxOnnxInterpreter.run  torch/onnx/_internal/fx/fx_onnx_interpreter.py:418
│     │  │        └─ 4.687 run_node  <@beartype(torch.onnx._internal.fx.fx_onnx_interpreter.FxOnnxInterpreter.run_node) at 0x7f5a630b5c10>:1
│     │  │              [4 frames hidden]  <@beartype(torch.onnx._internal.fx.fx...
│     │  │                 4.684 wrapper  torch/onnx/_internal/diagnostics/infra/decorator.py:70
│     │  │                 ├─ 3.113 format_function_signature_in_markdown  <@beartype(torch.onnx._internal.diagnostics.infra.decorator.format_function_signature_in_markdown) at 0x7f5a77f14b80>:1
│     │  │                 │     [2 frames hidden]  <@beartype(torch.onnx._internal.diagn...
│     │  │                 │        3.110 format_function_signature_in_markdown  torch/onnx/_internal/diagnostics/infra/decorator.py:30
│     │  │                 │        └─ 3.049 format_argument  torch/onnx/_internal/fx/diagnostics.py:35
│     │  │                 │           └─ 2.894 _dict  torch/onnx/_internal/fx/diagnostics.py:150
│     │  │                 │              └─ 2.409 format_argument  torch/onnx/_internal/fx/diagnostics.py:35
│     │  │                 │                 └─ 1.551 _onnxscript_torch_script_tensor  torch/onnx/_internal/fx/diagnostics.py:165
│     │  │                 │                    └─ 0.628 _stringify_shape  torch/onnx/_internal/fx/diagnostics.py:181
...
```

[ghstack-poisoned]
BowenBao added a commit that referenced this pull request Jul 26, 2023
Retrieving torch dtype is slow due to fetching 'JitScalarType' through
pybind.

In a recent change in PyTorch exporter, this method is invoked heavily
to produce diagnostic info for operator arguments, which leads to slow
down.

Before this PR
```
...
|     │  │     └─ 17.767 FxOnnxInterpreter.run  torch/onnx/_internal/fx/fx_onnx_interpreter.py:418
│     │  │        └─ 17.763 run_node  <beartype(torch.onnx._internal.fx.fx_onnx_interpreter.FxOnnxInterpreter.run_node) at 0x7f7c73513e50>:1
│     │  │              [2 frames hidden]  <beartype(torch.onnx._internal.fx.fx...
│     │  │                 17.761 wrapper  torch/onnx/_internal/diagnostics/infra/decorator.py:70
│     │  │                 ├─ 15.735 format_function_signature_in_markdown  <beartype(torch.onnx._internal.diagnostics.infra.decorator.format_function_signature_in_markdown) at 0x7f7c8c946040>:1
│     │  │                 │     [2 frames hidden]  <beartype(torch.onnx._internal.diagn...
│     │  │                 │        15.733 format_function_signature_in_markdown  torch/onnx/_internal/diagnostics/infra/decorator.py:30
│     │  │                 │        └─ 15.658 format_argument  torch/onnx/_internal/fx/diagnostics.py:35
│     │  │                 │           └─ 15.503 _dict  torch/onnx/_internal/fx/diagnostics.py:150
│     │  │                 │              └─ 15.031 format_argument  torch/onnx/_internal/fx/diagnostics.py:35
│     │  │                 │                 ├─ 13.234 _onnxscript_torch_script_tensor  torch/onnx/_internal/fx/diagnostics.py:165
│     │  │                 │                 │  ├─ 11.797 TorchScriptTensor.dtype  onnxscript/function_libs/torch_lib/graph_building.py:150
│     │  │                 │                 │  │  └─ 11.207 from_value  <beartype(torch.onnx._type_utils.JitScalarType.from_value) at 0x7f7c8d909e50>:1
│     │  │                 │                 │  │        [46 frames hidden]  <beartype(torch.onnx._type_utils.Jit...
│     │  │                 │                 │  │           7.431 JitScalarType.from_value  torch/onnx/_type_utils.py:148
│     │  │                 │                 │  │           ├─ 4.399 _from_name  <beartype(torch.onnx._type_utils.JitScalarType._from_name) at 0x7f7c8d909b80>:1
│     │  │                 │                 │  │           │     [46 frames hidden]  <beartype(torch.onnx._type_utils.Jit...
│     │  │                 │                 │  │           └─ 2.813 [self]  None
│     │  │                 │                 │  └─ 0.779 _stringify_shape  torch/onnx/_internal/fx/diagnostics.py:181
...
```

After this PR:
```
...
│     │  │     └─ 4.689 FxOnnxInterpreter.run  torch/onnx/_internal/fx/fx_onnx_interpreter.py:418
│     │  │        └─ 4.687 run_node  <beartype(torch.onnx._internal.fx.fx_onnx_interpreter.FxOnnxInterpreter.run_node) at 0x7f5a630b5c10>:1
│     │  │              [4 frames hidden]  <beartype(torch.onnx._internal.fx.fx...
│     │  │                 4.684 wrapper  torch/onnx/_internal/diagnostics/infra/decorator.py:70
│     │  │                 ├─ 3.113 format_function_signature_in_markdown  <beartype(torch.onnx._internal.diagnostics.infra.decorator.format_function_signature_in_markdown) at 0x7f5a77f14b80>:1
│     │  │                 │     [2 frames hidden]  <beartype(torch.onnx._internal.diagn...
│     │  │                 │        3.110 format_function_signature_in_markdown  torch/onnx/_internal/diagnostics/infra/decorator.py:30
│     │  │                 │        └─ 3.049 format_argument  torch/onnx/_internal/fx/diagnostics.py:35
│     │  │                 │           └─ 2.894 _dict  torch/onnx/_internal/fx/diagnostics.py:150
│     │  │                 │              └─ 2.409 format_argument  torch/onnx/_internal/fx/diagnostics.py:35
│     │  │                 │                 └─ 1.551 _onnxscript_torch_script_tensor  torch/onnx/_internal/fx/diagnostics.py:165
│     │  │                 │                    └─ 0.628 _stringify_shape  torch/onnx/_internal/fx/diagnostics.py:181
...
```

ghstack-source-id: dcb8131
Pull Request resolved: #922
@BowenBao BowenBao added the module: torchlib Related to the torch/aten function lib in development label Jul 26, 2023
@codecov
Copy link

codecov bot commented Jul 26, 2023

Codecov Report

Merging #922 (b343c3f) into gh/BowenBao/8/base (f24797f) will decrease coverage by 0.02%.
The diff coverage is 33.33%.

@@                  Coverage Diff                   @@
##           gh/BowenBao/8/base     #922      +/-   ##
======================================================
- Coverage               76.81%   76.80%   -0.02%     
======================================================
  Files                     112      112              
  Lines                   13547    13552       +5     
  Branches                 1377     1378       +1     
======================================================
+ Hits                    10406    10408       +2     
- Misses                   2798     2801       +3     
  Partials                  343      343              
Files Changed Coverage Δ
...nxscript/function_libs/torch_lib/graph_building.py 83.77% <33.33%> (-0.66%) ⬇️

@justinchuby justinchuby added the change base before merge Remember to change the merge base to main when the PR is ready to merge label Jul 26, 2023
@justinchuby justinchuby changed the title TorchScriptTensor to cache torch dtype for performance TorchScriptTensor to cache torch dtype for performance | feat(torchlib) Jul 26, 2023
@BowenBao BowenBao changed the base branch from gh/BowenBao/8/base to main July 26, 2023 16:15
@BowenBao BowenBao merged commit 0b1a0e3 into main Jul 26, 2023
@BowenBao BowenBao deleted the gh/BowenBao/8/head branch July 26, 2023 16:16
BowenBao added a commit to pytorch/pytorch that referenced this pull request Jul 26, 2023
…dict in diagnostics"


In a recent change, diagnostics started logging contents within tuple/list/dict
for diagnosed function arguments and return types. This brought slow down
to export due to some extremely large container instances, such as the fx to 
onnx node mapping dictionary.

This PR adds a limit to how many elements the diagnostic would record for
these types. Together with microsoft/onnxscript#922, the performance of
export w/ diagnostics is improved. As shown by pyinstrument:

GPT2 time for `fx_to_onnx_interpreter.run` 17.767s -> 1.961s
xcit_large_24_p8_224 time for `fx_to_onnx_interpreter.run` 144.729s -> 4.067s

[ghstack-poisoned]
BowenBao added a commit to pytorch/pytorch that referenced this pull request Jul 26, 2023
…ist/tuple/dict in diagnostics"


In a recent change, diagnostics started logging contents within tuple/list/dict
for diagnosed function arguments and return types. This brought slow down
to export due to some extremely large container instances, such as the fx to 
onnx node mapping dictionary.

This PR adds a limit to how many elements the diagnostic would record for
these types. Together with microsoft/onnxscript#922, the performance of
export w/ diagnostics is restored and improved. As shown by pyinstrument:

GPT2 time for `fx_to_onnx_interpreter.run` 17.767s -> 1.961s
xcit_large_24_p8_224 time for `fx_to_onnx_interpreter.run` 144.729s -> 4.067s

[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jul 26, 2023
…gnostics (#106048)

In a recent change, diagnostics started logging contents within tuple/list/dict
for diagnosed function arguments and return types. This brought slow down
to export due to some extremely large container instances, such as the fx to
onnx node mapping dictionary.

This PR adds a limit to how many elements the diagnostic would record for
these types. Together with microsoft/onnxscript#922, the performance of
export w/ diagnostics is restored and improved. As shown by pyinstrument:

GPT2 time for `fx_to_onnx_interpreter.run` 17.767s -> 1.961s
xcit_large_24_p8_224 time for `fx_to_onnx_interpreter.run` 144.729s -> 4.067s
Pull Request resolved: #106048
Approved by: https://github.com/titaiwangms, https://github.com/justinchuby
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

change base before merge Remember to change the merge base to main when the PR is ready to merge module: torchlib Related to the torch/aten function lib in development

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants