From 9d2d7668420ebb7f060b8c0a8f3b770fd97ab97c Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Thu, 11 Jul 2024 12:44:21 -0700 Subject: [PATCH] [LinalgExt] Adding IndexingMaps to linalg_ext.attentionOp (#17864) In order to make fusion with other generics, specifically transpose easier, we introduce affineMaps/indexingMaps to linalg_ext.attentionOp. With that we are also enforcing the number and types of dpsInputs. We are also removing "transpose_V" attribute in lieu of infering from indexingMaps. --------- Co-authored-by: Kunwar Grover --- .../ConvertTMTensorToLinalgExt.cpp | 105 +++++--------- .../Torch/InputConversion/test/attention.mlir | 37 +++-- .../test/select_x86_64_lowering_strategy.mlir | 7 +- .../Codegen/LLVMGPU/test/attention.mlir | 6 +- .../Codegen/LLVMGPU/test/attention_mfma.mlir | 6 +- .../Dialect/LinalgExt/IR/LinalgExtOps.cpp | 113 +++------------ .../Dialect/LinalgExt/IR/LinalgExtOps.td | 137 +++++++++--------- .../Dialect/LinalgExt/IR/test/invalid.mlir | 22 ++- .../Dialect/LinalgExt/IR/test/roundtrip.mlir | 63 +++++++- .../Transforms/DecomposeAttention.cpp | 2 +- .../LinalgExt/Transforms/TileAttention.cpp | 54 ++++++- .../Transforms/test/decompose_attention.mlir | 26 +++- .../Transforms/test/tile_attention.mlir | 34 ++++- .../LinalgExt/Transforms/test/tiling.mlir | 30 +++- tests/e2e/linalg_ext_ops/attention.mlir | 18 ++- 15 files changed, 368 insertions(+), 292 deletions(-) diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp index 7a86812a2017..f88e4daf433b 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp @@ -71,80 +71,37 @@ struct ScatterOpConversion }; } // namespace -static Value collapseBatches(PatternRewriter &rewriter, Location loc, - Value val) { - auto valSizes = cast(val.getType()).getShape(); - int64_t newBatch = - std::accumulate(valSizes.begin(), valSizes.end() - 2, 1, - [](int64_t x, int64_t y) { return x * y; }); - Type elementType = cast(val.getType()).getElementType(); - SmallVector newSizes{newBatch}; - newSizes.append(valSizes.end() - 2, valSizes.end()); - Type newType = RankedTensorType::get(newSizes, elementType); - - auto rank = valSizes.size(); - SmallVector collapsed; - for (auto i = 0; i < rank - 2; i++) - collapsed.push_back(i); - - SmallVector reassociation(3); - reassociation[0].append(collapsed); - reassociation[1].push_back(rank - 2); - reassociation[2].push_back(rank - 1); - - return rewriter - .create(loc, newType, val, reassociation) - .getResult(); -} -static Value expandBatches(PatternRewriter &rewriter, Location loc, - SmallVector batchSizes, Value val) { - auto valSizes = cast(val.getType()).getShape(); - Type elementType = cast(val.getType()).getElementType(); - SmallVector newSizes(batchSizes); - newSizes.append(valSizes.end() - 2, valSizes.end()); - auto rank = newSizes.size(); - Type newType = RankedTensorType::get(newSizes, elementType); - - SmallVector reassociation(3); - for (auto i = 0; i < batchSizes.size(); i++) - reassociation[0].push_back(i); - reassociation[1].push_back(rank - 2); - reassociation[2].push_back(rank - 1); - - return rewriter - .create(loc, newType, val, reassociation) - .getResult(); +static SmallVector +getStandardAttentionIndexingMaps(MLIRContext *ctx) { + AffineExpr m, k1, k2, n; + bindDims(ctx, m, k1, k2, n); + + AffineMap qMap = + AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k1}, ctx); + AffineMap kMap = + AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx); + AffineMap vMap = + AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx); + AffineMap rMap = + AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx); + + return {qMap, kMap, vMap, rMap}; } + struct AttentionOpConversion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::torch::TMTensor::AttentionOp op, PatternRewriter &rewriter) const override { + MLIRContext *ctx = getContext(); Location loc = op->getLoc(); Value query = op.getQuery(); Value key = op.getKey(); Value value = op.getValue(); - auto sizes = cast(query.getType()).getShape(); - SmallVector batchSizes(sizes.begin(), sizes.end() - 2); - - if (sizes.size() > 3) { - query = collapseBatches(rewriter, loc, query); - key = collapseBatches(rewriter, loc, key); - value = collapseBatches(rewriter, loc, value); - } - SmallVector resultShape( - cast(op->getResultTypes()[0]).getShape()); - SmallVector collapsedResultShape; - collapsedResultShape.push_back( - std::accumulate(resultShape.begin(), resultShape.end() - 2, 1, - [](int64_t x, int64_t y) { return x * y; })); - collapsedResultShape.append(resultShape.end() - 2, resultShape.end()); - Type elementType = cast(query.getType()).getElementType(); - auto collapsedResultType = - RankedTensorType::get(collapsedResultShape, elementType); - Value collapsedResult = rewriter.create( - loc, collapsedResultShape, elementType); + ShapedType outputType = op.getOutputType(); + Value result = rewriter.create( + loc, outputType.getShape(), outputType.getElementType()); // TODO: This is a hack. This should be replaced with a simple getScale() // when support for scaling is plumbed to TMTensor on the torch-mlir side. @@ -174,15 +131,21 @@ struct AttentionOpConversion Value scale = rewriter.create( loc, targetType, rewriter.getFloatAttr(targetType, dk)); + // Add batches to standard attention indexing maps. + SmallVector indexingMaps = getStandardAttentionIndexingMaps(ctx); + int64_t numBatches = op.getQueryType().getRank() - 2; + for (AffineMap &map : indexingMaps) { + map = map.shiftDims(numBatches); + for (int batch : llvm::seq(numBatches)) { + map = map.insertResult(rewriter.getAffineDimExpr(batch), batch); + } + } + auto attention = rewriter.create( - loc, collapsedResultType, SmallVector{query, key, value, scale}, - collapsedResult); - - if (sizes.size() > 3) - rewriter.replaceOp( - op, expandBatches(rewriter, loc, batchSizes, attention.getResult(0))); - else - rewriter.replaceOp(op, attention.getResult(0)); + loc, result.getType(), query, key, value, scale, result, + rewriter.getAffineMapArrayAttr(indexingMaps)); + + rewriter.replaceOp(op, attention.getResult(0)); return success(); } }; diff --git a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir index 17000c051a24..360308829894 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir @@ -5,17 +5,18 @@ func.func @attention(%arg0: tensor<5x2x3x4xf32>, %arg1: tensor<5x2x3x4xf32>, %ar return %0 : tensor<5x2x3x4xf32> } +// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> +// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> +// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> + // CHECK-LABEL: func.func @attention( // CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x3x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>, // CHECK: %arg3: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> { // CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32 -// CHECK: %[[COL:.*]] = tensor.collapse_shape %[[ARG0]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32> -// CHECK: %[[COL0:.*]] = tensor.collapse_shape %[[ARG1]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32> -// CHECK: %[[COL1:.*]] = tensor.collapse_shape %[[ARG2]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32> -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<10x3x4xf32> -// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[COL]], %[[COL0]], %[[COL1]], %[[SCALE]] : tensor<10x3x4xf32>, tensor<10x3x4xf32>, tensor<10x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<10x3x4xf32>) -> tensor<10x3x4xf32> -// CHECK: %[[RET:.*]] = tensor.expand_shape %[[ATTN]] {{.*}} : tensor<10x3x4xf32> into tensor<5x2x3x4xf32> -// CHECK: return %[[RET]] : tensor<5x2x3x4xf32> +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x3x4xf32> +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> +// CHECK: return %[[ATTN]] : tensor<5x2x3x4xf32> // ----- func.func @attention(%arg0: tensor<5x2x8x4xf32>, %arg1: tensor<5x2x3x4xf32>, %arg2: tensor<5x2x3x4xf32>, %arg3: tensor<5x2x8x4xf32>) -> (tensor<5x2x8x4xf32>) { @@ -23,17 +24,18 @@ func.func @attention(%arg0: tensor<5x2x8x4xf32>, %arg1: tensor<5x2x3x4xf32>, %ar return %0 : tensor<5x2x8x4xf32> } +// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)> +// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> +// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)> + // CHECK-LABEL: func.func @attention( // CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x8x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>, // CHECK: %arg3: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> { // CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32 -// CHECK: %[[COL:.*]] = tensor.collapse_shape %[[ARG0]] {{.*}} : tensor<5x2x8x4xf32> into tensor<10x8x4xf32> -// CHECK: %[[COL0:.*]] = tensor.collapse_shape %[[ARG1]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32> -// CHECK: %[[COL1:.*]] = tensor.collapse_shape %[[ARG2]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32> -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<10x8x4xf32> -// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[COL]], %[[COL0]], %[[COL1]], %[[SCALE]] : tensor<10x8x4xf32>, tensor<10x3x4xf32>, tensor<10x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<10x8x4xf32>) -> tensor<10x8x4xf32> -// CHECK: %[[RET:.*]] = tensor.expand_shape %[[ATTN]] {{.*}} : tensor<10x8x4xf32> into tensor<5x2x8x4xf32> -// CHECK: return %[[RET]] : tensor<5x2x8x4xf32> +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x8x4xf32> +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> +// CHECK: return %[[ATTN]] : tensor<5x2x8x4xf32> // ----- func.func @attention(%arg0: tensor<1x3x4xf32>, %arg1: tensor<1x3x4xf32>, %arg2: tensor<1x3x4xf32>, %arg3: tensor<1x3x4xf32>) -> (tensor<1x3x4xf32>) { @@ -41,10 +43,15 @@ func.func @attention(%arg0: tensor<1x3x4xf32>, %arg1: tensor<1x3x4xf32>, %arg2: return %0 : tensor<1x3x4xf32> } +// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + // CHECK-LABEL: func.func @attention( // CHECK-SAME: %[[ARG0:.*]]: tensor<1x3x4xf32>, %[[ARG1:.*]]: tensor<1x3x4xf32>, %[[ARG2:.*]]: tensor<1x3x4xf32>, // CHECK: %arg3: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> { // CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32 // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x3x4xf32> -// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32> +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32> // CHECK: return %[[ATTN]] : tensor<1x3x4xf32> diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir index f8b8412d131e..ed06e0bfdb3c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_x86_64_lowering_strategy.mlir @@ -1606,7 +1606,10 @@ module { %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<20x4096x64xf16> %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<20x4096x64xf16> %7 = tensor.empty() : tensor<20x4096x64xf16> - %8 = iree_linalg_ext.attention + %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} ins(%4, %5, %6, %scale : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor> @@ -1618,7 +1621,7 @@ module { // CHECK: func.func @attention() // CHECK-SAME: translation_info = #[[TRANSLATION]] // CHECK: iree_linalg_ext.attention -// CHECK-SAME: {lowering_config = #[[CONFIG]]} +// CHECK-SAME: lowering_config = #[[CONFIG]] // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir index ffc3ef96a0a7..d76070a41129 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention.mlir @@ -14,7 +14,11 @@ module { %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<192x1024x64xf16> %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<192x1024x64xf16> %7 = tensor.empty() : tensor<192x1024x64xf16> - %8 = iree_linalg_ext.attention ins(%4, %5, %6, %cst : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) outs(%7 : tensor<192x1024x64xf16>) -> tensor<192x1024x64xf16> + %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%4, %5, %6, %cst : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) outs(%7 : tensor<192x1024x64xf16>) -> tensor<192x1024x64xf16> flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : tensor<192x1024x64xf16> -> !flow.dispatch.tensor> return } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir index 3732d968c348..ba61ff0ee476 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/attention_mfma.mlir @@ -14,7 +14,11 @@ module { %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<16x16384x128xf16> %6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<16x16384x128xf16> %7 = tensor.empty() : tensor<16x16384x128xf16> - %8 = iree_linalg_ext.attention ins(%4, %5, %6, %scale : tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, f16) outs(%7 : tensor<16x16384x128xf16>) -> tensor<16x16384x128xf16> + %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%4, %5, %6, %scale : tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, f16) outs(%7 : tensor<16x16384x128xf16>) -> tensor<16x16384x128xf16> flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : tensor<16x16384x128xf16> -> !flow.dispatch.tensor> return } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index c7ca00f5a5f7..360ee7e65281 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -1266,52 +1266,29 @@ LogicalResult WinogradOutputTransformOp::reifyResultShapes( //===----------------------------------------------------------------------===// LogicalResult AttentionOp::verify() { - Operation *op = getOperation(); + AttentionOp attnOp = *this; - int numInputs = getNumDpsInputs(); int numOutputs = getNumDpsInits(); - - if (numInputs != 4) { - return op->emitOpError( - "expected 4 input operands: Query, Key, Value and Scale"); - } - if (numOutputs != 1 && numOutputs != 3) { - return op->emitOpError( + return attnOp->emitOpError( "expected 1 or 3 output operands: Output, [Max and Sum]"); } - bool isTiled = numOutputs == 3; - if (!llvm::all_of(llvm::drop_end(getDpsInputs()), [](Value input) { - return isa(input.getType()); - })) { - return op->emitOpError( - "expected Query, Key, Value inputs to be of shaped type"); - } - - ShapedType queryType = getQueryType(); - ShapedType keyType = getKeyType(); - ShapedType valueType = getValueType(); - ShapedType outputType = getOutputType(); - Type queryElementType = queryType.getElementType(); - Type keyElementType = keyType.getElementType(); - Type valueElementType = valueType.getElementType(); - Type outputElementType = outputType.getElementType(); + SmallVector indexingMaps = attnOp.getIndexingMapsArray(); - FloatType scaleElementType = dyn_cast(getScale().getType()); - if (!scaleElementType) { - return op->emitOpError("expected scale to be of floating point type"); - } + // Check if indexing maps can represent attention. + FailureOr maybeOpInfo = + AttentionOpDetail::get(indexingMaps); // Check shape compatibility based on indexing maps. SmallVector shape(getIterationDomainRank()); SmallVector foundDims(getIterationDomainRank(), false); auto checkShape = [&shape, &foundDims, - &op](StringRef operandName, ArrayRef valShape, - AffineMap indexingMap) -> LogicalResult { + &attnOp](StringRef operandName, ArrayRef valShape, + AffineMap indexingMap) -> LogicalResult { if (indexingMap.getNumResults() != valShape.size()) { - return op->emitError("Rank Mismatch for ") + return attnOp->emitError("Rank Mismatch for ") << operandName << ". Expected: " << indexingMap.getNumResults() << " Got: " << valShape.size(); } @@ -1326,7 +1303,7 @@ LogicalResult AttentionOp::verify() { shape[pos] = valShape[i]; } if (shape[pos] != valShape[i]) { - return op->emitError("Shape Mismatch for ") + return attnOp->emitError("Shape Mismatch for ") << operandName << ". Expected: " << shape[pos] << " Got: " << valShape[i]; } @@ -1336,32 +1313,20 @@ LogicalResult AttentionOp::verify() { if (failed(checkShape("Query", getQueryType().getShape(), getQueryMap())) || failed(checkShape("Key", getKeyType().getShape(), getKeyMap())) || - failed(checkShape("Value", getValueType().getShape(), getValueMap()))) { + failed(checkShape("Value", getValueType().getShape(), getValueMap())) || + failed( + checkShape("Output", getOutputType().getShape(), getOutputMap()))) { return failure(); } - if (queryElementType != keyElementType || - queryElementType != valueElementType || - queryElementType != scaleElementType) { - return op->emitOpError( - "element types of (Q)uery, (K)ey and (V)alue and scale should be " - "same"); - } - if (!isTiled) { - // Vanilla attention. - if (queryElementType != outputElementType) { - return op->emitOpError("expected element type for Output ") - << queryElementType << "but found " << outputElementType - << " instead"; - } - } if (isTiled) { // Tiled/Flash attention. Type maxElementType = getMaxType()->getElementType(); Type sumElementType = getSumType()->getElementType(); + Type outputElementType = getOutputType().getElementType(); if (outputElementType != maxElementType || maxElementType != sumElementType) { - return op->emitOpError( + return attnOp->emitOpError( "element types of tiled output, max and sum should be same"); } @@ -1385,52 +1350,8 @@ LogicalResult AttentionOp::reifyResultShapes( } SmallVector AttentionOp::getIndexingMapsArray() { - MLIRContext *ctx = getContext(); - - AffineExpr batch, m, k1, k2, n; - bindDims(ctx, batch, m, k1, k2, n); - - AffineMap qMap = - AffineMap::get(/*dimCount=*/5, /*symbolCount=*/0, {batch, m, k1}, ctx); - AffineMap kMap = - AffineMap::get(/*dimCount=*/5, /*symbolCount=*/0, {batch, k2, k1}, ctx); - - AffineMap vMap; - if (getTransposeV()) { - vMap = - AffineMap::get(/*dimCount=*/5, /*symbolCount=*/0, {batch, n, k2}, ctx); - } else { - vMap = - AffineMap::get(/*dimCount=*/5, /*symbolCount=*/0, {batch, k2, n}, ctx); - } - - AffineMap resMap = - AffineMap::get(/*dimCount=*/5, /*symbolCount=*/0, {batch, m, n}, ctx); - - SmallVector results = {qMap, kMap, vMap, resMap}; - - if (getMax()) { - AffineMap maxMap = - AffineMap::get(/*dimCount=*/5, /*symbolCount=*/0, {batch, m}, ctx); - results.push_back(maxMap); - } - - if (getSum()) { - AffineMap sumMap = - AffineMap::get(/*dimCount=*/5, /*symbolCount=*/0, {batch, m}, ctx); - results.push_back(sumMap); - } - - // Remove batch dim for tiled operands. - // TODO: This is a weird expectation from TileAndDecomposeAttention. - bool isTiled = getNumResults() == 3; - if (isTiled) { - for (AffineMap &map : results) { - map = map.dropResult(0); - } - } - - return results; + return SmallVector( + getIndexingMaps().getAsValueRange()); } //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index fd907fe10352..993ab3bbb0d6 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -511,8 +511,10 @@ def IREELinalgExt_TopkOp : IREELinalgExt_Op<"topk",[ // Attention //===----------------------------------------------------------------------===// -def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention", - [DeclareOpInterfaceMethods, +def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", + [DeclareOpInterfaceMethods, + DestinationStyleOpInterface, LinalgExtInterface, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods (n, k2) instead + of (m, k1, k2, n) -> (k2, n), then we'd know that V is transposed (V.T), + and we can expect computation to look like: attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V.T - - TODO: We should be moving to using a indexing map like approach so we - can generalize which tensor is transposed and which is not. }]; - let arguments = (ins Variadic:$inputs, + let arguments = (ins AnyShaped:$query, + AnyShaped:$key, + AnyShaped:$value, + AnyFloat:$scale, Variadic:$outputs, - DefaultValuedOptionalAttr:$transpose_v + AffineMapArrayAttr:$indexing_maps ); - let builders = [ - OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs)>, - ]; - let results = (outs Variadic:$results); let hasFolder = 1; + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; let assemblyFormat = [{ attr-dict - `ins` `(` $inputs `:` type($inputs) `)` + `ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)` `outs` `(` $outputs `:` type($outputs) `)` (`->` type($results)^)? }]; - let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ - Value getQuery() { - return getDpsInputOperand(0)->get(); - } - Value getKey() { - return getDpsInputOperand(1)->get(); + let extraClassDeclaration = [{ + // Method to implement for specifying output range for + // DestinationStyleOpInterface + MutableOperandRange getDpsInitsMutable() { + return getOutputsMutable(); } - Value getValue() { - return getDpsInputOperand(2)->get(); + + SmallVector getIndexingMapsArray(); + + AffineMap getQueryMap() { + return cast(getIndexingMapsArray()[0]); } - Value getScale() { - return getDpsInputOperand(3)->get(); + AffineMap getKeyMap() { + return cast(getIndexingMapsArray()[1]); } - Value getOutput() { - return getDpsInitOperand(0)->get(); + AffineMap getValueMap() { + return cast(getIndexingMapsArray()[2]); } - std::optional getMax() { - if (getNumResults() < 2) - return std::nullopt; - return getDpsInitOperand(1)->get(); + AffineMap getOutputMap() { + return cast(getIndexingMapsArray()[3]); } - std::optional getSum() { - if (getNumResults() < 3) - return std::nullopt; - return getDpsInitOperand(2)->get(); + int64_t getIterationDomainRank() { + return getQueryMap().getNumDims(); } - ShapedType getQueryType() { return cast(getQuery().getType()); } @@ -608,6 +607,34 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention", ShapedType getOutputType() { return cast(getOutput().getType()); } + + // Helper functions used for tile_and_decompose_attn. + // TODO: Remove once we completely move to + // online_attn decompose pipeline. + + std::optional getMaxMap() { + if (getNumResults() < 2) + return std::nullopt; + return cast(getIndexingMapsArray()[4]); + } + std::optional getSumMap() { + if (getNumResults() < 3) + return std::nullopt; + return cast(getIndexingMapsArray()[5]); + } + Value getOutput() { + return getDpsInitOperand(0)->get(); + } + std::optional getMax() { + if (getNumResults() < 2) + return std::nullopt; + return getDpsInitOperand(1)->get(); + } + std::optional getSum() { + if (getNumResults() < 3) + return std::nullopt; + return getDpsInitOperand(2)->get(); + } std::optional getMaxType() { std::optional maxVal = getMax(); if (!maxVal) return std::nullopt; @@ -642,39 +669,13 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention", return sumType->getRank(); } - // Method to implement for specifying output range for - // DestinationStyleOpInterface - MutableOperandRange getDpsInitsMutable() { - return getOutputsMutable(); - } - - SmallVector getIndexingMapsArray(); - - AffineMap getQueryMap() { - return cast(getIndexingMapsArray()[0]); - } - AffineMap getKeyMap() { - return cast(getIndexingMapsArray()[1]); - } - AffineMap getValueMap() { - return cast(getIndexingMapsArray()[2]); - } - AffineMap getOutputMap() { - return cast(getIndexingMapsArray()[3]); - } - std::optional getMaxMap() { - if (getNumResults() < 2) - return std::nullopt; - return cast(getIndexingMapsArray()[4]); - } - std::optional getSumMap() { - if (getNumResults() < 3) - return std::nullopt; - return cast(getIndexingMapsArray()[5]); - } - - int64_t getIterationDomainRank() { - return getQueryMap().getNumDims(); + // Since we expect map_V to be (b, m, k1, k2, n) -> (b, k2, n), this function + // returns true if n is not the fastest dim. + bool isTransposeV() { + AffineMap vMap = getValueMap(); + auto vFastDim = llvm::cast(vMap.getResults().back()); + int64_t nDim = vMap.getNumInputs() - 1; + return nDim != vFastDim.getPosition(); } }]; } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir index 5fdfdf5dd4bc..3363f1bf0cb4 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir @@ -741,11 +741,15 @@ func.func @illegal_winograd_filter_kernel_dimensions(%arg0: tensor<3x3x64x128xf3 // ----- -func.func @illegal_attention_inputs(%query: tensor<6x12x20x8xf32>, %key: tensor<6x12x20x8xf32>, %value: tensor<6x12x20x8xf32>) { +func.func @illegal_attention_inputs(%query: tensor<6x12x20x8xf32>, %key: tensor<6x12x20x8xf32>, %value: tensor<6x12x20x8xf32>) -> tensor<6x12x20x8xf32> { %0 = tensor.empty() : tensor<6x12x20x8xf32> %scale = arith.constant 1.0 : f32 // expected-error @+1 {{Rank Mismatch for Query. Expected: 3 Got: 4}} - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>, f32) outs(%0 : tensor<6x12x20x8xf32>) -> tensor<6x12x20x8xf32> + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} + ins(%query, %key, %value, %scale : tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>, tensor<6x12x20x8xf32>, f32) outs(%0 : tensor<6x12x20x8xf32>) -> tensor<6x12x20x8xf32> return %1 : tensor<6x12x20x8xf32> } @@ -757,7 +761,11 @@ func.func @illegal_flash_attention_inputs(%query: tensor<20xf32>, %key: tensor<2 %sum = tensor.empty() : tensor<8xf32> %scale = arith.constant 1.0 : f32 // expected-error @+1 {{Rank Mismatch for Query. Expected: 2 Got: 1}} - %1:3 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<20xf32>, tensor<20x8xf32>, tensor<20x8xf32>, f32) outs(%result, %max, %sum : tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32> + %1:3 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d2, d1)>, + affine_map<(d0, d1, d2, d3) -> (d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0)>]} + ins(%query, %key, %value, %scale : tensor<20xf32>, tensor<20x8xf32>, tensor<20x8xf32>, f32) outs(%result, %max, %sum : tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32> return %1#0, %1#1, %1#2 : tensor<20x8xf32>, tensor<8xf32>, tensor<8xf32> } @@ -766,7 +774,11 @@ func.func @illegal_flash_attention_inputs(%query: tensor<20xf32>, %key: tensor<2 func.func @illegal_attention_inputs(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: f32) -> tensor<192x1024x64xf32> { %0 = tensor.empty() : tensor<192x1024x64xf32> %scale = arith.constant 1.0 : f32 - // expected-error @+1 {{expected Query, Key, Value inputs to be of shaped type}} - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> + // expected-error @+5 {{custom op 'iree_linalg_ext.attention' invalid kind of type specified}} + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} + ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> return %1 : tensor<192x1024x64xf32> } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir index d01074ad86d2..88a1f2d52267 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir @@ -1162,15 +1162,27 @@ func.func @winograd_output_transform_nchw(%arg0: tensor<8x8x1x2x2x1280xf32>) -> func.func @attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> { %0 = tensor.empty() : tensor<192x1024x64xf32> %scale = arith.constant 1.0 : f32 - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> return %1 : tensor<192x1024x64xf32> } + +// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + // CHECK: func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: // CHECK-SAME: tensor<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> // CHECK-SAME: { // CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32> // CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[D1:.+]] = iree_linalg_ext.attention ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : +// CHECK: %[[D1:.+]] = iree_linalg_ext.attention +// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : // CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%[[D0]] : // CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> // CHECK: return %[[D1]] : tensor<192x1024x64xf32> @@ -1181,15 +1193,26 @@ func.func @attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf func.func @cross_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x2048x64xf32>, %value: tensor<192x2048x64xf32>) -> tensor<192x1024x64xf32> { %0 = tensor.empty() : tensor<192x1024x64xf32> %scale = arith.constant 1.0 : f32 - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> return %1 : tensor<192x1024x64xf32> } +// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + // CHECK: func.func @cross_attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: // CHECK-SAME: tensor<192x2048x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x2048x64xf32>) -> tensor<192x1024x64xf32> // CHECK-SAME: { // CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32> // CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[D1:.+]] = iree_linalg_ext.attention ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : +// CHECK: %[[D1:.+]] = iree_linalg_ext.attention +// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : // CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x2048x64xf32>, f32) outs(%[[D0]] : // CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> // CHECK: return %[[D1]] : tensor<192x1024x64xf32> @@ -1197,18 +1220,31 @@ func.func @cross_attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x204 // ----- +// transpose_V is detected through indexingMap. + func.func @cross_attention_transposev(%query: tensor<192x1024x64xf32>, %key: tensor<192x2048x64xf32>, %value: tensor<192x64x2048xf32>) -> tensor<192x1024x64xf32> { %0 = tensor.empty() : tensor<192x1024x64xf32> %scale = arith.constant 1.0 : f32 - %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> return %1 : tensor<192x1024x64xf32> } +// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)> +// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + // CHECK: func.func @cross_attention_transposev(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: // CHECK-SAME: tensor<192x2048x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x64x2048xf32>) -> tensor<192x1024x64xf32> // CHECK-SAME: { // CHECK: %[[D0:.+]] = tensor.empty() : tensor<192x1024x64xf32> // CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[D1:.+]] = iree_linalg_ext.attention {transpose_v = true} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : +// CHECK: %[[D1:.+]] = iree_linalg_ext.attention +// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : // CHECK-SAME: tensor<192x1024x64xf32>, tensor<192x2048x64xf32>, tensor<192x64x2048xf32>, f32) outs(%[[D0]] : // CHECK-SAME: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> // CHECK: return %[[D1]] : tensor<192x1024x64xf32> @@ -1218,14 +1254,25 @@ func.func @cross_attention_transposev(%query: tensor<192x1024x64xf32>, %key: ten func.func @cross_attention_transposev_dyn(%query: tensor, %key: tensor, %value: tensor, %init: tensor) -> tensor { %scale = arith.constant 1.0 : f32 - %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value, %scale : tensor, tensor, tensor, f32) outs(%init : tensor) -> tensor + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor, tensor, tensor, f32) outs(%init : tensor) -> tensor return %1 : tensor } +// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)> +// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + // CHECK: func.func @cross_attention_transposev_dyn(%[[ARG0:[a-zA-Z0-9_]+]]: tensor, %[[ARG1:[a-zA-Z0-9_]+]]: // CHECK-SAME: tensor, %[[ARG2:[a-zA-Z0-9_]+]]: tensor, %[[ARG3:[a-zA-Z0-9_]+]]: tensor) -> tensor // CHECK-SAME: { // CHECK: %[[SCALE:.+]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[D1:.+]] = iree_linalg_ext.attention {transpose_v = true} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : +// CHECK: %[[D1:.+]] = iree_linalg_ext.attention +// CHECK-SAME: {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : // CHECK-SAME: tensor, tensor, tensor, f32) outs(%[[ARG3]] : // CHECK-SAME: tensor) -> tensor // CHECK: return %[[D1]] : tensor diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp index 2cd851dc58a0..daf50dcd2e30 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp @@ -331,7 +331,7 @@ void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp, auto [result, newMax, newSum] = createAttentionBody( keySlice, valueSlice, querySlice, tiledResult, max, sum, sequenceTileLength, keyValueTileLength, headDimension, elementType, ops, - tiledAttnOp.getTransposeV(), loc, rewriter); + tiledAttnOp.isTransposeV(), loc, rewriter); rewriter.replaceOp(tiledAttnOp, ValueRange{result, newMax, newSum}); } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp index df0d87e13fb2..384ff558d8b7 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp @@ -145,6 +145,46 @@ static Value insertOutputSlice(Value src, Value dst, headDimension, loc, builder); } +static SmallVector +getTileAttentionIndexingMaps(RewriterBase &rewriter, int64_t tiledInputRank, + bool transposeV) { + MLIRContext *ctx = rewriter.getContext(); + AffineExpr m, k1, k2, n; + bindDims(ctx, m, k1, k2, n); + + AffineMap qMap = + AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k1}, ctx); + AffineMap kMap = + AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx); + AffineMap vMap = + AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx); + AffineMap rMap = + AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx); + AffineMap maxMap = + AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m}, ctx); + AffineMap sumMap = + AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m}, ctx); + + if (transposeV) { + SmallVector vDims(vMap.getResults()); + std::swap(vDims[0], vDims[1]); + vMap = AffineMap::get(vMap.getNumDims(), vMap.getNumSymbols(), vDims, ctx); + } + + SmallVector attentionMaps = {qMap, kMap, vMap, + rMap, maxMap, sumMap}; + // Add batches to standard attention indexing maps. + int64_t numBatches = tiledInputRank - 2; + for (AffineMap &map : attentionMaps) { + map = map.shiftDims(numBatches); + for (int batch : llvm::seq(numBatches)) { + map = map.insertResult(rewriter.getAffineDimExpr(batch), batch); + } + } + + return attentionMaps; +} + struct TileAttentionPass : public TileAttentionBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert< @@ -257,20 +297,22 @@ IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp, headDimension, elementType, loc, rewriter); Value valueSlice = extractSlice(value, keyShape, ivs, keyValueTileLength, headDimension, - elementType, loc, rewriter, attnOp.getTransposeV()); + elementType, loc, rewriter, attnOp.isTransposeV()); Value querySlice = extractSlice(query, queryShape, {}, sequenceTileLength, headDimension, elementType, loc, rewriter); Value scale = attnOp.getScale(); + int64_t tiledInputRank = cast(querySlice.getType()).getRank(); + SmallVector tiledIndexingMaps = getTileAttentionIndexingMaps( + rewriter, tiledInputRank, attnOp.isTransposeV()); + auto tiledAttentionOp = rewriter.create( attnOp.getLoc(), SmallVector{accumulatorF32.getType(), sum.getType(), max.getType()}, - SmallVector{querySlice, keySlice, valueSlice, scale}, - SmallVector{iterArgResult, iterArgMax, iterArgSum}); - - if (attnOp.getTransposeV()) - tiledAttentionOp.setTransposeVAttr(attnOp.getTransposeVAttr()); + querySlice, keySlice, valueSlice, scale, + SmallVector{iterArgResult, iterArgMax, iterArgSum}, + rewriter.getAffineMapArrayAttr(tiledIndexingMaps)); Value tiledResult = tiledAttentionOp.getResult(0); Value newMax = tiledAttentionOp.getResult(1); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir index ce7e384b730b..0344eb731491 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_attention.mlir @@ -3,7 +3,11 @@ func.func @attention(%query: tensor<1x1024x64xf32>, %key: tensor<1x1024x64xf32>, %value: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> { %0 = tensor.empty() : tensor<1x1024x64xf32> %scale = arith.constant 0.05 : f32 - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> return %1 : tensor<1x1024x64xf32> } @@ -101,7 +105,11 @@ func.func @attention(%query: tensor<1x1024x64xf32>, %key: tensor<1x1024x64xf32>, func.func @attention(%query: tensor, %key: tensor, %value: tensor, %dim0: index, %dim1: index, %dim2: index) -> tensor { %0 = tensor.empty(%dim0, %dim1, %dim2) : tensor %scale = arith.constant 0.05 : f32 - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor, tensor, tensor, f32) outs(%0 : tensor) -> tensor + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor, tensor, tensor, f32) outs(%0 : tensor) -> tensor return %1 : tensor } @@ -202,7 +210,11 @@ func.func @attention(%query: tensor, %key: tensor, %value: func.func @attention_f16(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> { %0 = tensor.empty() : tensor<1x1024x64xf16> %scale = arith.constant 0.05 : f16 - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> return %1 : tensor<1x1024x64xf16> } @@ -313,10 +325,16 @@ func.func @attention_f16(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf // ----- +// transpose_V is detected through indexingMap. + func.func @attention_transpose_v(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> { %0 = tensor.empty() : tensor<1x1024x64xf16> %scale = arith.constant 0.05 : f16 - %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> return %1 : tensor<1x1024x64xf16> } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir index 51f1b6effb2a..be9c9da45760 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tile_attention.mlir @@ -5,7 +5,11 @@ func.func @attention(%query: tensor<1x1024x64xf32>, %key: tensor<1x1024x64xf32>, %value: tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> { %0 = tensor.empty() : tensor<1x1024x64xf32> %scale = arith.constant 0.05 : f32 - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, tensor<1x1024x64xf32>, f32) outs(%0 : tensor<1x1024x64xf32>) -> tensor<1x1024x64xf32> return %1 : tensor<1x1024x64xf32> } @@ -54,7 +58,11 @@ func.func @attention(%query: tensor<1x1024x64xf32>, %key: tensor<1x1024x64xf32>, func.func @attention(%query: tensor, %key: tensor, %value: tensor, %dim0: index, %dim1: index, %dim2: index) -> tensor { %0 = tensor.empty(%dim0, %dim1, %dim2) : tensor %scale = arith.constant 0.05 : f32 - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor, tensor, tensor, f32) outs(%0 : tensor) -> tensor + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor, tensor, tensor, f32) outs(%0 : tensor) -> tensor return %1 : tensor } @@ -83,7 +91,8 @@ func.func @attention(%query: tensor, %key: tensor, %value: // CHECK: %[[K_S:.+]] = tensor.extract_slice %[[KEY]][0, %[[ARG6]], 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1, 1] // CHECK: %[[V_S:.+]] = tensor.extract_slice %[[VALUE]][0, %[[ARG6]], 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1, 1] // CHECK: %[[Q_S:.+]] = tensor.extract_slice %[[QUERY]][0, 0, 0] [1, %[[DIM]], %[[DIM_0]]] [1, 1, 1] -// CHECK: %[[ATT:.+]]:3 = iree_linalg_ext.attention ins(%[[Q_S]], %[[K_S]], %[[V_S]], %{{[a-z0-1]+}} +// CHECK: %[[ATT:.+]]:3 = iree_linalg_ext.attention +// CHECK-SAME: ins(%[[Q_S]], %[[K_S]], %[[V_S]], %{{[a-z0-1]+}} // CHECK-SAME: outs(%[[ARG7]], %[[ARG8]], %[[ARG9]] // CHECK: scf.yield %[[ATT]]#0, %[[ATT]]#1, %[[ATT]]#2 // CHECK: } @@ -105,14 +114,19 @@ func.func @attention(%query: tensor, %key: tensor, %value: func.func @attention_f16(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> { %0 = tensor.empty() : tensor<1x1024x64xf16> %scale = arith.constant 0.05 : f16 - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> return %1 : tensor<1x1024x64xf16> } // CHECK-LABEL: @attention_f16 // CHECK: scf.for -// CHECK: iree_linalg_ext.attention ins(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<1024x64xf16>, tensor<1024x64xf16>, tensor<1024x64xf16>, f16 +// CHECK: iree_linalg_ext.attention +// CHECK-SAME: ins(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<1024x64xf16>, tensor<1024x64xf16>, tensor<1024x64xf16>, f16 // CHECK-SAME: outs(%{{.*}}, %{{.*}}, %{{.*}} : // CHECK-SAME: -> tensor<1024x64xf32>, tensor<1024xf32>, tensor<1024xf32> // CHECK: scf.yield @@ -129,13 +143,19 @@ func.func @attention_f16(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf // ----- +// transpose_V is detected through indexingMap. + func.func @attention_transpose_v(%query: tensor<1x1024x64xf16>, %key: tensor<1x1024x64xf16>, %value: tensor<1x64x1024xf16>) -> tensor<1x1024x64xf16> { %0 = tensor.empty() : tensor<1x1024x64xf16> %scale = arith.constant 0.05 : f16 - %1 = iree_linalg_ext.attention {transpose_v = true} ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<1x1024x64xf16>, tensor<1x1024x64xf16>, tensor<1x64x1024xf16>, f16) outs(%0 : tensor<1x1024x64xf16>) -> tensor<1x1024x64xf16> return %1 : tensor<1x1024x64xf16> } // CHECK-LABEL: func.func @attention_transpose_v // CHECK: scf.for -// CHECK: iree_linalg_ext.attention {transpose_v = true} +// CHECK: iree_linalg_ext.attention // CHECK: scf.yield diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir index 0e23d94efacb..42e619326edd 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir @@ -1623,7 +1623,11 @@ module attributes { transform.with_named_sequence } { func.func @attention(%query: tensor<192x1024x64xf32>, %key: tensor<192x1024x64xf32>, %value: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> { %0 = tensor.empty() : tensor<192x1024x64xf32> %scale = arith.constant 1.0 : f32 - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%0 : tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> return %1 : tensor<192x1024x64xf32> } module attributes { transform.with_named_sequence } { @@ -1635,6 +1639,11 @@ module attributes { transform.with_named_sequence } { } // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 192, 10)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 1024, 30)> +// CHECK-DAG: #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +// CHECK-DAG: #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-DAG: #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + // CHECK: func.func @attention(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: // CHECK-SAME: tensor<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> // CHECK-SAME: { @@ -1659,7 +1668,9 @@ module attributes { transform.with_named_sequence } { // CHECK-SAME: 1, 1] : tensor<192x1024x64xf32> to tensor // CHECK: %[[EXTRACTED_SLICE_2:.+]] = tensor.extract_slice %[[ARG6]][%[[ARG3]], %[[ARG5]], 0] [%[[D2]], // CHECK-SAME: %[[D4]], 64] [1, 1, 1] : tensor<192x1024x64xf32> to tensor -// CHECK: %[[D5:.+]] = iree_linalg_ext.attention ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]], +// CHECK: %[[D5:.+]] = iree_linalg_ext.attention +// CHECK-SAME: {indexing_maps = [#[[MAP_Q]], #[[MAP_K]], #[[MAP_V]], #[[MAP_O]]]} +// CHECK-SAME: ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]], // CHECK-SAME: %[[EXTRACTED_SLICE_1]], %[[C1_F32]] : tensor, tensor, tensor, f32) // CHECK-SAME: outs(%[[EXTRACTED_SLICE_2]] : tensor) -> tensor // CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D5]] into %[[ARG6]][%[[ARG3]], %[[ARG5]], 0] @@ -1675,7 +1686,11 @@ module attributes { transform.with_named_sequence } { func.func @attention_memref(%query: memref<192x1024x64xf32>, %key: memref<192x1024x64xf32>, %value: memref<192x1024x64xf32>, %output: memref<192x1024x64xf32>) { %scale = arith.constant 1.0 : f32 - iree_linalg_ext.attention ins(%query, %key, %value, %scale : memref<192x1024x64xf32>, memref<192x1024x64xf32>, memref<192x1024x64xf32>, f32) outs(%output : memref<192x1024x64xf32>) + iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : memref<192x1024x64xf32>, memref<192x1024x64xf32>, memref<192x1024x64xf32>, f32) outs(%output : memref<192x1024x64xf32>) return } module attributes { transform.with_named_sequence } { @@ -1687,6 +1702,11 @@ module attributes { transform.with_named_sequence } { } // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (-d0 + 192, 10)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (-d0 + 1024, 30)> +// CHECK-DAG: #[[MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)> +// CHECK-DAG: #[[MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK-DAG: #[[MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)> + // CHECK: func.func @attention_memref(%[[ARG0:[a-zA-Z0-9_]+]]: memref<192x1024x64xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: // CHECK-SAME: memref<192x1024x64xf32>, %[[ARG2:[a-zA-Z0-9_]+]]: memref<192x1024x64xf32>, %[[ARG3:[a-zA-Z0-9_]+]]: // CHECK-SAME: memref<192x1024x64xf32>) { @@ -1708,7 +1728,9 @@ module attributes { transform.with_named_sequence } { // CHECK-SAME: memref<192x1024x64xf32> to memref> // CHECK: %[[SUBVIEW_2:.+]] = memref.subview %[[ARG3]][%[[ARG4]], %[[ARG5]], 0] [%[[D0]], %[[D1]], 64] [1, 1, // CHECK-SAME: 1] : memref<192x1024x64xf32> to memref> -// CHECK: iree_linalg_ext.attention ins(%[[SUBVIEW]], %[[SUBVIEW_0]], %[[SUBVIEW_1]], %[[C1_F32]] : memref>, memref>, // CHECK-SAME: memref>, f32) outs(%[[SUBVIEW_2]] : // CHECK-SAME: memref>) diff --git a/tests/e2e/linalg_ext_ops/attention.mlir b/tests/e2e/linalg_ext_ops/attention.mlir index f5cea4de59d2..c418809eedd7 100644 --- a/tests/e2e/linalg_ext_ops/attention.mlir +++ b/tests/e2e/linalg_ext_ops/attention.mlir @@ -11,7 +11,11 @@ func.func @attention1x3x4() { [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]]> : tensor<1x3x4xf32> %scale = arith.constant 0.5 : f32 - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<1x3x4xf32>, + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%init : tensor<1x3x4xf32>) -> tensor<1x3x4xf32> check.expect_almost_eq_const( %1, @@ -37,7 +41,11 @@ func.func @attention1x4x4() { [0.9, 1.0, 1.1, 1.2], [1.3, 1.4, 1.5, 1.6]]]> : tensor<1x4x4xf32> %scale = arith.constant 0.5 : f32 - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<1x4x4xf32>, + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<1x4x4xf32>, tensor<1x4x4xf32>, tensor<1x4x4xf32>, f32) outs(%init : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> check.expect_almost_eq_const( %1, @@ -79,7 +87,11 @@ func.func @attention3x3x4() { [-0.5962, -1.0055, 0.4285, 1.4761], [ 1.6103, -0.7040, -0.1853, -0.9962]]]> : tensor<3x3x4xf32> %scale = arith.constant 0.5 : f32 - %1 = iree_linalg_ext.attention ins(%query, %key, %value, %scale : tensor<3x3x4xf32>, + %1 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]} + ins(%query, %key, %value, %scale : tensor<3x3x4xf32>, tensor<3x3x4xf32>, tensor<3x3x4xf32>, f32) outs(%init : tensor<3x3x4xf32>) -> tensor<3x3x4xf32> check.expect_almost_eq_const( %1,