Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused transposed elementwise ops in dispatch region causing extra shared memory allocation #12523

Closed
Abhishek-Varma opened this issue Mar 6, 2023 · 50 comments · Fixed by #13823
Assignees
Labels
bug 🐞 Something isn't working codegen/spirv SPIR-V code generation compiler backend

Comments

@Abhishek-Varma
Copy link
Contributor

What happened?

On trying to pass the IR through iree-run-module, I get the following error :-

C:\A\iree\runtime\src\iree\hal\drivers\vulkan\native_executable.cc:157: UNAVAILABLE; VK_ERROR_INITIALIZATION_FAILED; while invoking native function hal.executable.create; while calling import;
[ 1]   native hal.executable.create:0 -
[ 0] bytecode module.__init:268 .\dispatch\module_forward_dispatch_28_vulkan_spirv_fb.mlir:2:3

This takes place for --iree-vulkan-target-triple=rdna2-unknown-windows

Steps to reproduce your issue

Download module_forward_dispatch_28_vulkan_spirv_fb.mlir.

Step 1.

.\iree-compile.exe module_forward_dispatch_28_vulkan_spirv_fb.mlir --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-windows -o test_28vmfb

Step 2.

.iree-run-module.exe --module=test_28.vmfb --device=vulkan --function=forward_dispatch_28_matmul_4096x512x512

What component(s) does this issue relate to?

Compiler, Runtime

Version information

No response

Additional context

No response

@powderluv
Copy link
Collaborator

This used to work at some point and stopped working in the recent past. We have pinned to an earlier version to workaround.

@antiagainst
Copy link
Contributor

This is because we have a transposed linalg.generic fused together with linalg.matmul:

%10 = linalg.matmul
  ins(%4, %5 : tensor<4096x512xf16>, tensor<512x512xf16>)
  outs(%9 : tensor<4096x512xf16>) -> tensor<4096x512xf16>
%11 = linalg.generic {
  indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, 
                   affine_map<(d0, d1) -> (d1, d0)>], 
  iterator_types = ["parallel", "parallel"]
} ins(%10, %6 : tensor<4096x512xf16>, tensor<512xf16>) outs(%7 : tensor<512x4096xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
  %12 = arith.addf %in, %in_0 : f16
  linalg.yield %12 : f16
} -> tensor<512x4096xf16>

So bufferization allocated a buffer from workgroup memory to hold the intermediate result for linalg.matmul, and that exceeds the total amount of allowed workgroup memory.

The issue is likely at a higher level. I'd be interested to know how we generate such dispatches in flow. Did you perform layout transformations and handle transposes somehow causing this?

@antiagainst antiagainst added codegen/spirv SPIR-V code generation compiler backend and removed awaiting-triage labels Mar 7, 2023
@benvanik
Copy link
Collaborator

benvanik commented Mar 7, 2023

ew, using a temporary buffer is really unfortunate - one shouldn't be needed here for a simple bias add.

@powderluv
Copy link
Collaborator

So it was working until about a week ago and we hit the issue. We have pinned torch-mlir to an older version to avoid this issue. Is the issue in higher level in torch-mlir or in flow dialect ?

@Abhishek-Varma
Copy link
Contributor Author

This is because we have a transposed linalg.generic fused together with linalg.matmul:

%10 = linalg.matmul
  ins(%4, %5 : tensor<4096x512xf16>, tensor<512x512xf16>)
  outs(%9 : tensor<4096x512xf16>) -> tensor<4096x512xf16>
%11 = linalg.generic {
  indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, 
                   affine_map<(d0, d1) -> (d1, d0)>], 
  iterator_types = ["parallel", "parallel"]
} ins(%10, %6 : tensor<4096x512xf16>, tensor<512xf16>) outs(%7 : tensor<512x4096xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
  %12 = arith.addf %in, %in_0 : f16
  linalg.yield %12 : f16
} -> tensor<512x4096xf16>

So bufferization allocated a buffer from workgroup memory to hold the intermediate result for linalg.matmul, and that exceeds the total amount of allowed workgroup memory.

The issue is likely at a higher level. I'd be interested to know how we generate such dispatches in flow. Did you perform layout transformations and handle transposes somehow causing this?

So I used the linalg IR I get from torch-mlir to dump the dispatches using iree-compile and find the "culprit" dispatch.
As @powderluv mentioned, we have pinned to an older version of torch-mlir to avoid this issue.

