Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ScheduleAllocation crashes when compiling included sparse program #13729

Closed
rsuderman opened this issue May 22, 2023 · 22 comments · Fixed by #13748
Closed

ScheduleAllocation crashes when compiling included sparse program #13729

rsuderman opened this issue May 22, 2023 · 22 comments · Fixed by #13748
Labels
integrations/stablehlo StableHLO (JAX/TensorFlow/etc) import and conversion

Comments

@rsuderman
Copy link
Contributor

Segmentation fault inside of ScheduleAllocation (iree-stream-schedule-allocation) when compiling the following sparse program:

https://gist.github.com/rsuderman/28ad59022e46137903746557e6d05dd4

@rsuderman rsuderman added the integrations/stablehlo StableHLO (JAX/TensorFlow/etc) import and conversion label May 22, 2023
@benvanik
Copy link
Collaborator

At head I get this StableHLO error on parsing - Do you know how/why the gather op changed so I can fix this input or have a copy that matches the new format?

D:\Dev\iree/../iree-tmp/13729.mlir:16:11: error: failed to legalize operation 'stablehlo.gather' that was explicitly marked illegal
    %13 = "stablehlo.gather"(%arg0, %12) {dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>} : (tensor<2x6x2xi32>, tensor<2x6x2xi32>) -> tensor<2x6x2xi32>
          ^
D:\Dev\iree/../iree-tmp/13729.mlir:16:11: note: see current operation: %83 = "stablehlo.gather"(%arg0, %82) {dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>} : (tensor<2x6x2xi32>, tensor<2x6x2xi32>) -> tensor<2x6x2xi32>

@benvanik
Copy link
Collaborator

I'm not sure how this ever worked? Any ideas of recent regressions @kuhar? The gather is not getting converted to torch_index_select because of this condition:
https://github.com/openxla/iree/blob/main/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/GatherToTorchIndexSelect.cpp#L50-L55

I don't see any other GatherOp handling during conversion.

@kuhar
Copy link
Member

kuhar commented May 23, 2023

Taking a look. @rsuderman, can you share the compilation command?

@benvanik
Copy link
Collaborator

I was trying with --iree-input-type=stablehlo (which usually works for me). I saw this with another StableHLO model the other day too that others had apparently been successfully importing but I could not. Maybe this is a difference between OSS/google3? (I'm guessing @rsuderman / @phoenix-meadowlark are in google3?)

@benvanik
Copy link
Collaborator

Yeah, the repro from @silvasean in #13637 (ir19 from #13543) also has this issue.

@stellaraccident
Copy link
Collaborator

@burmako for tracking: another instance of version skew affecting development.

@burmako
Copy link
Contributor

burmako commented May 23, 2023

@GleasonK ^

@kuhar
Copy link
Member

kuhar commented May 23, 2023

@benvanik This input happens to exercise hlo canonicalizations pretty well. I found a few that are missing. Will check gather lowering once I have those in and the intermediate IR is more alike.

@benvanik
Copy link
Collaborator

Cool! Any idea why everyone but me seems to be able to import these?

@kuhar
Copy link
Member

kuhar commented May 23, 2023

Cool! Any idea why everyone but me seems to be able to import these?

I saw the same crash that @rsuderman reported when I set --iree-input-type=mhlo. My build was relwithdebinfo + assertions on linux.

@benvanik
Copy link
Collaborator

oh, wait, should I not be using --iree-input-type=stablehlo for stablehlo?

@kuhar
Copy link
Member

kuhar commented May 23, 2023

oh, wait, should I not be using --iree-input-type=stablehlo for stablehlo?

Right now you can use either, but =stablehlo is less mature. When you set =mhlo, it will first convert stablehlo to mhlo and then continue with its ingestion pipeline.

Also note that =mhlo is on track to go away in O(~weeks)`: https://groups.google.com/g/iree-discuss/c/s6dBpDtWhtk

@benvanik
Copy link
Collaborator

Ok, my bad - I didn't realize mhlo should be used for stablehlo. Guess that reenforces the request for people to post the flags they are using to compile when they post a reproducer ;)

@benvanik
Copy link
Collaborator

benvanik commented May 23, 2023

Root cause is that a dispatch region (main_dispatch_3_sort_2x6xi32) is returning an unused result and (if I'm reading correctly) is incorrectly formed. Full dump here: https://gist.github.com/benvanik/aed334005561082f106fbf7287c70e56

IR before FormDispatchWorkgroups:

func.func @main(%arg0: !hal.buffer_view {mhlo.sharding = "{replicated}"}) -> (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view) attributes {iree.abi.stub} {
  %c0_i32 = arith.constant 0 : i32
  %c6_i32 = arith.constant 6 : i32
  %true = arith.constant true
  %cst = arith.constant dense<1> : tensor<1x1xi32>
  %c0 = arith.constant 0 : index
  %c5 = arith.constant 5 : index
  %false = arith.constant false
  %c1 = arith.constant 1 : index
  %0 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<2x6x2xi32>
  %1 = tensor.empty() : tensor<2x2x6xi32>
  %2 = flow.dispatch.region -> (tensor<2x2x6xi32>) {
    %21 = tensor.empty() : tensor<2x2x6xi32>
    %22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<2x6x2xi32>) outs(%21 : tensor<2x2x6xi32>) {
    ^bb0(%in: i32, %out: i32):
      linalg.yield %in : i32
    } -> tensor<2x2x6xi32>
    flow.return %22 : tensor<2x2x6xi32>
  }
  %3 = flow.dispatch.region -> (tensor<2x2x6xi32>) {
    %21 = tensor.empty() : tensor<2x2x6xi32>
    %22 = iree_linalg_ext.reverse dimensions(dense<1> : tensor<1xi64>) ins(%2 : tensor<2x2x6xi32>) outs(%21 : tensor<2x2x6xi32>) : tensor<2x2x6xi32>
    flow.return %22 : tensor<2x2x6xi32>
  }
  %extracted_slice = tensor.extract_slice %3[0, 0, 0] [2, 1, 6] [1, 1, 1] : tensor<2x2x6xi32> to tensor<2x6xi32>
  %extracted_slice_0 = tensor.extract_slice %3[0, 1, 0] [2, 1, 6] [1, 1, 1] : tensor<2x2x6xi32> to tensor<2x6xi32>
  %4 = tensor.empty() : tensor<2x6xi32>
  %5 = flow.dispatch.region -> (tensor<2x6xi32>) {
    %21 = tensor.empty() : tensor<2x6xi32>
    %22 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%21 : tensor<2x6xi32>) {
    ^bb0(%out: i32):
      %23 = linalg.index 1 : index
      %24 = arith.index_cast %23 : index to i32
      linalg.yield %24 : i32
    } -> tensor<2x6xi32>
    flow.return %22 : tensor<2x6xi32>
  }
  %6:3 = flow.dispatch.region -> (tensor<2x6xi32>, tensor<2x6xi32>, tensor<2x6xi32>) {
    %extracted_slice_9 = tensor.extract_slice %3[0, 0, 0] [2, 1, 6] [1, 1, 1] : tensor<2x2x6xi32> to tensor<2x6xi32>
    %extracted_slice_10 = tensor.extract_slice %3[0, 1, 0] [2, 1, 6] [1, 1, 1] : tensor<2x2x6xi32> to tensor<2x6xi32>
    %21:3 = iree_linalg_ext.sort dimension(1) outs(%extracted_slice_10, %extracted_slice_9, %5 : tensor<2x6xi32>, tensor<2x6xi32>, tensor<2x6xi32>) {
    ^bb0(%arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32):
      %22 = arith.cmpi slt, %arg3, %arg4 : i32
      %23 = arith.cmpi slt, %arg1, %arg2 : i32
      %24 = arith.cmpi eq, %arg1, %arg2 : i32
      %25 = arith.andi %24, %22 : i1
      %26 = arith.ori %23, %25 : i1
      iree_linalg_ext.yield %26 : i1
    } -> tensor<2x6xi32>, tensor<2x6xi32>, tensor<2x6xi32>
    flow.return %21#0, %21#1, %21#2 : tensor<2x6xi32>, tensor<2x6xi32>, tensor<2x6xi32>
  }
  %7 = tensor.empty() : tensor<2x6xi1>
  %collapsed = tensor.collapse_shape %6#2 [[0, 1]] : tensor<2x6xi32> into tensor<12xi32>
  %8 = tensor.empty() : tensor<12xi32>
  %9 = flow.dispatch.region -> (tensor<12xi32>) {
    %21 = tensor.empty() : tensor<12xi32>
    %c0_i32_9 = arith.constant 0 : i32
    %c6_i32_10 = arith.constant 6 : i32
    %22 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%collapsed : tensor<12xi32>) outs(%21 : tensor<12xi32>) {
    ^bb0(%in: i32, %out: i32):
      %23 = arith.addi %in, %c6_i32_10 : i32
      %24 = arith.cmpi slt, %in, %c0_i32_9 : i32
      %25 = arith.select %24, %23, %in : i32
      linalg.yield %25 : i32
    } -> tensor<12xi32>
    flow.return %22 : tensor<12xi32>
  }
  %expanded = tensor.expand_shape %9 [[0, 1]] : tensor<12xi32> into tensor<2x6xi32>
  %10 = flow.dispatch.region -> (tensor<2x6xi32>) {
    %21 = tensor.empty() : tensor<2x6xi32>
    %22 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%21 : tensor<2x6xi32>) {
    ^bb0(%out: i32):
      %23 = linalg.index 0 : index
      %24 = arith.index_cast %23 : index to i32
      linalg.yield %24 : i32
    } -> tensor<2x6xi32>
    flow.return %22 : tensor<2x6xi32>
  }
  %11 = tensor.empty() : tensor<2x6x2xi32>
  %inserted_slice = tensor.insert_slice %10 into %11[0, 0, 0] [2, 6, 1] [1, 1, 1] : tensor<2x6xi32> into tensor<2x6x2xi32>
  %inserted_slice_1 = tensor.insert_slice %expanded into %inserted_slice[0, 0, 1] [2, 6, 1] [1, 1, 1] : tensor<2x6xi32> into tensor<2x6x2xi32>
  %12 = flow.dispatch.region -> (tensor<2x6x2xi32>) {
    %21 = tensor.empty() : tensor<2x6x2xi32>
    %c5_9 = arith.constant 5 : index
    %c1_10 = arith.constant 1 : index
    %c0_11 = arith.constant 0 : index
    %22 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%21 : tensor<2x6x2xi32>) {
    ^bb0(%out: i32):
      %23 = linalg.index 0 : index
      %24 = linalg.index 1 : index
      %25 = linalg.index 2 : index
      %extracted = tensor.extract %inserted_slice_1[%23, %24, %c0_11] : tensor<2x6x2xi32>
      %26 = arith.index_cast %extracted : i32 to index
      %extracted_12 = tensor.extract %inserted_slice_1[%23, %24, %c1_10] : tensor<2x6x2xi32>
      %27 = arith.index_cast %extracted_12 : i32 to index
      %28 = arith.maxsi %26, %c0_11 : index
      %29 = arith.minsi %28, %c1_10 : index
      %30 = arith.maxsi %27, %c0_11 : index
      %31 = arith.minsi %30, %c5_9 : index
      %extracted_13 = tensor.extract %0[%29, %31, %25] : tensor<2x6x2xi32>
      linalg.yield %extracted_13 : i32
    } -> tensor<2x6x2xi32>
    flow.return %22 : tensor<2x6x2xi32>
  }
  %extracted_slice_2 = tensor.extract_slice %12[0, 1, 0] [2, 5, 2] [1, 1, 1] : tensor<2x6x2xi32> to tensor<2x5x2xi32>
  %extracted_slice_3 = tensor.extract_slice %12[0, 0, 0] [2, 5, 2] [1, 1, 1] : tensor<2x6x2xi32> to tensor<2x5x2xi32>
  %13 = tensor.empty() : tensor<2x5xi1>
  %14 = linalg.fill ins(%false : i1) outs(%13 : tensor<2x5xi1>) -> tensor<2x5xi1>
  %collapsed_4 = tensor.collapse_shape %extracted_slice_2 [[0, 1], [2]] : tensor<2x5x2xi32> into tensor<10x2xi32>
  %collapsed_5 = tensor.collapse_shape %extracted_slice_3 [[0, 1], [2]] : tensor<2x5x2xi32> into tensor<10x2xi32>
  %collapsed_6 = tensor.collapse_shape %14 [[0, 1]] : tensor<2x5xi1> into tensor<10xi1>
  %15 = flow.dispatch.region -> (tensor<10xi1>) {
    %21 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%collapsed_4, %collapsed_5 : tensor<10x2xi32>, tensor<10x2xi32>) outs(%collapsed_6 : tensor<10xi1>) {
    ^bb0(%in: i32, %in_9: i32, %out: i1):
      %22 = arith.cmpi ne, %in, %in_9 : i32
      %23 = arith.ori %out, %22 : i1
      linalg.yield %23 : i1
    } -> tensor<10xi1>
    flow.return %21 : tensor<10xi1>
  }
  %expanded_7 = tensor.expand_shape %15 [[0, 1]] : tensor<10xi1> into tensor<2x5xi1>
  %expanded_8 = tensor.expand_shape %expanded_7 [[0, 1], [2]] : tensor<2x5xi1> into tensor<1x2x5xi1>
  %16 = linalg.fill ins(%true : i1) outs(%7 : tensor<2x6xi1>) -> tensor<2x6xi1>
  %17 = flow.dispatch.region -> (tensor<2x6xi1>) {
    %21 = tensor.empty() : tensor<2x6xi1>
    %true_9 = arith.constant true
    %cst_10 = arith.constant dense<1> : tensor<1x1xi32>
    %22 = linalg.fill ins(%true_9 : i1) outs(%21 : tensor<2x6xi1>) -> tensor<2x6xi1>
    %23 = iree_linalg_ext.scatter dimension_map = [1] unique_indices(true) ins(%expanded_8, %cst_10 : tensor<1x2x5xi1>, tensor<1x1xi32>) outs(%22 : tensor<2x6xi1>) {
    ^bb0(%arg1: i1, %arg2: i1):
      iree_linalg_ext.yield %arg1 : i1
    } -> tensor<2x6xi1>
    flow.return %23 : tensor<2x6xi1>
  }
  %18 = hal.tensor.export %12 "output 0" : tensor<2x6x2xi32> -> !hal.buffer_view
  %19 = hal.tensor.export %17 "output 1" : tensor<2x6xi1> -> !hal.buffer_view
  %20 = hal.tensor.export %6#2 "output 2" : tensor<2x6xi32> -> !hal.buffer_view
  return %18, %19, %20 : !hal.buffer_view, !hal.buffer_view, !hal.buffer_view
}

Here's the sort:

  %6:3 = flow.dispatch.region -> (tensor<2x6xi32>, tensor<2x6xi32>, tensor<2x6xi32>) {
    %extracted_slice_9 = tensor.extract_slice %3[0, 0, 0] [2, 1, 6] [1, 1, 1] : tensor<2x2x6xi32> to tensor<2x6xi32>
    %extracted_slice_10 = tensor.extract_slice %3[0, 1, 0] [2, 1, 6] [1, 1, 1] : tensor<2x2x6xi32> to tensor<2x6xi32>
    %21:3 = iree_linalg_ext.sort dimension(1) outs(%extracted_slice_10, %extracted_slice_9, %5 : tensor<2x6xi32>, tensor<2x6xi32>, tensor<2x6xi32>) {
    ^bb0(%arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32):
      %22 = arith.cmpi slt, %arg3, %arg4 : i32
      %23 = arith.cmpi slt, %arg1, %arg2 : i32
      %24 = arith.cmpi eq, %arg1, %arg2 : i32
      %25 = arith.andi %24, %22 : i1
      %26 = arith.ori %23, %25 : i1
      iree_linalg_ext.yield %26 : i1
    } -> tensor<2x6xi32>, tensor<2x6xi32>, tensor<2x6xi32>
    flow.return %21#0, %21#1, %21#2 : tensor<2x6xi32>, tensor<2x6xi32>, tensor<2x6xi32>
  }

Of the 3 sort results only #2 is used and #0/#1 are unused and should not be returned here.

And this is the formed region:

  %4:2 = flow.dispatch.workgroups(%2, %3) : (tensor<2x2x6xi32>, tensor<2x6xi32>) -> (tensor<2x6xi32>, %3) =
      (%arg1: !flow.dispatch.tensor<readonly:tensor<2x2x6xi32>>, %arg2: !flow.dispatch.tensor<readwrite:tensor<2x6xi32>>, %arg3: !flow.dispatch.tensor<writeonly:tensor<2x6xi32>>) {
    %25 = flow.dispatch.tensor.load %arg2, offsets = [0, 0], sizes = [2, 6], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<2x6xi32>> -> tensor<2x6xi32>
    %26 = flow.dispatch.tensor.load %arg1, offsets = [0, 0, 0], sizes = [2, 1, 6], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x2x6xi32>> -> tensor<2x6xi32>
    %27 = flow.dispatch.tensor.load %arg1, offsets = [0, 1, 0], sizes = [2, 1, 6], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x2x6xi32>> -> tensor<2x6xi32>
    %28:3 = iree_linalg_ext.sort dimension(1) outs(%27, %26, %25 : tensor<2x6xi32>, tensor<2x6xi32>, tensor<2x6xi32>) {
    ^bb0(%arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32):
      %29 = arith.cmpi slt, %arg6, %arg7 : i32
      %30 = arith.cmpi slt, %arg4, %arg5 : i32
      %31 = arith.cmpi eq, %arg4, %arg5 : i32
      %32 = arith.andi %31, %29 : i1
      %33 = arith.ori %30, %32 : i1
      iree_linalg_ext.yield %33 : i1
    } -> tensor<2x6xi32>, tensor<2x6xi32>, tensor<2x6xi32>
    flow.dispatch.tensor.store %28#0, %arg3, offsets = [0, 0], sizes = [2, 6], strides = [1, 1] : tensor<2x6xi32> -> !flow.dispatch.tensor<writeonly:tensor<2x6xi32>>
    flow.dispatch.tensor.store %28#2, %arg2, offsets = [0, 0], sizes = [2, 6], strides = [1, 1] : tensor<2x6xi32> -> !flow.dispatch.tensor<readwrite:tensor<2x6xi32>>
    flow.return
  } count() -> (index, index, index) {
    %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
    flow.return %x, %y, %z : index, index, index
  }

%4#0 is never used so it's weird it's escaping as the other of the unused results (%28#1) is dropped correctly.

The weirder thing, though, is the signature:
(%2, %3) : (tensor<2x2x6xi32>, tensor<2x6xi32>) -> (tensor<2x6xi32>, %3)

(nevermind, remembering why I hate the syntax of this op, it may be ok but wasteful - looking more)
(yeah signature is fine - I've still got a TODO to fix the flow.dispatch.workgroups syntax to match what stream.async.execute does)

@benvanik
Copy link
Collaborator

Ok so quick fix here is to figure out why FormDispatchWorkgroups is returning the unused write-only result instead of dropping it.
I'll see if I can make ScheduleAllocations allocate outputs for unused write-only results but in a well-formed program we should never end up with those.

@benvanik
Copy link
Collaborator

Ah yes ScheduleAllocations does its allocation based on liveness and these unused values nested in concurrent regions don't count as live. May require some non-trivial rejiggering.

@MaheshRavishankar / @hanhanW any ideas as to why one unused result of a linalg sort is dropped during dispatch region formation but the other unused result isn't?

@benvanik
Copy link
Collaborator

I suspect one place this could be fixed is here:
https://github.com/openxla/iree/blob/6a46afd4f82715979b54c57fb41a8505466cd68f/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp#L388-L397

That's saying if there's any use of any result then add all results instead of just the results with escaping uses. I'm not sure if that's load-bearing or not - doesn't feel like it should be.

@MaheshRavishankar
Copy link
Contributor

MaheshRavishankar commented May 23, 2023

Ok, as part of #13711 (here) I did change that to only return the results that are used...

That can have some issues though.... It might fix the issue here, but sometimes the operation semantics requires it to have multiple results (there are some argmax kind of ops that have this behavior). You need to have multiple results even if not all results are used (cause they effectively induction variables that take an init and for consistency they need to return a result).
So this does happen. There are two options here

  1. There is an allocation for it that is done outside the dispatch, but that is wateful
  2. If there is no allocation for it outside the dispatch, the op semantics will still have multiple results... in the backend this would result in a stack allocation. For these ops, we have been able to keep the stack allocation within the limit...

For a first step, I can split out the NFC parts of #13711 and send that out separately. That might unblock things here.

@MaheshRavishankar
Copy link
Contributor

Ah yes ScheduleAllocations does its allocation based on liveness and these unused values nested in concurrent regions don't count as live. May require some non-trivial rejiggering.

@MaheshRavishankar / @hanhanW any ideas as to why one unused result of a linalg sort is dropped during dispatch region formation but the other unused result isn't?

No, I dont know off the top of my head. I need to look further....

@benvanik
Copy link
Collaborator

Gotcha - sounds great! I'll keep trying to make this work but that'd be a nice improvement overall.

It's ok to have unused read/write results but unused write-only should never be needed as they're declared as never being readable by anything and are effectively /dev/null - the dispatch isn't correct if it tries to read from a write-only binding. Conceptually you can think of a read-only binding as being a stream input, a write-only binding as being a stream output, and read/write being random access.

For argmax-like things we should be able to do much better than a globally-sized heap allocation (eventually) - as we're regularly out-of-memory'ing on LLM's I'd err on the side of assuming we'll have to some day eliminate all such cases of globally-sized heap allocations that can be done any other way :)

@MaheshRavishankar
Copy link
Contributor

For argmax-like things we should be able to do much better than a globally-sized heap allocation (eventually) - as we're regularly out-of-memory'ing on LLM's I'd err on the side of assuming we'll have to some day eliminate all such cases of globally-sized heap allocations that can be done any other way :)

Argmax are handled today without globally-sizes heap allocation... they are handled through bounded stack allocations.

benvanik added a commit that referenced this issue May 23, 2023
The stream allocation pass is currently just looking at the top level
of each execution region and values nested under concurrency regions
with no users don't get allocated any memory. Since there are cases where
dispatch region formation can produce unused write-only results we need
to handle these even if suboptimal (wasting an allocation for /dev/null).

Fixes #13729.
@benvanik
Copy link
Collaborator

I've got a workaround in #13748 that should unblock this by allocating memory for the unused results. @MaheshRavishankar's fix for not returning unused results will be a good memory reduction, though! 🥳

benvanik added a commit that referenced this issue May 23, 2023
The stream allocation pass is currently just looking at the top level of
each execution region and values nested under concurrency regions with
no users don't get allocated any memory. Since there are cases where
dispatch region formation can produce unused write-only results we need
to handle these even if suboptimal (wasting an allocation for
/dev/null).

Fixes #13729.
kuhar added a commit to kuhar/iree that referenced this issue May 25, 2023
This pattern fell through the cracks during the initial porting of
hlo-to-linalg lowering in iree-org#12957.

With this pattern and the most recent canon patterns, we produce the
same code as the mhlo input conversion pipeline on the input from
iree-org#13729.

Issue: iree-org#12678
kuhar added a commit that referenced this issue May 25, 2023
This pattern fell through the cracks during the initial porting of
hlo-to-linalg lowering in #12957.

With this pattern and the most recent canon patterns, we produce the
same code as the mhlo input conversion pipeline on the input from
#13729.

Also fixed issues with undefined FileCheck variables in tests.

Issue: #12678
NatashaKnk pushed a commit to NatashaKnk/iree that referenced this issue Jul 6, 2023
…rg#13748)

The stream allocation pass is currently just looking at the top level of
each execution region and values nested under concurrency regions with
no users don't get allocated any memory. Since there are cases where
dispatch region formation can produce unused write-only results we need
to handle these even if suboptimal (wasting an allocation for
/dev/null).

Fixes iree-org#13729.
NatashaKnk pushed a commit to NatashaKnk/iree that referenced this issue Jul 6, 2023
This pattern fell through the cracks during the initial porting of
hlo-to-linalg lowering in iree-org#12957.

With this pattern and the most recent canon patterns, we produce the
same code as the mhlo input conversion pipeline on the input from
iree-org#13729.

Also fixed issues with undefined FileCheck variables in tests.

Issue: iree-org#12678
nhasabni pushed a commit to plaidml/iree that referenced this issue Aug 24, 2023
…rg#13748)

The stream allocation pass is currently just looking at the top level of
each execution region and values nested under concurrency regions with
no users don't get allocated any memory. Since there are cases where
dispatch region formation can produce unused write-only results we need
to handle these even if suboptimal (wasting an allocation for
/dev/null).

Fixes iree-org#13729.
nhasabni pushed a commit to plaidml/iree that referenced this issue Aug 24, 2023
This pattern fell through the cracks during the initial porting of
hlo-to-linalg lowering in iree-org#12957.

With this pattern and the most recent canon patterns, we produce the
same code as the mhlo input conversion pipeline on the input from
iree-org#13729.

Also fixed issues with undefined FileCheck variables in tests.

Issue: iree-org#12678
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
integrations/stablehlo StableHLO (JAX/TensorFlow/etc) import and conversion
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants