Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HLO iota should fuse with other ops and not (often) require materialization into memory. #13745

Closed
benvanik opened this issue May 23, 2023 · 7 comments · Fixed by #14070
Closed
Assignees
Labels
compiler/dialects Relating to the IREE compiler dialects (flow, hal, vm) performance ⚡ Performance/optimization related work across the compiler and runtime

Comments

@benvanik
Copy link
Collaborator

Seeing this input HLO with iota and broadcast/etc that don't get fused, sometimes leading up to sorts and sometimes directly being inserted into tensors/etc:

    %4 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<6xi32>
    %5 = "mhlo.broadcast_in_dim"(%4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<6xi32>) -> tensor<2x6xi32>

->

    %15 = tensor.empty() : tensor<2xi32>
    %16 = linalg.generic {indexing_maps = [#map2], iterator_types = ["parallel"]} outs(%15 : tensor<2xi32>) {
    ^bb0(%out: i32):
      %28 = linalg.index 0 : index
      %29 = arith.index_cast %28 : index to i32
      linalg.yield %29 : i32
    } -> tensor<2xi32>
    %17 = tensor.empty() : tensor<2x6x1xi32>
    %18 = linalg.generic {indexing_maps = [#map5, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%16 : tensor<2xi32>) outs(%17 : tensor<2x6x1xi32>) {
    ^bb0(%in: i32, %out: i32):
      linalg.yield %in : i32
    } -> tensor<2x6x1xi32>

->

  %8 = flow.dispatch.workgroups() : () -> tensor<2x6xi32> =
      (%arg1: !flow.dispatch.tensor<writeonly:tensor<2x6xi32>>) {
    %24 = tensor.empty() : tensor<2x6xi32>
    %25 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%24 : tensor<2x6xi32>) {
    ^bb0(%out: i32):
      %26 = linalg.index 0 : index
      %27 = arith.index_cast %26 : index to i32
      linalg.yield %27 : i32
    } -> tensor<2x6xi32>
    flow.dispatch.tensor.store %25, %arg1, offsets = [0, 0], sizes = [2, 6], strides = [1, 1] : tensor<2x6xi32> -> !flow.dispatch.tensor<writeonly:tensor<2x6xi32>>
    flow.return
  } count() -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
    flow.return %x, %y, %z : index, index, index
  }
  %9 = flow.tensor.empty : tensor<2x6x2xi32>
  %10 = flow.dispatch.workgroups(%8, %9) : (tensor<2x6xi32>, tensor<2x6x2xi32>) -> %9 =
      (%arg1: !flow.dispatch.tensor<readonly:tensor<2x6xi32>>, %arg2: !flow.dispatch.tensor<readwrite:tensor<2x6x2xi32>>) {
    %24 = flow.dispatch.tensor.load %arg1, offsets = [0, 0], sizes = [2, 6], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x6xi32>> -> tensor<2x6xi32>
    flow.dispatch.tensor.store %24, %arg2, offsets = [0, 0, 0], sizes = [2, 6, 1], strides = [1, 1, 1] : tensor<2x6xi32> -> !flow.dispatch.tensor<readwrite:tensor<2x6x2xi32>>
    flow.return
  } count() -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
    flow.return %x, %y, %z : index, index, index
  }

The iota should definitely end up with the broadcast, but should probably even end up in the subsequent sort/consumer as in most cases iota should be something we can derive from workgroup/distribution and not something we need to materialize in memory.

Full reproducer with two such iotas in #13729. Most LLMs also do this though and #13729, #13637, and #13648 all share this pattern to varying degrees.

@benvanik benvanik added compiler/dialects Relating to the IREE compiler dialects (flow, hal, vm) performance ⚡ Performance/optimization related work across the compiler and runtime labels May 23, 2023
@allieculp
Copy link

@benvanik Are you working on this or needs an owner?

@benvanik
Copy link
Collaborator Author

Needs an owner.

@allieculp
Copy link

@jpienaar @julianwa @mattwalsh Can we find an owner for this? Or drop to P2?

@benvanik
Copy link
Collaborator Author

This is relevant to LLM memory usage and important for that effort. In some models it can be several hundred MB of cumulative iota allocations (from things like stablehlo.iota dim = 2 : tensor<4x2048x51200xi32> as used in argmax, and scales with batch size).

@MaheshRavishankar
Copy link
Contributor

Strange. This should work out of the box.

@MaheshRavishankar
Copy link
Contributor

So

#map2 = affine_map<(d0) -> (d0)>
#map5 = affine_map<(d0, d1, d2) -> (d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func.func @test() -> tensor<2x6x1xi32> {
  %15 = tensor.empty() : tensor<2xi32>
  %16 = linalg.generic {indexing_maps = [#map2], iterator_types = ["parallel"]} outs(%15 : tensor<2xi32>) {
  ^bb0(%out: i32):
    %28 = linalg.index 0 : index
    %29 = arith.index_cast %28 : index to i32
    linalg.yield %29 : i32
  } -> tensor<2xi32>
  %17 = tensor.empty() : tensor<2x6x1xi32>
  %18 = linalg.generic {indexing_maps = [#map5, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%16 : tensor<2xi32>) outs(%17 : tensor<2x6x1xi32>) {
  ^bb0(%in: i32, %out: i32):
    linalg.yield %in : i32
  } -> tensor<2x6x1xi32>
  return %18 : tensor<2x6x1xi32>
}

by itself get fused

iree-opt --pass-pipeline="builtin.module(func.func(iree-flow-fusion-of-tensor-ops))" repro.mlir 
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
  func.func @test() -> tensor<2x6x1xi32> {
    %0 = tensor.empty() : tensor<2x6x1xi32>
    %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%0 : tensor<2x6x1xi32>) {
    ^bb0(%out: i32):
      %2 = linalg.index 0 : index
      %3 = arith.index_cast %2 : index to i32
      linalg.yield %3 : i32
    } -> tensor<2x6x1xi32>
    return %1 : tensor<2x6x1xi32>
  }
}

So this might be an issue of multiple uses... #13747 is probably going to help (I do get a lot better IR with that).... Have to resolve issues with landing that PR

@benvanik
Copy link
Collaborator Author

benvanik commented Jun 1, 2023

Ah yes, multiple uses would cause this for sure - in a model with multiple topks I bet the iota is getting CSEd.

MaheshRavishankar pushed a commit to MaheshRavishankar/iree that referenced this issue Jun 12, 2023
Current fusion heuristics always fuse copy-like ops with its
consumers. Iota ops are also copy-like ops (indeed if the deprecated
`linalg.indexed_generic` were around it would still be a copy-like
op).

Fixes iree-org#13745
MaheshRavishankar added a commit that referenced this issue Jun 13, 2023
Current fusion heuristics always fuse copy-like ops with its consumers. Iota ops are also copy-like ops (indeed if the deprecated linalg.indexed_generic were around it would still be a copy-like op).

Fixes #13745
nhasabni pushed a commit to plaidml/iree that referenced this issue Aug 24, 2023
Current fusion heuristics always fuse copy-like ops with its consumers. Iota ops are also copy-like ops (indeed if the deprecated linalg.indexed_generic were around it would still be a copy-like op).

Fixes iree-org#13745
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler/dialects Relating to the IREE compiler dialects (flow, hal, vm) performance ⚡ Performance/optimization related work across the compiler and runtime
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants