diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUChainedMatmulPass.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUChainedMatmulPass.cpp index bdb16e838e91..d22bd206b053 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUChainedMatmulPass.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUChainedMatmulPass.cpp @@ -4,13 +4,17 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include #include "iree/compiler/Codegen/LLVMGPU/PassDetail.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" +#include "iree/compiler/Codegen/Utils/VectorOpUtils.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" namespace mlir::iree_compiler { +using VectorValue = TypedValue; + namespace { /// Let's assume that we only have vector.contract with the standard indexing @@ -62,6 +66,22 @@ struct AMDGPUPrepareForChainedMatmulPass registry.insert(); } + VectorValue swapDims(RewriterBase &rewriter, VectorValue val, int64_t dimA, + int64_t dimB) const { + ArrayRef shape = val.getType().getShape(); + SmallVector perm(shape.size()); + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[dimA], perm[dimB]); + return rewriter.create(val.getLoc(), val, perm); + } + + AffineMap swapDimsInMap(AffineMap map, int64_t dimA, int64_t dimB) const { + SmallVector results(map.getResults()); + std::swap(results[dimA], results[dimB]); + return AffineMap::get(map.getNumDims(), map.getNumSymbols(), results, + map.getContext()); + } + /// Given a vector contract of the form /// %output = vector.contract %lhs, %rhs, %acc /// this function swaps the operands (%rhs, %lhs), @@ -86,29 +106,45 @@ struct AMDGPUPrepareForChainedMatmulPass /// simply swap the operands without transposing them. void swapOperandsAndTranspose(RewriterBase &rewriter, vector::ContractionOp contractOp) const { - Value lhs = contractOp.getLhs(); - Value rhs = contractOp.getRhs(); - Value acc = contractOp.getAcc(); + VectorContractOpInfo opInfo(contractOp); + auto [lhsM, rhsN] = opInfo.getOperandMNIndex(); + auto [lhsK, rhsK] = opInfo.getOperandKIndex(); + auto [accM, accN] = opInfo.getResultMNIndex(); + VectorValue lhs = contractOp.getLhs(); + VectorValue rhs = contractOp.getRhs(); + VectorValue acc = cast(contractOp.getAcc()); rewriter.setInsertionPoint(contractOp); - acc = rewriter.create(contractOp.getLoc(), acc, - SmallVector{1, 0}); + + SmallVector maps = contractOp.getIndexingMapsArray(); + AffineMap lhsMap = maps[0]; + AffineMap rhsMap = maps[1]; + AffineMap accMap = maps[2]; + + acc = swapDims(rewriter, acc, accN, accM); + accMap = swapDimsInMap(accMap, accN, accM); if (!isOperandSwapInvariant(contractOp)) { - lhs = rewriter.create(contractOp.getLoc(), lhs, - SmallVector{1, 0}); - rhs = rewriter.create(contractOp.getLoc(), rhs, - SmallVector{1, 0}); + lhs = swapDims(rewriter, lhs, lhsK, lhsM); + rhs = swapDims(rewriter, rhs, rhsK, rhsN); + lhsMap = swapDimsInMap(lhsMap, lhsK, lhsM); + rhsMap = swapDimsInMap(rhsMap, rhsK, rhsN); } - vector::ContractionOp swappedOp = rewriter.create( - contractOp.getLoc(), rhs, lhs, acc, contractOp.getIndexingMaps(), + auto swappedOp = rewriter.create( + contractOp.getLoc(), rhs, lhs, acc, + rewriter.getAffineMapArrayAttr({rhsMap, lhsMap, accMap}), contractOp.getIteratorTypesAttr()); - rewriter.replaceOpWithNewOp( - contractOp, swappedOp.getResult(), SmallVector{1, 0}); + + acc = cast(swappedOp.getResult()); + acc = swapDims(rewriter, acc, accN, accM); + + rewriter.replaceOp(contractOp, acc); } - /// For a matmul_transpose_b, this transformation boils down to an operand - /// swap and result transpose: + /// If one of the operands is transposed, while the other isn't, the + /// transformation boils down to an operand swap and result transpose. This + /// happens because transposing and swapping both operands, preserves the + /// structure of the contraction. For example: /// /// def matmul_transpose_b(A, B): /// B.T = transpose(B) @@ -124,7 +160,7 @@ struct AMDGPUPrepareForChainedMatmulPass /// matmul_transpose_b(B, A) = matmul_transpose_b_swapped(B, A).T /// /// For the sake of completeness, we also show that this does not hold - /// for normal matmul: + /// when no operands are transposed, or both operands are transposed: /// /// def matmul(A, B): /// C = A @ B @@ -135,18 +171,15 @@ struct AMDGPUPrepareForChainedMatmulPass /// B.T = transpose(B) /// C.T = B.T @ A.T /// C = transpose(C.T) - /// - /// TODO: This check applies more generally when one of the operands in the - /// function is transposed compared to what "@" expects. bool isOperandSwapInvariant(vector::ContractionOp contractOp) const { - AffineExpr m, n, k; - bindDims(contractOp.getContext(), m, n, k); - using MapList = ArrayRef>; - auto infer = [&](MapList m) { - return AffineMap::inferFromExprList(m, contractOp.getContext()); - }; - SmallVector newIndexingMaps = infer({{m, k}, {n, k}, {m, n}}); - return newIndexingMaps == contractOp.getIndexingMapsArray(); + // Check if the innermost m, n, k dimensions are in the order: + // lhs: (m, k), rhs: (n, k) + VectorContractOpInfo opInfo(contractOp); + auto [lhsM, rhsN] = opInfo.getOperandMNIndex(); + auto [lhsK, rhsK] = opInfo.getOperandKIndex(); + bool isLhsTransposed = lhsM > lhsK; + bool isRhsTransposed = rhsN < rhsK; + return isLhsTransposed != isRhsTransposed; } /// Returns a vector.contract operation that this value was transitively @@ -162,7 +195,9 @@ struct AMDGPUPrepareForChainedMatmulPass FailureOr getTransitiveMatmulParent(vector::ContractionOp contractOp) const { SetVector backwardSlice; - getBackwardSlice(contractOp.getLhs(), &backwardSlice); + BackwardSliceOptions options; + options.inclusive = true; + getBackwardSlice(contractOp.getLhs(), &backwardSlice, options); vector::ContractionOp result; for (Operation *sliceOp : backwardSlice) { auto chainParent = dyn_cast(sliceOp); @@ -173,6 +208,9 @@ struct AMDGPUPrepareForChainedMatmulPass // For now, we only support transpose invariant matmuls. This is because // transposing the inputs may have a non-trivial cost which we need // to think about. + // TODO: We should probably enable it always. Currently, this is + // only useful in Flash Attention, where the first matmul is generally + // a transpose. if (!isOperandSwapInvariant(chainParent)) { continue; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_chained_matmul.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_chained_matmul.mlir index 292f481c6f03..f1d666579302 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_chained_matmul.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_chained_matmul.mlir @@ -12,9 +12,9 @@ } builtin.module { - // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> - // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> - // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> + // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> + // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)> func.func @chained_matmul(%lhs : vector<32x8xf16>, %rhs : vector<16x8xf16>, %acc : vector<32x16xf16>, // CHECK: func.func @chained_matmul(%[[LHS:.*]]: vector<32x8xf16>, %[[RHS:.*]]: vector<16x8xf16>, %[[ACC:.*]]: vector<32x16xf16> // CHECK-SAME: %[[RHS2:.*]]: vector<8x16xf16>, %[[ACC2:.*]]: vector<32x8xf16> @@ -115,10 +115,10 @@ builtin.module { builtin.module { func.func @chained_matmul_mmt_mm(%lhs : vector<32x8xf16>, %rhs : vector<16x8xf16>, %acc : vector<32x16xf16>, - // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> - // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> - // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> - // CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> + // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> + // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> + // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)> + // CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> // CHECK: func.func @chained_matmul_mmt_mm(%[[LHS:.*]]: vector<32x8xf16>, %[[RHS:.*]]: vector<16x8xf16>, %[[ACC:.*]]: vector<32x16xf16> // CHECK-SAME: %[[RHS2:.*]]: vector<16x8xf16>, %[[ACC2:.*]]: vector<32x8xf16> %rhs2 : vector<16x8xf16>, %acc2 : vector<32x8xf16>) -> vector<32x8xf16> { @@ -141,3 +141,49 @@ builtin.module { func.return %result2 : vector<32x8xf16> } } + +// ----- + +#accesses0 = [ + affine_map<(b, m1, m2, n, k) -> (b, m2, m1, k)>, + affine_map<(b, m1, m2, n, k) -> (b, n, k)>, + affine_map<(b, m1, m2, n, k) -> (b, m2, m1, n)> +] + +#trait0 = { + indexing_maps = #accesses0, + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"] +} + +builtin.module { + // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> + // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d4)> + // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d1, d2)> + func.func @chained_matmul(%lhs : vector<17x64x32x8xf16>, + %rhs : vector<17x16x8xf16>, + %acc : vector<17x64x32x16xf16>, + %rhs2 : vector<17x8x16xf16>, + %acc2 : vector<17x64x32x8xf16>) -> vector<17x64x32x8xf16> { + + // CHECK: vector.transpose + // CHECK-NOT: vector.transpose + // CHECK: vector.contract + // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] + %result = vector.contract #trait0 %lhs, %rhs, %acc + : vector<17x64x32x8xf16>, vector<17x16x8xf16> into vector<17x64x32x16xf16> + + // transpose from result will fold with transpose of the acc of the next + // contract + + // CHECK: vector.transpose + // CHECK: vector.transpose + // CHECK-NOT: vector.transpose + // CHECK: vector.contract + // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]] + %result2 = vector.contract #trait0 %result, %rhs2, %acc2 + : vector<17x64x32x16xf16>, vector<17x8x16xf16> into vector<17x64x32x8xf16> + // CHECK: vector.transpose + + func.return %result2 : vector<17x64x32x8xf16> + } +}