Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stablehlo.sort pathologically slow #13865

Open
silvasean opened this issue May 31, 2023 · 26 comments
Open

stablehlo.sort pathologically slow #13865

silvasean opened this issue May 31, 2023 · 26 comments
Assignees
Labels
codegen Shared code generation infrastructure and dialects performance ⚡ Performance/optimization related work across the compiler and runtime

Comments

@silvasean
Copy link
Contributor

What happened?

Compile/run IR below. The result takes like 10 seconds to run on my A100. That is over 1000x off in performance.

$ iree-compile --iree-hal-target-backends=cuda --iree-input-type=xla --iree-hal-cuda-llvm-target-arch=sm_80 sort.mlir >sort.vmfb
$ iree-benchmark-module --device=cuda --module=sort.vmfb --function=main --input=4x8000xf32=0 --input=4x8000xi32=0 --benchmark_repetitions=2
module {
  func.func @main(%arg0: tensor<4x8000xf32>, %arg1: tensor<4x8000xi32>) -> (tensor<4x8000xf32>, tensor<4x8000xi32>) {
    %0:2 = "stablehlo.sort"(%arg0, %arg1) ({
    ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<i32>, %arg5: tensor<i32>):
      %1 = stablehlo.compare  GT, %arg2, %arg3,  TOTALORDER : (tensor<f32>, tensor<f32>) -> tensor<i1>
      stablehlo.return %1 : tensor<i1>
    }) {dimension = 1 : i64, is_stable = true} : (tensor<4x8000xf32>, tensor<4x8000xi32>) -> (tensor<4x8000xf32>, tensor<4x8000xi32>)
    return %0#0, %0#1 : tensor<4x8000xf32>, tensor<4x8000xi32>
  }
}

Steps to reproduce your issue

See above

What component(s) does this issue relate to?

Compiler

Version information

No response

Additional context

No response

@silvasean silvasean added the bug 🐞 Something isn't working label May 31, 2023
@benvanik
Copy link
Collaborator

yeah, would be great to get some eyes on this - it's been ignored for ages because most models were doing sorts of 5-10 elements, but it's really bad as it scales larger. IIRC this is related to #13745 (I believe that is coming from sort or at least adjacent to it).

@benvanik benvanik added codegen Shared code generation infrastructure and dialects performance ⚡ Performance/optimization related work across the compiler and runtime and removed bug 🐞 Something isn't working labels May 31, 2023
@silvasean
Copy link
Contributor Author

yep the surrounding ir in the model looks like:

      %10308 = stablehlo.iota dim = 0 : tensor<8000xi32>
      %10309 = stablehlo.broadcast_in_dim %10308, dims = [1] : (tensor<8000xi32>) -> tensor<4x8000xi32>
      %10310:2 = "stablehlo.sort"(%10289, %10309) ({
      ^bb0(%arg386: tensor<f32>, %arg387: tensor<f32>, %arg388: tensor<i32>, %arg389: tensor<i32>):
        %10525 = stablehlo.compare  GT, %arg386, %arg387,  TOTALORDER : (tensor<f32>, tensor<f32>) -> tensor<i1>
        stablehlo.return %10525 : tensor<i1>
      }) {dimension = 1 : i64, is_stable = true} : (tensor<4x8000xf32>, tensor<4x8000xi32>) -> (tensor<4x8000xf32>, tensor<4x8000xi32>)

@MaheshRavishankar
Copy link
Contributor

The iree_linalg_ext.sort operation implements the TilingInterface. Its a small step from there to the PartialReductionOpInterface. That will allow handling of sort similar to split-k.... both the interface and implementation/use in IREE need work...

@qcolombet
Copy link
Contributor

Assigning to Güray for an assessment of what needs to happen here.

@silvasean
Copy link
Contributor Author

