New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
iree_linalg_ext.attention dispatch fails to bufferize after TileAndDecomposeAttention #16421
Comments
The attached The IR after TileAndDecomposeAttention shows an address being reused many times in the innermost decomposed block of computation, which might be triggering a failure of OneShotAnalysis:
|
The attention decomposition is failing during bufferization, specifically `--iree-eliminate-empty-tensors-pass`. This seems to be an upstream bug, but there are clearly no end-to-end tests for attention operation. So disabling the decomposition for now till the issue is fixed. Eventually this pass is probably needed only backends that dont support attention. Towards iree-org#16421.
Turns out attention decomposition is failing during bufferization and we probably dont have any end-to-end tests exercising this (AFAIK all attention op related work has been done only using transform dialect with correctness checks out of tree). So this is probably failing silently. |
The attention decomposition is failing during bufferization, specifically `--iree-eliminate-empty-tensors-pass`. This seems to be an upstream bug, but there are clearly no end-to-end tests for attention operation. So disabling the decomposition for now till the issue is fixed. Eventually this pass is probably needed only backends that dont support attention. Towards #16421.
I am attaching the current decomposition (repro.mlir) and what I manually fixed (fixed.mlir). The latter fixes the failure in the eliminate empty tensors pass. The current decomposition is doing something funky with destinations and probably needs a look. It could also probably use some elementwise op fusion to make the IR more managable. All in all this needs a new look. Tagging @harsh-nod , do you think we can iterate on this to get to a better final place. |
@MaheshRavishankar - definitely. I will be out next week but @erman-gurses can help pick this up in the meanwhile. Some notes on the problem:
%24 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%23 : tensor<64x64xf32>) outs(%arg4 : tensor<64xf32>) {
^bb0(%in: f32, %out: f32):
%33 = arith.maximumf %in, %out : f32
linalg.yield %33 : f32
} -> tensor<64xf32> where we use the old max to compute the new max (%24). This is returned at the end of the loop
The problem comes when we use this again here %26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<64xf32>) outs(%arg4 : tensor<64xf32>) {
^bb0(%in: f32, %out: f32):
%33 = arith.subf %out, %in : f32
%34 = math.exp2 %33 : f32
linalg.yield %34 : f32
} -> tensor<64xf32> The idea behind this second usage was the following: Another note is that the transform dialect bufferization operator seems to bufferize the original graph. So might be worthwhile to look into the differences between the TD Bufferization Op and what's in the pass pipeline. Specifically, here https://github.com/openxla/iree/blob/6560f8600c6e6a79eaed3ca3b26af2f8e2b900ae/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp#L760 where some new passes are being used such as EmptyTensorLoweringPattern etc. |
First thing to look at the output of this command and see if it avoids additional stack allocations
i.e. without trying to manually reuse the buffers, but letting the analysis eliminate it for you. If so, then just changing the tile and decompose attention pass to not reuse buffer this way would fix the issue. |
@MaheshRavishankar FWIW: Followed instructions to run output of iree-opt (above command):
|
noting I've included the
|
I just verified that #map = affine_map<()[s0] -> (s0 * 128)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d0)>
#map3 = affine_map<(d0) -> (d0)>
module {
func.func @main_dispatch_0_attention_20x4096x64xf16() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant -1.000000e+30 : f32
%c4096 = arith.constant 4096 : index
%cst_1 = arith.constant 1.000000e+00 : f32
%c128 = arith.constant 128 : index
%alloca = memref.alloca() {alignment = 64 : i64} : memref<128x128xf16>
%alloca_2 = memref.alloca() {alignment = 64 : i64} : memref<128xf32>
%alloca_3 = memref.alloca() {alignment = 64 : i64} : memref<128x128xf32>
%alloca_4 = memref.alloca() {alignment = 64 : i64} : memref<128xf32>
%alloca_5 = memref.alloca() {alignment = 64 : i64} : memref<128xf32>
%alloca_6 = memref.alloca() {alignment = 64 : i64} : memref<128x64xf32>
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
%3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %3, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%4 = affine.apply #map()[%workgroup_id_x]
%subview = memref.subview %3[%workgroup_id_y, %4, 0] [1, 128, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_7 = memref.subview %0[%workgroup_id_y, %4, 0] [1, 128, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_8 = memref.subview %1[%workgroup_id_y, 0, 0] [1, 4096, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_9 = memref.subview %2[%workgroup_id_y, 0, 0] [1, 4096, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_10 = memref.subview %subview[0, 0, 0] [1, 128, 64] [1, 1, 1] : memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.fill ins(%cst : f32) outs(%alloca_6 : memref<128x64xf32>)
linalg.fill ins(%cst_0 : f32) outs(%alloca_4 : memref<128xf32>)
linalg.fill ins(%cst : f32) outs(%alloca_5 : memref<128xf32>)
scf.for %arg0 = %c0 to %c4096 step %c128 {
%subview_11 = memref.subview %subview_8[0, %arg0, 0] [1, 128, 64] [1, 1, 1] : memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.fill ins(%cst : f32) outs(%alloca_3 : memref<128x128xf32>)
%subview_12 = memref.subview %subview_7[0, 0, 0] [1, 128, 64] [1, 1, 1] : memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.matmul_transpose_b ins(%subview_12, %subview_11 : memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%alloca_3 : memref<128x128xf32>)
linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%alloca_3 : memref<128x128xf32>) outs(%alloca_4 : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.maximumf %in, %out : f32
linalg.yield %5 : f32
}
linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_4 : memref<128xf32>) outs(%alloca_3 : memref<128x128xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.subf %out, %in : f32
%6 = math.exp2 %5 : f32
linalg.yield %6 : f32
}
linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel"]} ins(%alloca_4 : memref<128xf32>) outs(%alloca_2 : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.subf %out, %in : f32
%6 = math.exp2 %5 : f32
linalg.yield %6 : f32
}
linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel"]} ins(%alloca_2 : memref<128xf32>) outs(%alloca_5 : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.mulf %in, %out : f32
linalg.yield %5 : f32
}
linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%alloca_3 : memref<128x128xf32>) outs(%alloca_5 : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.addf %in, %out : f32
linalg.yield %5 : f32
}
linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_3 : memref<128x128xf32>) outs(%alloca : memref<128x128xf16>) {
^bb0(%in: f32, %out: f16):
%5 = arith.truncf %in : f32 to f16
linalg.yield %5 : f16
}
linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_2 : memref<128xf32>) outs(%alloca_6 : memref<128x64xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.mulf %in, %out : f32
linalg.yield %5 : f32
}
%subview_13 = memref.subview %subview_9[0, %arg0, 0] [1, 128, 64] [1, 1, 1] : memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.matmul ins(%alloca, %subview_13 : memref<128x128xf16>, memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%alloca_6 : memref<128x64xf32>)
}
linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_5 : memref<128xf32>) outs(%alloca_6 : memref<128x64xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.divf %cst_1, %in : f32
%6 = arith.mulf %5, %out : f32
linalg.yield %6 : f32
}
linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_6 : memref<128x64xf32>) outs(%subview_10 : memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
^bb0(%in: f32, %out: f16):
%5 = arith.truncf %in : f32 to f16
linalg.yield %5 : f16
}
return
}
}
|
I have a fix (which generates the IR that @MaheshRavishankar suggested) for attention on CPU side. Let me prepare a PR. |
There are big stack allocation issues, but you can bypass it with
I'm able to compile the op e2e on CPU. To repro, run |
can we make it |
Yes, we can do it with |
@monorimet #16421 (comment) should unblock running the model on CPU backends (using the flag to disable failure on exceeding stack allocation limit). Can you verify (though on GPU if it isnt within the limit, it wont run. @erman-gurses after Hanhan's fixes, can you add an e2e test in IREE with attention op, so that we can catch this issue. |
|
Thanks, please add e2e tests to https://github.com/openxla/iree/tree/main/tests/e2e/linalg_ext_ops |
The following command successfully finishes compilation for me with these commits cherrypicked:
Thank you @hanhanW I will try on other backends as well |
Opening the issue until we verify that it's fixed for other backends as well. |
ROCM runs into shared memory allocation limit with attention tiled+decomposed: #16538 |
Vulkan is tricky. Stumbling around SPIRV KernelConfig it seems that we just don't have a good pipeline for this decomposition -- I haven't had any luck dropping in @erman-gurses @Eliasj42 let me know if you two found anything workable for this (we will see it in vulkan + SDXL UNet, VAE) |
It's a wrong fix, so it introduces numerical issues. I created a revert #16559 I will figure how to fix it correctly. |
I just realized that my fix is still wrong.. it does not consider the update max slice correctly. There are two issues about attention.
I'm gonna take a look at (1) and can probably implement (2). If anyone can help on (2), that would be great. I can point out where to add the code. |
(2) makes sure that we have basic coverages for all the backends, including VMVX. |
I noticed that we can pass bufferization if we vectorize all the operations. I think the final goal is to have vectorization working for all the dispatches. So I'd like to create a new pipeline for attention op on CPU side: #16577 |
@harsh-nod looking at this
These kinds of |
@hanhanW I can work on the scalar fallback solution. I think we have a similar implementation here: https://github.com/gpetters94/mlir-npcomp/blob/3e30bb06c0cd725a32f1091552d4824bd796a2d6/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp#L179 |
The revision adds a new pipeline for LinalgExt ops. It is an experimental pipeline, and should eventually get merged into MultiTilingPipeline. The new pipeline introduces vector level of tiling to LinalgExt, and vectorization. Some dimension of attention op is not able to tile at this moment, so we set all the tile sizes to 1 which avoids huge vectors. Because the reduction dimension of matmuls is not tiled. Here is selected IR dump: https://gist.githubusercontent.com/hanhanW/db4511da681d4932cb81dd68cc98976f/raw/08c3cc42c9d7fb86b769f60dc712fecb9fb10700/dump.mlir Towards #16421
Thanks for offering the help, that's very helpful! |
What happened?
We have IR for SDXL that preserves the torch.aten._scaled_dot_product_attention op as tm_tensor.attention -> iree_linalg_ext.attention, but I'm seeing issues trying to lower the resulting IR through IREE-compiler.
@rsuderman managed to narrow the issue down to comprehensive bufferization of the tiled and decomposed attention op.
It seems that a buffer/address is being reused many times in the inner loop, causing bufferization/analysis to fail, though the error message given only mentions a mismatched yield / iterArg pair.
The minimized IR (minimal_attn.mlir):
The CLI input:
The error message:
The full log with --mlir-print-ir-after-all:
out.txt
What component(s) does this issue relate to?
iree-compiler, one-shot bufferization/analysis
Version information
Reproduced on latest IREE (0c61f77) and on a source build from contents of this PR: #16416
Additional context
No response
The text was updated successfully, but these errors were encountered: