diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index b7760ea7f2b0..993ab3bbb0d6 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -542,7 +542,7 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", statistics while processing the current tile. With indexing_maps we can also infer when variables are transposed. - for example, if V has indexing_map of (m, k1, k2, n) -> (n, k2) insted + for example, if V has indexing_map of (m, k1, k2, n) -> (n, k2) instead of (m, k1, k2, n) -> (k2, n), then we'd know that V is transposed (V.T), and we can expect computation to look like: @@ -669,8 +669,9 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention", return sumType->getRank(); } - // Use indexing_map to infer if V is transposed. - bool getTransposeV() { + // Since we expect map_V to be (b, m, k1, k2, n) -> (b, k2, n), this function + // returns true if n is not the fastest dim. + bool isTransposeV() { AffineMap vMap = getValueMap(); auto vFastDim = llvm::cast(vMap.getResults().back()); int64_t nDim = vMap.getNumInputs() - 1; diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp index 2cd851dc58a0..daf50dcd2e30 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/DecomposeAttention.cpp @@ -331,7 +331,7 @@ void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp, auto [result, newMax, newSum] = createAttentionBody( keySlice, valueSlice, querySlice, tiledResult, max, sum, sequenceTileLength, keyValueTileLength, headDimension, elementType, ops, - tiledAttnOp.getTransposeV(), loc, rewriter); + tiledAttnOp.isTransposeV(), loc, rewriter); rewriter.replaceOp(tiledAttnOp, ValueRange{result, newMax, newSum}); } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp index 9c9d2ac31963..384ff558d8b7 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TileAttention.cpp @@ -297,7 +297,7 @@ IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp, headDimension, elementType, loc, rewriter); Value valueSlice = extractSlice(value, keyShape, ivs, keyValueTileLength, headDimension, - elementType, loc, rewriter, attnOp.getTransposeV()); + elementType, loc, rewriter, attnOp.isTransposeV()); Value querySlice = extractSlice(query, queryShape, {}, sequenceTileLength, headDimension, elementType, loc, rewriter); @@ -305,7 +305,7 @@ IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp, int64_t tiledInputRank = cast(querySlice.getType()).getRank(); SmallVector tiledIndexingMaps = getTileAttentionIndexingMaps( - rewriter, tiledInputRank, attnOp.getTransposeV()); + rewriter, tiledInputRank, attnOp.isTransposeV()); auto tiledAttentionOp = rewriter.create( attnOp.getLoc(),