This occurs in a topk pattern

      %10308 = stablehlo.iota dim = 0 : tensor<8000xi32> loc(#loc11490)
      %10309 = stablehlo.broadcast_in_dim %10308, dims = [1] : (tensor<8000xi32>) -> tensor<4x8000xi32> loc(#loc11491)
      %10310:2 = "stablehlo.sort"(%10289, %10309) ({
      ^bb0(%arg386: tensor<bf16> loc("<stdin>":11110:12), %arg387: tensor<bf16> loc("<stdin>":11110:35), %arg388: tensor<i32> loc("<stdin>":11110:58), %arg389: tensor<i32> loc("<stdin>":11110:80)):
        %10525 = stablehlo.compare  GT, %arg386, %arg387,  TOTALORDER : (tensor<bf16>, tensor<bf16>) -> tensor<i1> loc(#loc11497)
        stablehlo.return %10525 : tensor<i1> loc(#loc11498)
      }) {dimension = 1 : i64, is_stable = true} : (tensor<4x8000xbf16>, tensor<4x8000xi32>) -> (tensor<4x8000xbf16>, tensor<4x8000xi32>) loc(#loc11492)
      %10311 = "stablehlo.slice"(%10310#0) {limit_indices = dense<[4, 40]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x8000xbf16>) -> tensor<4x40xbf16> loc(#loc11499)
      %10427 = "stablehlo.slice"(%10310#1) {limit_indices = dense<[4, 40]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x8000xi32>) -> tensor<4x40xi32> loc(#loc11615)

@MaheshRavishankar
Copy link
Contributor

There was work done on making topk work on CUDA by @KoolJBlack ... Maybe we can route the lowering through that....

@allieculp
Copy link

Adding @kuhar as an assignee for interfacing with MHLO

@hanhanW
Copy link
Contributor

hanhanW commented Jun 1, 2023

Kojo has #9383 for top-k improvements and already landed some split reduction optimization for top-k, see #9807. Maybe we need a direct lowering to linalg_ext.top_k, then we can reuse what Kojo have done for top-k.

@kuhar
Copy link
Member

kuhar commented Jun 1, 2023

Maybe we need a direct lowering to linalg_ext.top_k, then we can reuse what Kojo have done for top-k.

Don't we have that already? https://github.com/openxla/iree/blob/0d81062a8ffabf3561e9d6cc758b8250b2f607d3/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgExt.cpp#L431

@benvanik
Copy link
Collaborator

benvanik commented Jun 1, 2023

These models don't seem to be using chlo.topk (or it's getting lowered to stable/mhlo prior to that pass). Ideally we'd still be able to do ok on something like sorting without requiring chlo - there's a lot of ways sorts can be used without being a topk.

@MaheshRavishankar
Copy link
Contributor

These models don't seem to be using chlo.topk (or it's getting lowered to stable/mhlo prior to that pass). Ideally we'd still be able to do ok on something like sorting without requiring chlo - there's a lot of ways sorts can be used without being a topk.

Agree.. but this is more about priorities... I think we should invest in making all these operations better (including split-k). I am happy to help with that... Instead of another special thing of sort, i'd rather just reuse the one-off things that were done for topk. If someone wants to generalize and have a generic way of handling this by pushing on the interfaces above, I'd be glad to help.

@benvanik
Copy link
Collaborator

benvanik commented Jun 1, 2023

oh yeah, more saying that pattern matching this to the linalgext op may be the quick workaround vs needing to change the original model to use chlo (if it's not already - someone needs to look) - iota + broadcast + sort seems reasonable

@rsuderman
Copy link
Contributor

rsuderman commented Jun 2, 2023

Included a specific pattern matching bugs here:
#13905
#13906

@grypp
Copy link
Contributor

grypp commented Jun 2, 2023

I studied current top-k op and its performance on GPUs. It is not optimal.

As far as I can see, topk is executed sequentially or few parallelism by default. Setting a value to the flag --iree-flow-topk-split-reduction manually leverages more parallelism. I can implement a heuristics to set a value. This will be the quickest way to get at least 7x.

But even if we do, the generated code isn't cutting-edge. xla gpu uses custom kernel when k < 16. I did some benchmark that shows that it is faster than iree but still not the most optimal solution. We might want to use a custom kernel, or microkernel if we want to fuse producer.

@MaheshRavishankar
Copy link
Contributor

I studied current top-k op and its performance on GPUs. It is not optimal.

As far as I can see, topk is executed sequentially or few parallelism by default. Setting a value to the flag --iree-flow-topk-split-reduction manually leverages more parallelism. I can implement a heuristics to set a value. This will be the quickest way to get at least 7x.

But even if we do, the generated code isn't cutting-edge. xla gpu uses custom kernel when k < 16. I did some benchmark that shows that it is faster than iree but still not the most optimal solution. We might want to use a custom kernel, or microkernel if we want to fuse producer.

I think fusing with producers here is not a "high priority". Using a custom PTX blob to dispatch for sort will probably remove it from the blocker list..... If we want to handle "sort" properly, there is a bit of design/exploration needed and better usage of the two interfaces I mentioned here. If we want to go down that route I can help (with design, but cant take on implementation). If we just want to "handle sort", id just write a CUDA kernel, compile to PTX and ship the PTX for now.

@benvanik
Copy link
Collaborator

benvanik commented Jun 2, 2023

hrrmmm we can't close this issue if the solution is "write some ptx" - may unblock the user here but definitely still a major issue for the platform as vulkan/metal/cpu will have extremely slow sorts and we don't want to be making one-off workarounds in the core platform in lieu of actually solving the problem better. if the heuristic makes things better we should do that, and then maybe if the user needs even better performance they can add custom ptx in the nvgpu plugin.

@MaheshRavishankar
Copy link
Contributor

hrrmmm we can't close this issue if the solution is "write some ptx" - may unblock the user here but definitely still a major issue for the platform as vulkan/metal/cpu will have extremely slow sorts and we don't want to be making one-off workarounds in the core platform in lieu of actually solving the problem better. if the heuristic makes things better we should do that, and then maybe if the user needs even better performance they can add custom ptx in the nvgpu plugin.

Yes, I am not talking in tree. The custom ptx in the nvgpu plugin is what I am saying above. I mentioned it internally, the topk implementation isnt done using these interfaces, and was a one-off (I did push for using interfaces then, but got push back). Having a sort one-off would be sad... So if we want to solve it and be done with it, I'd happily help, but this will require some heavy lifting. If we arent prepared to do that now then we might as well go with the easiest solution to avoid building another "heavy-weight one-off".

@qcolombet
Copy link
Contributor

@grypp could you go ahead with the ptx implementation in the nvgpu plugin?

@silvasean
Copy link
Contributor Author

I studied current top-k op and its performance on GPUs. It is not optimal.

As far as I can see, topk is executed sequentially or few parallelism by default. Setting a value to the flag --iree-flow-topk-split-reduction manually leverages more parallelism. I can implement a heuristics to set a value. This will be the quickest way to get at least 7x.

But even if we do, the generated code isn't cutting-edge. xla gpu uses custom kernel when k < 16. I did some benchmark that shows that it is faster than iree but still not the most optimal solution. We might want to use a custom kernel, or microkernel if we want to fuse producer.

How bad is IREE's linalg_ext.top_k? (with no special compiler flags). As long as it is within 10x of XLA:GPU I would consider that acceptable for the immediate-term goals.

cc @kuhar

@jpienaar
Copy link
Member

jpienaar commented Jun 3, 2023

Ideally we'd still be able to do ok on something like sorting without requiring chlo - there's a lot of ways sorts can be used without being a topk.

+1, If the focus here is only sort used for topk then it looks like this issue should be renamed or a new issue filed for TopK.

These models don't seem to be using chlo.topk (or it's getting lowered to stable/mhlo prior to that pass).

The models here not using chlo.topk is due the IR being looked at post expansion to MHLO, lowering to XLA HLO protos & importing back, so this is looking at the model post a lot of passes have been run which have destroyed structure.

@allieculp
Copy link

@grypp @kuhar Can you update this issue with the latest?

@kuhar
Copy link
Member

kuhar commented Jun 5, 2023

@grypp @kuhar Can you update this issue with the latest?

In the medium-term, we are going to propose and implement a dedicated topk op for stablehlo: openxla/stablehlo#1514

@grypp
Copy link
Contributor

grypp commented Jun 6, 2023

I am working on improving performance of topk. As it is a different issue, I created #13960

@allieculp
Copy link

allieculp commented Jun 8, 2023

@kuhar Adding [openxla/stablehlo#1593](openxla/stablehlo#1593) to show the status of adding Topk to StableHLO. Please add any other relevant updates here!

@allieculp
Copy link

@kuhar Do you have any update on this issue?

@kuhar
Copy link
Member

kuhar commented Jun 20, 2023

@kuhar Do you have any update on this issue?

@allieculp I posted an RFC for adding a topk op on the stablehlo github and presented it during the last openxla community meeting. I don't have an ETA for the new op being available in IREE, but this is not on the critical path given the workarounds from Natasha and Rob.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
codegen Shared code generation infrastructure and dialects performance ⚡ Performance/optimization related work across the compiler and runtime
Projects
None yet
Development

No branches or pull requests

10 participants