There seems to be no special layout transformations or handling of transposes that we're using.

@powderluv
Copy link
Collaborator

Can we confirm the IR difference between the two versions? And post both too?

@Abhishek-Varma
Copy link
Contributor Author

Abhishek-Varma commented Mar 9, 2023

The op count difference with OLD vs LATEST torch-mlir :

At torch level:

torch.aten.broadcast_to        , 126 vs 120
torch.aten.mm                  , 1 vs 4
torch.aten.view                , 81 vs 78
torch.prim.ListConstruct       , 147 vs 146

At linalg level:

linalg.batch_matmul      , 5 vs 2
linalg.generic           , 794 vs 790
linalg.matmul            , 1 vs 4
linalg.yield             , 844 vs 840
tensor.collapse_shape    , 65 vs 66
tensor.empty             , 39 vs 38
tensor.expand_shape      , 126 vs 129

Elided Linalg IR which has the issue
Elided Linalg IR which we have from the older version of torch-mlir

The main set of difference I see are the 3 "extra" matmuls we have in the newer IR (%94, %97 and %100) forming a matmul + expand_shape + generic op set.

I tried experimenting with different torch versions and found that the issue is specifically with torch==2.0.0.dev20230228.
If we keep the latest torch-mlir, but pin to older torch versions, say torch==2.0.0.dev20230220 or torch==2.0.0.dev20230227 - the issue doesn't persist.

There are just a handful of decompositions we use in our pipeline as can be seen here.

So, I tried inspecting the decompositions which we're mainly using in our pipeline and the only "relevant" delta I see between torch==2.0.0.dev20230228 and torch==2.0.0.dev20230220 is fix embedding_backward_dense.
I tried reverting the one line change to see if it has any effect, but to no avail.

@antiagainst
Copy link
Contributor

antiagainst commented Mar 10, 2023

Thanks for the full input IR. However, I cannot reproduce the issue. With e2151d3 and tools/iree-compile --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-windows --compile-to=flow running on the new Linalg IR supposed to have the issue, I see 1) all the dispatch regions containing fused matmul and elementwise do not involve transpose as discovered above; 2) transpose linalg.generic ops are in their own dispatch regions. Please double check that this is still a problem (maybe some transformation done on your own side causing it?).

@powderluv
Copy link
Collaborator

I think this bug is magically fixed. I don't know what changed. We will revert the pin and keep an eye out and report back. Thank you for investigating.

@powderluv
Copy link
Collaborator

yeah there is no magic :( I was testing via the wmma pipeline which worked. This failure only happens on the RDNA2/ SIMT pipeline. Re-opening.

@powderluv powderluv reopened this Mar 11, 2023
@powderluv
Copy link
Collaborator

adding @MaheshRavishankar too for any guidance since we think it may be something at the flow level.

@powderluv
Copy link
Collaborator

@yzhang93 / @qedawkins could this be because the tunings changed or are now invalid since @antiagainst tried on top of master?

@powderluv
Copy link
Collaborator

yup confirmed that the tunings we apply causes this crash at runtime. Maybe we can have runtime checks for them ? I don't know how the verifier let this go ? Maybe we need to enhance the verifier to capture this failure too.

@qedawkins
Copy link
Contributor

The verifier might not be considering fused ops when determining shared memory requirements. Also I think it only verifies named matmuls (everything else just passes straight through) but I'd have to double check the verifier to be sure (am away from desk)

@yzhang93
Copy link
Contributor

yup confirmed that the tunings we apply causes this crash at runtime. Maybe we can have runtime checks for them ? I don't know how the verifier let this go ? Maybe we need to enhance the verifier to capture this failure too.

I'm sure previously both tuned and untuned model had the failure. But maybe something has changed and untuned model works fine now. I think it's the VAE model and we don't apply lowering configs on it. I'll check if the Winograd transform caused the problem.

@antiagainst
Copy link
Contributor

Not sure tuning is the problem--tuning just adjusts the tile sizes and such after seeing the dispatch region. The problem is having fused transposed elementwise op in the dispatch region from the beginning, which causes bufferization to insert extra allocations in shared memory.

It would be beneficial to understand why we are forming such dispatch regions, like doing --mlir-print-ir-after-all to understand how such dispatch region is generated using Shark fork, as I cannot repro this with IREE top of the tree. (I assume its some combination of patterns causing this..) The solution would be either to avoid forming such dispatch regions (e.g., separating/having transpose in its own dispatch region), or teach bufferization to be smarter (not sure how feasible it is here).

@antiagainst antiagainst changed the title IREE Runtime - UNAVAILABLE; VK_ERROR_INITIALIZATION_FAILED error Fused transposed elementwise ops in dispatch region causing extra shared memory allocation Mar 11, 2023
@antiagainst
Copy link
Contributor

Or actually maybe I'm not using the proper command-line option to reproduce the issue from the full model. I was using tools/iree-compile --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-windows from the issue; but that's just for the single dispatch. I recall there are preprocessing steps. Let me know.

@powderluv
Copy link
Collaborator

here are our typical flags

 C:\Users\foo\AppData\Local\Temp\_MEI83882\iree\compiler\tools\..\_mlir_libs\iree-compile.exe - --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvm-embedded-linker-path=C:\Users\foo\AppData\Local\Temp\_MEI83882\iree\compiler\tools\..\_mlir_libs\iree-lld.exe --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-cpu-features=host --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-util-zero-fill-elided-attrs -iree-vulkan-target-triple=rdna2-7900-windows

@yzhang93
Copy link
Contributor

yzhang93 commented Mar 12, 2023

@antiagainst I just tested with the latest nightly IREE python package (https://github.com/openxla/iree/releases/tag/candidate-20230311.455) and the above dispatch still has the compilation error. And I confirmed there's nothing to do with tuning. I tested on my navi3 system with the following commands:
iree-compile module_forward_dispatch_28_vulkan_spirv_fb.mlir --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux -o test_28.vmfb
iree-run-module --module=test_28.vmfb --device=vulkan --function=forward_dispatch_28_matmul_4096x512x512

The above dispatch is from VAE model and the whole model failed with the same error. We use the preprocessing flags when compiling the whole model.

@powderluv
Copy link
Collaborator

@yzhang93 can we please run with --mlir-print-ir-after-all too

@powderluv
Copy link
Collaborator

Here is the output of --mlir-print-ir-after-all

@powderluv
Copy link
Collaborator

@antiagainst if this is a AMD driver issue is there a temporary way to override this dispatch 28 creation until it is fixed ?

@powderluv
Copy link
Collaborator

Also should the dispatch_28 not be formed for those cards that expose less shared memory ?

@MaheshRavishankar
Copy link
Contributor

How much is the shared memory usage?

@MaheshRavishankar
Copy link
Contributor

it seems like the tile sizes chosen (which affects the shared memory usage) is not account for shared memory usage... So this is a backend issue.

@powderluv
Copy link
Collaborator

@yzhang93 can we tune dispatch_28 for rdna2 so it doesn't crash ?

@powderluv
Copy link
Collaborator

it seems like the tile sizes chosen (which affects the shared memory usage) is not account for shared memory usage... So this is a backend issue.

Backend here is the AMD vulkan driver ?

@MaheshRavishankar
Copy link
Contributor

it seems like the tile sizes chosen (which affects the shared memory usage) is not account for shared memory usage... So this is a backend issue.

Backend here is the AMD vulkan driver ?

No, this should be the SPIR-V backend in IREE.

@powderluv
Copy link
Collaborator

ah ok. Our generated spv

@antiagainst
Copy link
Contributor

Thanks for the repro steps in #12523 (comment). I've seen the issue and figured out what went wrong and put up a fix in #12627.

@antiagainst
Copy link
Contributor

With the above we won't fuse such cases. @powderluv or somebody else if you can help to verify this works that'd be nice.

@powderluv
Copy link
Collaborator

Thank you. I can confirm it works ok

(shark.venv) PS C:\g\shark> iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-util-zero-fill-elided-attrs --iree-vulkan-target-triple=rdna2-unknown-linux --iree-preprocessing-pass-pipeline='builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-preprocessing-pad-linalg-ops{pad-size=32}))' C:\Users\anush\Downloads\vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan_torch.mlir -o test.vmfb
(shark.venv) PS C:\g\shark> iree-benchmark-module --module=test.vmfb --function=forward --device=vulkan --input=1x4x64x64xf16
2023-03-14T00:10:14-07:00
Running C:\g\shark\shark.venv\Lib\site-packages\iree\runtime\scripts\iree_benchmark_module\..\..\iree-benchmark-module
Run on (32 X 4491.57 MHz CPU s)
CPU Caches:
 L1 Data 32 KiB (x16)
 L1 Instruction 32 KiB (x16)
 L2 Unified 1024 KiB (x16)
 L3 Unified 32768 KiB (x2)
--------------------------------------------------------------------------------------------
Benchmark                                  Time             CPU   Iterations UserCounters...
--------------------------------------------------------------------------------------------
BM_forward/process_time/real_time        792 ms        0.000 ms            1 items_per_second=1.26315/s

@powderluv powderluv added this to the IREE + NOD Model Coverage milestone Mar 15, 2023
@powderluv
Copy link
Collaborator

@antiagainst When the right fix lands would we back to original memory usage? We are carrying this locally on SHARK-Runtime and some end users on 8GB cards are running out of VRAM.

@aaron-schneider
Copy link

aaron-schneider commented Mar 20, 2023

Silly Q, but this issue is marked "Done" in status, but open in the IREE project (https://github.com/orgs/openxla/projects/13?pane=issue&itemId=22177652). Any idea why? (I can click "Close with Comment", but I guess I don't get the difference here.) Thanks!

@antiagainst
Copy link
Contributor

@antiagainst When the right fix lands would we back to original memory usage? We are carrying this locally on SHARK-Runtime and some end users on 8GB cards are running out of VRAM.

I'm not sure this one would address memory usage issues. We were never able to handle such fusion cases before; we never hit such cases previously. So it's not like we are not fusing some previously fused cases. The memory usage issue is likely different.

@antiagainst
Copy link
Contributor

Silly Q, but this issue is marked "Done" in status, but open in the IREE project (https://github.com/orgs/openxla/projects/13?pane=issue&itemId=22177652). Any idea why? (I can click "Close with Comment", but I guess I don't get the difference here.) Thanks!

Ha, interesting. This is not done; so I moved it back to "In Progress". The fix in #12627 is not the long term way to go. I'll spend some time to do it more proper later.

@aaron-schneider
Copy link

Hi - double checking on this P0 issue. More to say? Ok to close or lower priority? Thanks!

@powderluv
Copy link
Collaborator

Moving this to a P1 since we have a workaround (for SHARK at least)

@GMNGeoffrey
Copy link
Contributor

P1 is ok, but this is a release blocker, I think. I foolishly started talking about things in discord, but cross-posting here.

Looking at #12627 it seems like we have the option to drop a feature in order to fix the bug. I would advocate for that or hiding it behind a flag so that we don't have to wait for a big rewrite to fix this issue.

@MaheshRavishankar
Copy link
Contributor

P1 is ok, but this is a release blocker, I think. I foolishly started talking about things in discord, but cross-posting here.

Looking at #12627 it seems like we have the option to drop a feature in order to fix the bug. I would advocate for that or hiding it behind a flag so that we don't have to wait for a big rewrite to fix this issue.

It works on other backend, and its just a can we kicked down the road for a while....

@GMNGeoffrey
Copy link
Contributor

If this is indeed a release blocker then I think we need to revert the offending feature. This looks like we are miscompiling and we have head and unstable releases in a known-broken state.

@antiagainst
Copy link
Contributor

I'm coming to fix this in the proper way next. The issue is triggered by some new IR patterns from torch-mlir which we didn't see before. So it's not that we have a regression---previous releases won't support it either. I don't want to blocking releasing on my implementation; so I'm fine rolling forward the release.

@GMNGeoffrey
Copy link
Contributor

Got it, then I think this is not a release blocker. Thanks for clarifying (and for fixing 🙂 )

antiagainst added a commit that referenced this issue May 26, 2023
This just needs to optimize vector transfer ops after vectorization and
before folding memref aliases. At that time we still have memref
subviews using the same indices.

Fixes #12523
Closes #12627
NatashaKnk pushed a commit to NatashaKnk/iree that referenced this issue Jul 6, 2023
…#13823)

This just needs to optimize vector transfer ops after vectorization and
before folding memref aliases. At that time we still have memref
subviews using the same indices.

Fixes iree-org#12523
Closes iree-org#12627
nhasabni pushed a commit to plaidml/iree that referenced this issue Aug 24, 2023
…#13823)

This just needs to optimize vector transfer ops after vectorization and
before folding memref aliases. At that time we still have memref
subviews using the same indices.

Fixes iree-org#12523
Closes iree-org#12627
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working codegen/spirv SPIR-V code generation compiler backend
Projects
None yet
9 participants