-
Notifications
You must be signed in to change notification settings - Fork 560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
hal.fence.join: OUT_OF_RANGE; count 90 of iree_hal_fence > 32 #13543
Comments
Is the solution here as simple as just bumping up the number from 32 to 128 or something like that? |
We could slightly up it as a whackamole smack - it's guarding a stack allocation though and we'd want to keep it small and special case anything larger with a heap allocation. I think the real issue is that you have 90 non-folded fences - those are (relatively) expensive and indicate some earlier stage of the pipeline being really inefficient. --compile-to=stream and post the IR and we can see if it's obvious. |
@pavanimajety Curious if there is any update on this? |
@benvanik here is the --compile-to=stream output you requested: https://gist.github.com/silvasean/1038fec28172cf61872cdd6e000523ad |
For context, this is a LLM training workload. The IR structure is the forward pass, backward pass, and then optimizer updates. The forward and backward pass each have their own stablehlo.while op. In the current IR, this actually only iterates once (note sure if we fold it away, I see a ton of control flow in the --compile-to=stream IR, more than I would expect for this). In the user workload this corresponds to microbatching (see the image here). |
@pjannaty @pavanimajety Is this being worked on by your team? |
@allieculp I am taking a look, was out of office last few days. |
Thanks, Allie and Pavani. Allie, this looks like a high-priority bug and we are gradually ramping up and don't mean to block if someone else also wants to take a look. Happy to collaborate on this as well. |
issue is FoldBlockArgumentsPattern not properly folding the duplicate block arguments. Will see if I can get that fixed. |
The existing code would give up on particular args if multiple branch sites had non-identical duplicate arg sets for that arg. Fixes #13543.
The existing code would give up on particular args if multiple branch sites had non-identical duplicate arg sets for that arg. Fixes #13543.
The existing code would give up on particular args if multiple branch sites had non-identical duplicate arg sets for that arg. Fixes iree-org#13543.
What happened?
Using ir19.flagfile and ir19.no_sharding.mlir, run the following commands:
(The
--iree-codegen-llvmgpu-use-transform-dialect= --iree-codegen-llvmgpu-enable-transform-dialect-jit=false
is to work around #13419)I see the following error:
(full log here)
Steps to reproduce your issue
See above
What component(s) does this issue relate to?
Runtime
Version information
iree.git @ 4d6c2b8 (I have a couple trivial local modifications, but should be unrelated to this bug)
Additional context
No response
The text was updated successfully, but these errors were encountered: