Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hal.fence.join: OUT_OF_RANGE; count 90 of iree_hal_fence > 32 #13543

Closed
silvasean opened this issue May 10, 2023 · 9 comments · Fixed by #13631
Closed

hal.fence.join: OUT_OF_RANGE; count 90 of iree_hal_fence > 32 #13543

silvasean opened this issue May 10, 2023 · 9 comments · Fixed by #13631
Assignees
Labels
bug 🐞 Something isn't working collab/nvidia

Comments

@silvasean
Copy link
Contributor

silvasean commented May 10, 2023

What happened?

Using ir19.flagfile and ir19.no_sharding.mlir, run the following commands:

iree-compile --iree-hal-target-backends=cuda --iree-input-type=mhlo --iree-hal-cuda-llvm-target-arch=sm_80 ir19.no_sharding.mlir --iree-codegen-llvmgpu-use-transform-dialect= --iree-codegen-llvmgpu-enable-transform-dialect-jit=false -o ir19.vmfb
iree-benchmark-module --function=main --module=ir19.vmfb --device=cuda --flagfile=ir19.flagfile --benchmark_repetitions=5

(The --iree-codegen-llvmgpu-use-transform-dialect= --iree-codegen-llvmgpu-enable-transform-dialect-jit=false is to work around #13419)

I see the following error:

iree/runtime/src/iree/modules/hal/module.c:1119: OUT_OF_RANGE; count 90 of iree_hal_fence > 32; while invoking native function hal.fence.join; while calling import; 
.....

(full log here)

Steps to reproduce your issue

See above

What component(s) does this issue relate to?

Runtime

Version information

iree.git @ 4d6c2b8 (I have a couple trivial local modifications, but should be unrelated to this bug)

Additional context

No response

@silvasean
Copy link
Contributor Author

Is the solution here as simple as just bumping up the number from 32 to 128 or something like that?

@benvanik
Copy link
Collaborator

We could slightly up it as a whackamole smack - it's guarding a stack allocation though and we'd want to keep it small and special case anything larger with a heap allocation. I think the real issue is that you have 90 non-folded fences - those are (relatively) expensive and indicate some earlier stage of the pipeline being really inefficient. --compile-to=stream and post the IR and we can see if it's obvious.

@allieculp
Copy link

@pavanimajety Curious if there is any update on this?

@silvasean
Copy link
Contributor Author

@benvanik here is the --compile-to=stream output you requested: https://gist.github.com/silvasean/1038fec28172cf61872cdd6e000523ad

@silvasean
Copy link
Contributor Author

For context, this is a LLM training workload. The IR structure is the forward pass, backward pass, and then optimizer updates. The forward and backward pass each have their own stablehlo.while op. In the current IR, this actually only iterates once (note sure if we fold it away, I see a ton of control flow in the --compile-to=stream IR, more than I would expect for this). In the user workload this corresponds to microbatching (see the image here).

@allieculp
Copy link

@pjannaty @pavanimajety Is this being worked on by your team?

@pavanimajety
Copy link
Contributor

@allieculp I am taking a look, was out of office last few days.

@pjannaty
Copy link
Contributor

pjannaty commented May 15, 2023

Thanks, Allie and Pavani. Allie, this looks like a high-priority bug and we are gradually ramping up and don't mean to block if someone else also wants to take a look. Happy to collaborate on this as well.

@benvanik
Copy link
Collaborator

issue is FoldBlockArgumentsPattern not properly folding the duplicate block arguments. Will see if I can get that fixed.

@benvanik benvanik self-assigned this May 15, 2023
benvanik added a commit that referenced this issue May 16, 2023
The existing code would give up on particular args if multiple branch
sites had non-identical duplicate arg sets for that arg.

Fixes #13543.
benvanik added a commit that referenced this issue May 16, 2023
The existing code would give up on particular args if multiple branch
sites had non-identical duplicate arg sets for that arg.

Fixes #13543.
NatashaKnk pushed a commit to NatashaKnk/iree that referenced this issue Jul 6, 2023
The existing code would give up on particular args if multiple branch
sites had non-identical duplicate arg sets for that arg.

Fixes iree-org#13543.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working collab/nvidia
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants