Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

iree_linalg_ext.attention dispatch fails to bufferize after TileAndDecomposeAttention #16421

Open
monorimet opened this issue Feb 15, 2024 · 26 comments · Fixed by #16525
Open
Assignees
Labels
bug 🐞 Something isn't working

Comments

@monorimet
Copy link
Collaborator

monorimet commented Feb 15, 2024

What happened?

We have IR for SDXL that preserves the torch.aten._scaled_dot_product_attention op as tm_tensor.attention -> iree_linalg_ext.attention, but I'm seeing issues trying to lower the resulting IR through IREE-compiler.

@rsuderman managed to narrow the issue down to comprehensive bufferization of the tiled and decomposed attention op.

It seems that a buffer/address is being reused many times in the inner loop, causing bufferization/analysis to fail, though the error message given only mentions a mismatched yield / iterArg pair.

The minimized IR (minimal_attn.mlir):

func.func @main(%arg0 : tensor<20x4096x64xf16>, %arg1 : tensor<20x4096x64xf16>, %arg2 : tensor<20x4096x64xf16>, %arg3 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> {
  %empty = tensor.empty() : tensor<20x4096x64xf16>
  %0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
  return %0 : tensor<20x4096x64xf16>
}

The CLI input:

iree-compile ./minimal_attn.mlir --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=host --mlir-print-op-on-diagnostic=false --iree-vm-bytecode-module-strip-source-map=true --iree-util-zero-fill-elided-attrs --iree-opt-strip-assertions=true --verify=true -o minimal_attn_cpu.vmfb --mlir-print-ir-after-all 2> out.txt

The error message:

./minimal_attn.mlir:3:8: error: Yield operand #1 is not equivalent to the corresponding iter bbArg
  %0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> 
tensor<20x4096x64xf16>
       ^
./minimal_attn.mlir:1:1: note: called from
func.func @main(%arg0 : tensor<20x4096x64xf16>, %arg1 : tensor<20x4096x64xf16>, %arg2 : tensor<20x4096x64xf16>, %arg3 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> {
^
./minimal_attn.mlir:3:8: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = 
"znver3", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-avx512fp16,-usermsr,-sm4,+sse4.1,-avx512ifma,+xsave,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-wid
ekl,-sm3,+invpcid,+64bit,+xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-avx512er,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,-rtm,+adx
,+avx2,-hreset,-movdiri,-serialize,+vpclmulqdq,-avx512vl,-uintr,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-gfni,-avxvnniint16,-amx-fp16,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-
avx512bf16,-avx512vnni,+cx8,-avx512bw,+sse3,-pku,+fsgsbase,+clzero,+mwaitx,-lwp,+lzcnt,+sha,-movdir64b,+wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,
+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,+rdpru,+clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,+rdpid,-fma4,-avx512vbmi,+shstk,+vaes,-waitpkg,-sgx,+fxsr,-avx512dq,+sse
4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : index, target_triple = 
"x86_64-unknown-unknown-eabi-elf", ukernels = "default"}>
  %0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> 
tensor<20x4096x64xf16>
       ^
./minimal_attn.mlir:1:1: note: called from
func.func @main(%arg0 : tensor<20x4096x64xf16>, %arg1 : tensor<20x4096x64xf16>, %arg2 : tensor<20x4096x64xf16>, %arg3 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> {
^
./minimal_attn.mlir:3:8: error: failed to translate executables
  %0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> 
tensor<20x4096x64xf16>
       ^
./minimal_attn.mlir:1:1: note: called from
func.func @main(%arg0 : tensor<20x4096x64xf16>, %arg1 : tensor<20x4096x64xf16>, %arg2 : tensor<20x4096x64xf16>, %arg3 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> {
^

The full log with --mlir-print-ir-after-all:

out.txt

What component(s) does this issue relate to?

iree-compiler, one-shot bufferization/analysis

Version information

Reproduced on latest IREE (0c61f77) and on a source build from contents of this PR: #16416

Additional context

No response

@monorimet monorimet added the bug 🐞 Something isn't working label Feb 15, 2024
@monorimet
Copy link
Collaborator Author

The attached out.txt shows that TileAndDecomposeAttention was the last successful pass, and EliminateEmptyTensors fails first.

The IR after TileAndDecomposeAttention shows an address being reused many times in the innermost decomposed block of computation, which might be triggering a failure of OneShotAnalysis:

// -----// IR Dump After TileAndDecomposeWinogradTransform (iree-linalg-ext-tile-and-decompose-winograd) //----- //
func.func @main_dispatch_0_attention_20x4096x64xf16() {
  %c4096 = arith.constant 4096 : index
  %c20 = arith.constant 20 : index
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
  %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %workgroup_count_x = hal.interface.workgroup.count[0] : index
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_count_y = hal.interface.workgroup.count[1] : index
  %4 = affine.apply affine_map<()[s0] -> (s0 * 20)>()[%workgroup_id_y]
  %5 = affine.apply affine_map<()[s0] -> (s0 * 20)>()[%workgroup_count_y]
  scf.for %arg0 = %4 to %c20 step %5 {
    %6 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
    %7 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
    scf.for %arg1 = %6 to %c4096 step %7 {
      %8 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>> -> 
tensor<20x64x64xf16>
      %9 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> 
tensor<20x64x64xf16>
      %10 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> 
tensor<20x4096x64xf16>
      %11 = flow.dispatch.tensor.load %2, offsets = [%arg0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> 
tensor<20x4096x64xf16>
      %12 = tensor.empty() : tensor<64x64xf32>
      %extracted_slice = tensor.extract_slice %8[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<20x64x64xf16> to tensor<64x64xf16>
      %cst = arith.constant 0.000000e+00 : f32
      %13 = linalg.fill ins(%cst : f32) outs(%12 : tensor<64x64xf32>) -> tensor<64x64xf32>
      %cst_0 = arith.constant -1.000000e+30 : f32
      %14 = tensor.empty() : tensor<64xf32>
      %15 = linalg.fill ins(%cst_0 : f32) outs(%14 : tensor<64xf32>) -> tensor<64xf32>
      %16 = tensor.empty() : tensor<64xf32>
      %17 = linalg.fill ins(%cst : f32) outs(%16 : tensor<64xf32>) -> tensor<64xf32>
      %c0_1 = arith.constant 0 : index
      %c4096_2 = arith.constant 4096 : index
      %c64 = arith.constant 64 : index
      %18:3 = scf.for %arg2 = %c0_1 to %c4096_2 step %c64 iter_args(%arg3 = %13, %arg4 = %15, %arg5 = %17) -> (tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>) {
        %extracted_slice_3 = tensor.extract_slice %10[0, %arg2, 0] [1, 64, 64] [1, 1, 1] : tensor<20x4096x64xf16> to tensor<64x64xf16>
        %extracted_slice_4 = tensor.extract_slice %11[0, %arg2, 0] [1, 64, 64] [1, 1, 1] : tensor<20x4096x64xf16> to tensor<64x64xf16>
        %extracted_slice_5 = tensor.extract_slice %9[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<20x64x64xf16> to tensor<64x64xf16>
        %cst_6 = arith.constant 0.000000e+00 : f32
        %21 = tensor.empty() : tensor<64x64xf32>
        %22 = linalg.fill ins(%cst_6 : f32) outs(%21 : tensor<64x64xf32>) -> tensor<64x64xf32>
        %23 = linalg.matmul_transpose_b ins(%extracted_slice_5, %extracted_slice_3 : tensor<64x64xf16>, tensor<64x64xf16>) outs(%22 : tensor<64x64xf32>) -> tensor<64x64xf32>
        %24 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%23 : 
tensor<64x64xf32>) outs(%arg4 : tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.maximumf %in, %out : f32
          linalg.yield %33 : f32
        } -> tensor<64xf32>
        %25 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%24 : tensor<64xf32>) 
outs(%23 : tensor<64x64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.subf %out, %in : f32
          %34 = math.exp2 %33 : f32
          linalg.yield %34 : f32
        } -> tensor<64x64xf32>
        %26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<64xf32>) outs(%arg4 : 
tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.subf %out, %in : f32
          %34 = math.exp2 %33 : f32
          linalg.yield %34 : f32
        } -> tensor<64xf32>
        %27 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%26 : tensor<64xf32>) outs(%arg5 : 
tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.mulf %in, %out : f32
          linalg.yield %33 : f32
        } -> tensor<64xf32>
        %28 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%25 : 
tensor<64x64xf32>) outs(%27 : tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.addf %in, %out : f32
          linalg.yield %33 : f32
        } -> tensor<64xf32>
        %29 = tensor.empty() : tensor<64x64xf16>
        %30 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%25 : 
tensor<64x64xf32>) outs(%29 : tensor<64x64xf16>) {
        ^bb0(%in: f32, %out: f16):
          %33 = arith.truncf %in : f32 to f16
          linalg.yield %33 : f16
        } -> tensor<64x64xf16>
        %31 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%26 : tensor<64xf32>) 
outs(%arg3 : tensor<64x64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.mulf %in, %out : f32
          linalg.yield %33 : f32
        } -> tensor<64x64xf32>
        %32 = linalg.matmul ins(%30, %extracted_slice_4 : tensor<64x64xf16>, tensor<64x64xf16>) outs(%31 : tensor<64x64xf32>) -> tensor<64x64xf32>
        scf.yield %32, %24, %28 : tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>
      }
      %19 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%18#2 : tensor<64xf32>) 
outs(%18#0 : tensor<64x64xf32>) {
      ^bb0(%in: f32, %out: f32):
        %cst_3 = arith.constant 1.000000e+00 : f32
        %21 = arith.divf %cst_3, %in : f32
        %22 = arith.mulf %21, %out : f32
        linalg.yield %22 : f32
      } -> tensor<64x64xf32>
      %20 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%19 : 
tensor<64x64xf32>) outs(%extracted_slice : tensor<64x64xf16>) {
      ^bb0(%in: f32, %out: f16):
        %21 = arith.truncf %in : f32 to f16
        linalg.yield %21 : f16
      } -> tensor<64x64xf16>
      %inserted_slice = tensor.insert_slice %20 into %8[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<64x64xf16> into tensor<20x64x64xf16>
      flow.dispatch.tensor.store %inserted_slice, %3, offsets = [%arg0, %arg1, 0], sizes = [20, 64, 64], strides = [1, 1, 1] : tensor<20x64x64xf16> -> 
!flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
    }
  }
  return
}

@monorimet monorimet changed the title iree_linalg_ext.attention op fails to bufferize after TileAndDecomposeAttention iree_linalg_ext.attention dispatch fails to bufferize after TileAndDecomposeAttention Feb 15, 2024
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this issue Feb 16, 2024
The attention decomposition is failing during bufferization,
specifically `--iree-eliminate-empty-tensors-pass`. This seems to be
an upstream bug, but there are clearly no end-to-end tests for
attention operation. So disabling the decomposition for now till the
issue is fixed. Eventually this pass is probably needed only backends
that dont support attention.

Towards iree-org#16421.
@MaheshRavishankar
Copy link
Contributor

Turns out attention decomposition is failing during bufferization and we probably dont have any end-to-end tests exercising this (AFAIK all attention op related work has been done only using transform dialect with correctness checks out of tree). So this is probably failing silently.
I disabled the decomposition for now. This will unblock other work, while I look into this issue (most like file upstream bug).

MaheshRavishankar added a commit that referenced this issue Feb 16, 2024
The attention decomposition is failing during bufferization,
specifically `--iree-eliminate-empty-tensors-pass`. This seems to be an
upstream bug, but there are clearly no end-to-end tests for attention
operation. So disabling the decomposition for now till the issue is
fixed. Eventually this pass is probably needed only backends that dont
support attention.

Towards #16421.
@MaheshRavishankar
Copy link
Contributor

I am attaching the current decomposition (repro.mlir) and what I manually fixed (fixed.mlir). The latter fixes the failure in the eliminate empty tensors pass. The current decomposition is doing something funky with destinations and probably needs a look. It could also probably use some elementwise op fusion to make the IR more managable. All in all this needs a new look.

Tagging @harsh-nod , do you think we can iterate on this to get to a better final place.

fixed.mlir.txt
repro.mlir.txt

@harsh-nod
Copy link
Contributor

@MaheshRavishankar - definitely. I will be out next week but @erman-gurses can help pick this up in the meanwhile. Some notes on the problem:

  1. I just tried top of master (with pass uncommented) and did not see the tensor.casts (here is my output: https://gist.github.com/harsh-nod/c8a2a9da723dc5deea77fa6f8d5ff229).
  2. I can reproduce the error though and it does fail in eliminate empty tensors.
  3. The error has to do with the use of (%arg4 in the above gist and I can see you fixed that manually by adding a tensor.empty). I can comment a little bit as to why arg4 is used the way it is.
    arg4 represents the initial value of the accumulator for the max value (used for softmax). You can see this here
%24 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%23 : tensor<64x64xf32>) outs(%arg4 : tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.maximumf %in, %out : f32
          linalg.yield %33 : f32
        } -> tensor<64xf32>

where we use the old max to compute the new max (%24). This is returned at the end of the loop

scf.yield %32, %24, %28 : tensor<64x64xf32>, tensor<64xf32>, tensor<64xf32>

The problem comes when we use this again here

%26 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%24 : tensor<64xf32>) outs(%arg4 : tensor<64xf32>) {
        ^bb0(%in: f32, %out: f32):
          %33 = arith.subf %out, %in : f32
          %34 = math.exp2 %33 : f32
          linalg.yield %34 : f32
        } -> tensor<64xf32>

The idea behind this second usage was the following:
Say we allocate some memory for %arg4, then if we allocate some additional memory for %24, we want to reuse the memory we allocated for %arg4 when we do the second generic (rather than allocate some more memory). The idea was to minimize the memory allocations inside the for loop with the assumption that the tensor.empty() might materialize into an allocation.

Another note is that the transform dialect bufferization operator seems to bufferize the original graph. So might be worthwhile to look into the differences between the TD Bufferization Op and what's in the pass pipeline. Specifically, here https://github.com/openxla/iree/blob/6560f8600c6e6a79eaed3ca3b26af2f8e2b900ae/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp#L760 where some new passes are being used such as EmptyTensorLoweringPattern etc.

@MaheshRavishankar
Copy link
Contributor

First thing to look at the output of this command and see if it avoids additional stack allocations

iree-opt --iree-eliminate-empty-tensors --empty-tensor-to-alloc-tensor --iree-codegen-iree-comprehensive-bufferize fixed.mlir

i.e. without trying to manually reuse the buffers, but letting the analysis eliminate it for you. If so, then just changing the tile and decompose attention pass to not reuse buffer this way would fix the issue.

@monorimet
Copy link
Collaborator Author

@MaheshRavishankar FWIW:

Followed instructions to run iree-opt --iree-eliminate-empty-tensors --empty-tensor-to-alloc-tensor --iree-codegen-iree-comprehensive-bufferize fixed.mlir

output of iree-opt (above command):
minimal_attn_elim.mlir.txt

iree-compile .\minimal_attn_elim.mlir --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-embedded-linker-path=C:\V\iree\build\compiler\bindings\python\iree\compiler\tools\..\_mlir_libs\iree-lld.exe --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=./shark_tmp/core-reproducer.mlir --iree-input-type=torch --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-llvmcpu-enable-ukernels=all -o minimal_attn.vmfb
Assertion failed: index < size() && "invalid index into type range", file C:\V\iree\third_party\llvm-project\mlir\include\mlir/IR/TypeRange.h, line 140
.\minimal_attn_elim.mlir:6:1: error: Failures have been detected while processing an MLIR pass pipeline
module {
^
.\minimal_attn_elim.mlir:6:1: note: Pipeline failed while executing [`FormDispatchRegions` on 'util.func' operation: @main_dispatch_0_attention_20x4096x64xf16]: reproducer generated at `./shark_tmp/core-reproducer.mlir`

@monorimet
Copy link
Collaborator Author

noting I've included the empty-tensor-to-alloc-tensor pass ad-hoc -- this is my diff to init_mlir_passes.h:

diff --git a/compiler/src/iree/compiler/Tools/init_mlir_passes.h b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
index 4ebbdafe0..2d1d17897 100644
--- a/compiler/src/iree/compiler/Tools/init_mlir_passes.h
+++ b/compiler/src/iree/compiler/Tools/init_mlir_passes.h
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -59,6 +60,9 @@ inline void registerMlirPasses() {
   // Arm SME
   arm_sme::registerArmSMEPasses();

+  // Bufferization
+  bufferization::registerEmptyTensorToAllocTensor();
+
   // Linalg
   registerLinalgPasses();

@hanhanW hanhanW self-assigned this Feb 21, 2024
@erman-gurses erman-gurses self-assigned this Feb 21, 2024
@hanhanW
Copy link
Contributor

hanhanW commented Feb 21, 2024

I just verified that fixed.mlir can pass bufferization, though there are some big buffers. With #16524, we generate the below IRs. There are some stack allocation, which is bounded by 128 and 64. It is okay if they are from tile sizes. Without the fix (i.e., using repro.mlir.txt), I'm able to see the error. So we probably want to fix the decomposition logic.

#map = affine_map<()[s0] -> (s0 * 128)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> (d0)>
#map3 = affine_map<(d0) -> (d0)>
module {
  func.func @main_dispatch_0_attention_20x4096x64xf16() {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant -1.000000e+30 : f32
    %c4096 = arith.constant 4096 : index
    %cst_1 = arith.constant 1.000000e+00 : f32
    %c128 = arith.constant 128 : index
    %alloca = memref.alloca() {alignment = 64 : i64} : memref<128x128xf16>
    %alloca_2 = memref.alloca() {alignment = 64 : i64} : memref<128xf32>
    %alloca_3 = memref.alloca() {alignment = 64 : i64} : memref<128x128xf32>
    %alloca_4 = memref.alloca() {alignment = 64 : i64} : memref<128xf32>
    %alloca_5 = memref.alloca() {alignment = 64 : i64} : memref<128xf32>
    %alloca_6 = memref.alloca() {alignment = 64 : i64} : memref<128x64xf32>
    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %0, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %1, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %2, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    memref.assume_alignment %3, 64 : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>>
    %workgroup_id_y = hal.interface.workgroup.id[1] : index
    %workgroup_id_x = hal.interface.workgroup.id[0] : index
    %4 = affine.apply #map()[%workgroup_id_x]
    %subview = memref.subview %3[%workgroup_id_y, %4, 0] [1, 128, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %subview_7 = memref.subview %0[%workgroup_id_y, %4, 0] [1, 128, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %subview_8 = memref.subview %1[%workgroup_id_y, 0, 0] [1, 4096, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %subview_9 = memref.subview %2[%workgroup_id_y, 0, 0] [1, 4096, 64] [1, 1, 1] : memref<20x4096x64xf16, #hal.descriptor_type<storage_buffer>> to memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    %subview_10 = memref.subview %subview[0, 0, 0] [1, 128, 64] [1, 1, 1] : memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
    linalg.fill ins(%cst : f32) outs(%alloca_6 : memref<128x64xf32>)
    linalg.fill ins(%cst_0 : f32) outs(%alloca_4 : memref<128xf32>)
    linalg.fill ins(%cst : f32) outs(%alloca_5 : memref<128xf32>)
    scf.for %arg0 = %c0 to %c4096 step %c128 {
      %subview_11 = memref.subview %subview_8[0, %arg0, 0] [1, 128, 64] [1, 1, 1] : memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
      linalg.fill ins(%cst : f32) outs(%alloca_3 : memref<128x128xf32>)
      %subview_12 = memref.subview %subview_7[0, 0, 0] [1, 128, 64] [1, 1, 1] : memref<1x128x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
      linalg.matmul_transpose_b ins(%subview_12, %subview_11 : memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%alloca_3 : memref<128x128xf32>)
      linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%alloca_3 : memref<128x128xf32>) outs(%alloca_4 : memref<128xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.maximumf %in, %out : f32
        linalg.yield %5 : f32
      }
      linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_4 : memref<128xf32>) outs(%alloca_3 : memref<128x128xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.subf %out, %in : f32
        %6 = math.exp2 %5 : f32
        linalg.yield %6 : f32
      }
      linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel"]} ins(%alloca_4 : memref<128xf32>) outs(%alloca_2 : memref<128xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.subf %out, %in : f32
        %6 = math.exp2 %5 : f32
        linalg.yield %6 : f32
      }
      linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel"]} ins(%alloca_2 : memref<128xf32>) outs(%alloca_5 : memref<128xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.mulf %in, %out : f32
        linalg.yield %5 : f32
      }
      linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "reduction"]} ins(%alloca_3 : memref<128x128xf32>) outs(%alloca_5 : memref<128xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.addf %in, %out : f32
        linalg.yield %5 : f32
      }
      linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_3 : memref<128x128xf32>) outs(%alloca : memref<128x128xf16>) {
      ^bb0(%in: f32, %out: f16):
        %5 = arith.truncf %in : f32 to f16
        linalg.yield %5 : f16
      }
      linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_2 : memref<128xf32>) outs(%alloca_6 : memref<128x64xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.mulf %in, %out : f32
        linalg.yield %5 : f32
      }
      %subview_13 = memref.subview %subview_9[0, %arg0, 0] [1, 128, 64] [1, 1, 1] : memref<1x4096x64xf16, strided<[262144, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
      linalg.matmul ins(%alloca, %subview_13 : memref<128x128xf16>, memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%alloca_6 : memref<128x64xf32>)
    }
    linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_5 : memref<128xf32>) outs(%alloca_6 : memref<128x64xf32>) {
    ^bb0(%in: f32, %out: f32):
      %5 = arith.divf %cst_1, %in : f32
      %6 = arith.mulf %5, %out : f32
      linalg.yield %6 : f32
    }
    linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloca_6 : memref<128x64xf32>) outs(%subview_10 : memref<128x64xf16, strided<[64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) {
    ^bb0(%in: f32, %out: f16):
      %5 = arith.truncf %in : f32 to f16
      linalg.yield %5 : f16
    }
    return
  }
}

@hanhanW
Copy link
Contributor

hanhanW commented Feb 21, 2024

I have a fix (which generates the IR that @MaheshRavishankar suggested) for attention on CPU side. Let me prepare a PR.

@hanhanW
Copy link
Contributor

hanhanW commented Feb 21, 2024

There are big stack allocation issues, but you can bypass it with --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false. Otherwise, you'll see the below error. I think they are bounded by distribution tile sizes (which is tile_sizes = [[20, 64]] in this case), so it is fine for now.

/Users/hanchung/z.mlir:3:8: error: 'func.func' op exceeded stack allocation limit of 32768 bytes for function. Got 41728 bytes
  %0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
...
  %24 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf16>
  %25 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
  %26 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf32>
  %27 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
  %28 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
  %29 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf32>

I'm able to compile the op e2e on CPU. To repro, run iree-compile --output-format=vm-bytecode --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=cascadelake --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ~/z.mlir -o /tmp/a.vmfb --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false.

@MaheshRavishankar
Copy link
Contributor

There are big stack allocation issues, but you can bypass it with --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false. Otherwise, you'll see the below error. I think they are bounded by distribution tile sizes (which is tile_sizes = [[20, 64]] in this case), so it is fine for now.

/Users/hanchung/z.mlir:3:8: error: 'func.func' op exceeded stack allocation limit of 32768 bytes for function. Got 41728 bytes
  %0 = iree_linalg_ext.attention ins(%arg0, %arg1, %arg2 : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>) outs(%empty : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
...
  %24 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf16>
  %25 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
  %26 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf32>
  %27 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
  %28 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64xf32>
  %29 = "memref.alloca"() <{alignment = 64 : i64, operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<64x64xf32>

I'm able to compile the op e2e on CPU. To repro, run iree-compile --output-format=vm-bytecode --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=cascadelake --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ~/z.mlir -o /tmp/a.vmfb --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false.

can we make it [[20, 32]] and it fits within the limit?

@hanhanW
Copy link
Contributor

hanhanW commented Feb 22, 2024

can we make it [[20, 32]] and it fits within the limit?

Yes, we can do it with --iree-llvmcpu-distribution-size=32 flag. I think we need a specialized setRootConfig entry function for the op (or all the LinalgExt ops). Because all of them go with CPUDefault pipeline, which only apply distribution and bufferization.

@MaheshRavishankar
Copy link
Contributor

@monorimet #16421 (comment) should unblock running the model on CPU backends (using the flag to disable failure on exceeding stack allocation limit). Can you verify (though on GPU if it isnt within the limit, it wont run.

@erman-gurses after Hanhan's fixes, can you add an e2e test in IREE with attention op, so that we can catch this issue.

@erman-gurses
Copy link
Contributor

erman-gurses commented Feb 22, 2024

@erman-gurses after Hanhan's fixes, can you add an e2e test in IREE with attention op, so that we can catch this issue.
@MaheshRavishankar, Sure, I can work on that.

@hanhanW
Copy link
Contributor

hanhanW commented Feb 22, 2024

@erman-gurses after Hanhan's fixes, can you add an e2e test in IREE with attention op, so that we can catch this issue.
@MaheshRavishankar, Sure, I can work on that.

Thanks, please add e2e tests to https://github.com/openxla/iree/tree/main/tests/e2e/linalg_ext_ops

@monorimet
Copy link
Collaborator Author

The following command successfully finishes compilation for me with these commits cherrypicked:

iree-compile .\minimal_attn.mlir --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-embedded-linker-path=C:\V\iree\build\compiler\bindings\python\iree\compiler\tools\..\_mlir_libs\iree-lld.exe --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=./shark_tmp/core-reproducer.mlir --iree-input-type=torch --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-llvmcpu-enable-ukernels=all --iree-llvmcpu-distribution-size=32 -o minimal_attn.vmfb

Thank you @hanhanW

I will try on other backends as well

@hanhanW hanhanW reopened this Feb 22, 2024
@hanhanW
Copy link
Contributor

hanhanW commented Feb 22, 2024

Opening the issue until we verify that it's fixed for other backends as well.

@monorimet
Copy link
Collaborator Author

ROCM runs into shared memory allocation limit with attention tiled+decomposed: #16538

@monorimet
Copy link
Collaborator Author

Vulkan is tricky.

Stumbling around SPIRV KernelConfig it seems that we just don't have a good pipeline for this decomposition -- I haven't had any luck dropping in LinalgExt::TileAndDecomposeAttentionPass anywhere in addTileAndDistributeToWorkgroupsPasses to mimic the LLVMCPU/LLVMGPU implementation, mostly running into issues turning the result into valid MMA Subgroup Compute ops, perhaps it needs a different or slightly more bespoke pipeline to be set for these attention op dispatches... Very open to suggestions here.

@erman-gurses @Eliasj42 let me know if you two found anything workable for this (we will see it in vulkan + SDXL UNet, VAE)

@hanhanW
Copy link
Contributor

hanhanW commented Feb 23, 2024

It's a wrong fix, so it introduces numerical issues. I created a revert #16559

I will figure how to fix it correctly.

@hanhanW
Copy link
Contributor

hanhanW commented Feb 26, 2024

I just realized that my fix is still wrong.. it does not consider the update max slice correctly. There are two issues about attention.

  1. We need to teach IREE bufferization about the decomposed ops.
  2. We need to implement LinalgExtToLoops for the op, so we can at least have scalar fallback solution always.

I'm gonna take a look at (1) and can probably implement (2). If anyone can help on (2), that would be great. I can point out where to add the code.

@hanhanW
Copy link
Contributor

hanhanW commented Feb 26, 2024

(2) makes sure that we have basic coverages for all the backends, including VMVX.

@hanhanW
Copy link
Contributor

hanhanW commented Feb 26, 2024

I noticed that we can pass bufferization if we vectorize all the operations. I think the final goal is to have vectorization working for all the dispatches. So I'd like to create a new pipeline for attention op on CPU side: #16577

@MaheshRavishankar
Copy link
Contributor

@harsh-nod looking at this

%0 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<128xf32>) outs(%3 : tensor<128x128xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.subf %out, %in : f32
        %6 = math.exp2 %5 : f32
        linalg.yield %6 : f32
      } -> tensor<128x128xf32>

These kinds of linalg.generic is kind of an anti-pattern. Ideally the %out value is never "read" from within the body when the iterators are marked as all parallel. It isnt wrong per-se, but such uses are better avoided. I understand why you are doing it. A better way to do this would be to use bufferization.to_tensor or similar operations to try to reuse memory.

@pashu123
Copy link
Contributor

I just realized that my fix is still wrong.. it does not consider the update max slice correctly. There are two issues about attention.

  1. We need to teach IREE bufferization about the decomposed ops.
  2. We need to implement LinalgExtToLoops for the op, so we can at least have scalar fallback solution always.

I'm gonna take a look at (1) and can probably implement (2). If anyone can help on (2), that would be great. I can point out where to add the code.

@hanhanW I can work on the scalar fallback solution. I think we have a similar implementation here: https://github.com/gpetters94/mlir-npcomp/blob/3e30bb06c0cd725a32f1091552d4824bd796a2d6/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp#L179

hanhanW added a commit that referenced this issue Feb 27, 2024
The revision adds a new pipeline for LinalgExt ops. It is an
experimental pipeline, and should eventually get merged into
MultiTilingPipeline.

The new pipeline introduces vector level of tiling to LinalgExt, and
vectorization. Some dimension of attention op is not able to tile at
this moment, so we set all the tile sizes to 1 which avoids huge
vectors. Because the reduction dimension of matmuls is not tiled. Here
is selected IR dump:
https://gist.githubusercontent.com/hanhanW/db4511da681d4932cb81dd68cc98976f/raw/08c3cc42c9d7fb86b769f60dc712fecb9fb10700/dump.mlir

Towards #16421
@hanhanW
Copy link
Contributor

hanhanW commented Feb 27, 2024

@hanhanW I can work on the scalar fallback solution. I think we have a similar implementation here: https://github.com/gpetters94/mlir-npcomp/blob/3e30bb06c0cd725a32f1091552d4824bd796a2d6/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp#L179

Thanks for offering the help, that's very helpful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants