Skip to content

Commit

Permalink
[LLVMGPU] Generalize AMDGPUChainedMatmul pass to multiple dimensions (i…
Browse files Browse the repository at this point in the history
…ree-org#17684)

This patch generalizes the AMDGPUChainedMatmul pass to use
VectorContractOpInfo to query and transpose dims, instead of hardcoding
indexing maps.
  • Loading branch information
Groverkss committed Jun 18, 2024
1 parent 3835c8b commit 1f954b2
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <numeric>
#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"

namespace mlir::iree_compiler {

using VectorValue = TypedValue<VectorType>;

namespace {

/// Let's assume that we only have vector.contract with the standard indexing
Expand Down Expand Up @@ -62,6 +66,22 @@ struct AMDGPUPrepareForChainedMatmulPass
registry.insert<vector::VectorDialect>();
}

VectorValue swapDims(RewriterBase &rewriter, VectorValue val, int64_t dimA,
int64_t dimB) const {
ArrayRef<int64_t> shape = val.getType().getShape();
SmallVector<int64_t> perm(shape.size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[dimA], perm[dimB]);
return rewriter.create<vector::TransposeOp>(val.getLoc(), val, perm);
}

AffineMap swapDimsInMap(AffineMap map, int64_t dimA, int64_t dimB) const {
SmallVector<AffineExpr> results(map.getResults());
std::swap(results[dimA], results[dimB]);
return AffineMap::get(map.getNumDims(), map.getNumSymbols(), results,
map.getContext());
}

/// Given a vector contract of the form
/// %output = vector.contract %lhs, %rhs, %acc
/// this function swaps the operands (%rhs, %lhs),
Expand All @@ -86,29 +106,45 @@ struct AMDGPUPrepareForChainedMatmulPass
/// simply swap the operands without transposing them.
void swapOperandsAndTranspose(RewriterBase &rewriter,
vector::ContractionOp contractOp) const {
Value lhs = contractOp.getLhs();
Value rhs = contractOp.getRhs();
Value acc = contractOp.getAcc();
VectorContractOpInfo opInfo(contractOp);
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
auto [accM, accN] = opInfo.getResultMNIndex();
VectorValue lhs = contractOp.getLhs();
VectorValue rhs = contractOp.getRhs();
VectorValue acc = cast<VectorValue>(contractOp.getAcc());
rewriter.setInsertionPoint(contractOp);
acc = rewriter.create<vector::TransposeOp>(contractOp.getLoc(), acc,
SmallVector<int64_t>{1, 0});

SmallVector<AffineMap> maps = contractOp.getIndexingMapsArray();
AffineMap lhsMap = maps[0];
AffineMap rhsMap = maps[1];
AffineMap accMap = maps[2];

acc = swapDims(rewriter, acc, accN, accM);
accMap = swapDimsInMap(accMap, accN, accM);

if (!isOperandSwapInvariant(contractOp)) {
lhs = rewriter.create<vector::TransposeOp>(contractOp.getLoc(), lhs,
SmallVector<int64_t>{1, 0});
rhs = rewriter.create<vector::TransposeOp>(contractOp.getLoc(), rhs,
SmallVector<int64_t>{1, 0});
lhs = swapDims(rewriter, lhs, lhsK, lhsM);
rhs = swapDims(rewriter, rhs, rhsK, rhsN);
lhsMap = swapDimsInMap(lhsMap, lhsK, lhsM);
rhsMap = swapDimsInMap(rhsMap, rhsK, rhsN);
}

vector::ContractionOp swappedOp = rewriter.create<vector::ContractionOp>(
contractOp.getLoc(), rhs, lhs, acc, contractOp.getIndexingMaps(),
auto swappedOp = rewriter.create<vector::ContractionOp>(
contractOp.getLoc(), rhs, lhs, acc,
rewriter.getAffineMapArrayAttr({rhsMap, lhsMap, accMap}),
contractOp.getIteratorTypesAttr());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(
contractOp, swappedOp.getResult(), SmallVector<int64_t>{1, 0});

acc = cast<VectorValue>(swappedOp.getResult());
acc = swapDims(rewriter, acc, accN, accM);

rewriter.replaceOp(contractOp, acc);
}

/// For a matmul_transpose_b, this transformation boils down to an operand
/// swap and result transpose:
/// If one of the operands is transposed, while the other isn't, the
/// transformation boils down to an operand swap and result transpose. This
/// happens because transposing and swapping both operands, preserves the
/// structure of the contraction. For example:
///
/// def matmul_transpose_b(A, B):
/// B.T = transpose(B)
Expand All @@ -124,7 +160,7 @@ struct AMDGPUPrepareForChainedMatmulPass
/// matmul_transpose_b(B, A) = matmul_transpose_b_swapped(B, A).T
///
/// For the sake of completeness, we also show that this does not hold
/// for normal matmul:
/// when no operands are transposed, or both operands are transposed:
///
/// def matmul(A, B):
/// C = A @ B
Expand All @@ -135,18 +171,15 @@ struct AMDGPUPrepareForChainedMatmulPass
/// B.T = transpose(B)
/// C.T = B.T @ A.T
/// C = transpose(C.T)
///
/// TODO: This check applies more generally when one of the operands in the
/// function is transposed compared to what "@" expects.
bool isOperandSwapInvariant(vector::ContractionOp contractOp) const {
AffineExpr m, n, k;
bindDims(contractOp.getContext(), m, n, k);
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [&](MapList m) {
return AffineMap::inferFromExprList(m, contractOp.getContext());
};
SmallVector<AffineMap> newIndexingMaps = infer({{m, k}, {n, k}, {m, n}});
return newIndexingMaps == contractOp.getIndexingMapsArray();
// Check if the innermost m, n, k dimensions are in the order:
// lhs: (m, k), rhs: (n, k)
VectorContractOpInfo opInfo(contractOp);
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
bool isLhsTransposed = lhsM > lhsK;
bool isRhsTransposed = rhsN < rhsK;
return isLhsTransposed != isRhsTransposed;
}

/// Returns a vector.contract operation that this value was transitively
Expand All @@ -162,7 +195,9 @@ struct AMDGPUPrepareForChainedMatmulPass
FailureOr<vector::ContractionOp>
getTransitiveMatmulParent(vector::ContractionOp contractOp) const {
SetVector<Operation *> backwardSlice;
getBackwardSlice(contractOp.getLhs(), &backwardSlice);
BackwardSliceOptions options;
options.inclusive = true;
getBackwardSlice(contractOp.getLhs(), &backwardSlice, options);
vector::ContractionOp result;
for (Operation *sliceOp : backwardSlice) {
auto chainParent = dyn_cast<vector::ContractionOp>(sliceOp);
Expand All @@ -173,6 +208,9 @@ struct AMDGPUPrepareForChainedMatmulPass
// For now, we only support transpose invariant matmuls. This is because
// transposing the inputs may have a non-trivial cost which we need
// to think about.
// TODO: We should probably enable it always. Currently, this is
// only useful in Flash Attention, where the first matmul is generally
// a transpose.
if (!isOperandSwapInvariant(chainParent)) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
}

builtin.module {
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)>
func.func @chained_matmul(%lhs : vector<32x8xf16>, %rhs : vector<16x8xf16>, %acc : vector<32x16xf16>,
// CHECK: func.func @chained_matmul(%[[LHS:.*]]: vector<32x8xf16>, %[[RHS:.*]]: vector<16x8xf16>, %[[ACC:.*]]: vector<32x16xf16>
// CHECK-SAME: %[[RHS2:.*]]: vector<8x16xf16>, %[[ACC2:.*]]: vector<32x8xf16>
Expand Down Expand Up @@ -115,10 +115,10 @@ builtin.module {

builtin.module {
func.func @chained_matmul_mmt_mm(%lhs : vector<32x8xf16>, %rhs : vector<16x8xf16>, %acc : vector<32x16xf16>,
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)>
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
// CHECK: func.func @chained_matmul_mmt_mm(%[[LHS:.*]]: vector<32x8xf16>, %[[RHS:.*]]: vector<16x8xf16>, %[[ACC:.*]]: vector<32x16xf16>
// CHECK-SAME: %[[RHS2:.*]]: vector<16x8xf16>, %[[ACC2:.*]]: vector<32x8xf16>
%rhs2 : vector<16x8xf16>, %acc2 : vector<32x8xf16>) -> vector<32x8xf16> {
Expand All @@ -141,3 +141,49 @@ builtin.module {
func.return %result2 : vector<32x8xf16>
}
}

// -----

#accesses0 = [
affine_map<(b, m1, m2, n, k) -> (b, m2, m1, k)>,
affine_map<(b, m1, m2, n, k) -> (b, n, k)>,
affine_map<(b, m1, m2, n, k) -> (b, m2, m1, n)>
]

#trait0 = {
indexing_maps = #accesses0,
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
}

builtin.module {
// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d1, d4)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d1, d2)>
func.func @chained_matmul(%lhs : vector<17x64x32x8xf16>,
%rhs : vector<17x16x8xf16>,
%acc : vector<17x64x32x16xf16>,
%rhs2 : vector<17x8x16xf16>,
%acc2 : vector<17x64x32x8xf16>) -> vector<17x64x32x8xf16> {

// CHECK: vector.transpose
// CHECK-NOT: vector.transpose
// CHECK: vector.contract
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
%result = vector.contract #trait0 %lhs, %rhs, %acc
: vector<17x64x32x8xf16>, vector<17x16x8xf16> into vector<17x64x32x16xf16>

// transpose from result will fold with transpose of the acc of the next
// contract

// CHECK: vector.transpose
// CHECK: vector.transpose
// CHECK-NOT: vector.transpose
// CHECK: vector.contract
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
%result2 = vector.contract #trait0 %result, %rhs2, %acc2
: vector<17x64x32x16xf16>, vector<17x8x16xf16> into vector<17x64x32x8xf16>
// CHECK: vector.transpose

func.return %result2 : vector<17x64x32x8xf16>
}
}

0 comments on commit 1f954b2

Please sign in to comment.