Skip to content

Commit

Permalink
[GPU] Add Implicit GEMM pipeline for LLVMGPU. (#16788)
Browse files Browse the repository at this point in the history
Co-authored-by: MaheshRavishankar <mahesh@nod-labs.com>
Co-authored-by: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com>
  • Loading branch information
3 people authored and antiagainst committed Mar 19, 2024
1 parent 88787cb commit d306196
Show file tree
Hide file tree
Showing 16 changed files with 918 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def LLVMGPU_MatmulTensorCoreMmaSync
: I32EnumAttrCase<"LLVMGPUMatmulTensorCoreMmaSync", 109>;
def LLVMGPU_VectorDistribute
: I32EnumAttrCase<"LLVMGPUVectorDistribute", 110>;
def LLVMGPU_ImplicitGEMM
: I32EnumAttrCase<"LLVMGPUImplicitGEMM", 111>;

def SPIRV_BaseLowering
: I32EnumAttrCase<"SPIRVBaseLowering", 200>;
Expand Down Expand Up @@ -90,6 +92,7 @@ def DispatchLoweringPassPipelineEnum : I32EnumAttr<
LLVMGPU_Vectorize, LLVMGPU_MatmulSimt, LLVMGPU_MatmulTensorCore,
LLVMGPU_TransposeSharedMem, LLVMGPU_WarpReduction, LLVMGPU_PackUnPack,
LLVMGPU_MatmulTensorCoreMmaSync, LLVMGPU_VectorDistribute,
LLVMGPU_ImplicitGEMM,

// SPIR-V CodeGen pipelines
SPIRV_BaseLowering, SPIRV_BaseDistribute, SPIRV_BaseVectorize,
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,18 @@ iree_compiler_cc_library(
"KernelConfig.cpp",
"LLVMGPUCastAddressSpaceFunction.cpp",
"LLVMGPUCastTypeToFitMMA.cpp",
"LLVMGPUIm2ColPass.cpp",
"LLVMGPULowerExecutableTarget.cpp",
"LLVMGPUNormalizeContractMaps.cpp",
"LLVMGPUPackSharedMemoryAlloc.cpp",
"LLVMGPUPadIGemm.cpp",
"LLVMGPUPrefetching.cpp",
"LLVMGPURewritePadInDestinationPassingStyle.cpp",
"LLVMGPUSelectLoweringStrategy.cpp",
"LLVMGPUTensorCoreVectorization.cpp",
"LLVMGPUTensorPad.cpp",
"LLVMGPUTileAndDistribute.cpp",
"LLVMGPUTileMatmulAndFuseImg2Col.cpp",
"LLVMGPUVectorDistribute.cpp",
"LLVMGPUVectorLowering.cpp",
"LLVMGPUVectorToGPU.cpp",
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,18 @@ iree_cc_library(
"KernelConfig.cpp"
"LLVMGPUCastAddressSpaceFunction.cpp"
"LLVMGPUCastTypeToFitMMA.cpp"
"LLVMGPUIm2ColPass.cpp"
"LLVMGPULowerExecutableTarget.cpp"
"LLVMGPUNormalizeContractMaps.cpp"
"LLVMGPUPackSharedMemoryAlloc.cpp"
"LLVMGPUPadIGemm.cpp"
"LLVMGPUPrefetching.cpp"
"LLVMGPURewritePadInDestinationPassingStyle.cpp"
"LLVMGPUSelectLoweringStrategy.cpp"
"LLVMGPUTensorCoreVectorization.cpp"
"LLVMGPUTensorPad.cpp"
"LLVMGPUTileAndDistribute.cpp"
"LLVMGPUTileMatmulAndFuseImg2Col.cpp"
"LLVMGPUVectorDistribute.cpp"
"LLVMGPUVectorLowering.cpp"
"LLVMGPUVectorToGPU.cpp"
Expand Down
206 changes: 206 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <numeric>
#include <optional>

#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
#include "iree-dialects/Transforms/TransformMatchers.h"
#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
Expand Down Expand Up @@ -45,6 +47,11 @@ llvm::cl::opt<bool> clGPUEnableTransformDialectJit(
llvm::cl::desc("enable the usage of the transform dialect JIT"),
llvm::cl::init(true));

llvm::cl::opt<bool> clGPUEnableImplicitGemm(
"iree-codegen-llvmgpu-enable-implicit-gemm",
llvm::cl::desc("activate the convolution implicit gemm strategy"),
llvm::cl::init(false));

/// Flag to force using WMMA tensorcore operations.
llvm::cl::opt<bool>
clGPUUseWMMA("iree-codegen-llvmgpu-use-wmma",
Expand Down Expand Up @@ -1602,6 +1609,199 @@ static bool distributeToSquare(const int64_t oh, const int64_t ow,
// Convolution Pipeline Configuration
//====---------------------------------------------------------------------===//

static LogicalResult
setConvolutionIGemmConfig(mlir::FunctionOpInterface entryPoint,
linalg::LinalgOp op, const TargetInfo &targetInfo) {
FailureOr<ArrayAttr> mmaKinds = getSupportedMmaTypes(entryPoint);
if (failed(mmaKinds)) {
return failure();
}

// This pipeline needs to know the subgroup size for distributing to virtual
// lane IDs.
if (targetInfo.supportedSubgroupSizes.empty()) {
return failure();
}
const int64_t subgroupSize = targetInfo.supportedSubgroupSizes.front();

Value input = op.getDpsInputOperand(0)->get();
Value filter = op.getDpsInputOperand(1)->get();
Value output = op.getDpsInitOperand(0)->get();
auto filterType = cast<ShapedType>(filter.getType());
auto outputType = cast<ShapedType>(output.getType());
ArrayRef<int64_t> filterShape = filterType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();

int64_t oh, ow, oc, fh, fw, ic, matmulM, matmulN, matmulK;
LogicalResult matchResult =
TypeSwitch<Operation *, LogicalResult>(op.getOperation())
.Case<linalg::Conv2DNhwcHwcfOp>([&](auto) {
oh = outputShape[1];
ow = outputShape[2];
oc = outputShape[3];
fh = filterShape[0];
fw = filterShape[1];
ic = filterShape[2];
matmulM = oh * ow;
matmulN = oc;
matmulK = fh * fw * ic;
return success();
})
.Case<linalg::DepthwiseConv2DNhwcHwcOp>([&](auto) {
oh = outputShape[1];
ow = outputShape[2];
fh = filterShape[0];
fw = filterShape[1];
matmulM = oh * ow;
matmulN = 1;
matmulK = fh * fw;
return success();
})
.Case<linalg::Conv2DNchwFchwOp>([&](auto) {
oc = outputShape[1];
oh = outputShape[2];
ow = outputShape[3];
ic = filterShape[1];
fh = filterShape[2];
fw = filterShape[3];
matmulM = oc;
matmulN = oh * ow;
matmulK = ic * fh * fw;
return success();
})
.Case<linalg::Conv2DNhwcFhwcOp>([&](auto) {
oh = outputShape[1];
ow = outputShape[2];
oc = outputShape[3];
fh = filterShape[1];
fw = filterShape[2];
ic = filterShape[3];
matmulM = oh * ow;
matmulN = oc;
matmulK = fh * fw * ic;
return success();
})
.Default([](auto) { return failure(); });
if (failed(matchResult)) {
return failure();
}

Type inputElemType = getElementTypeOrSelf(input);
Type filterElemType = getElementTypeOrSelf(filter);
Type outputElemType = getElementTypeOrSelf(output);

GPUMatmulShapeType problem{matmulM, matmulN, matmulK,
inputElemType, filterElemType, outputElemType};

auto mmaAttrs = llvm::to_vector(mmaKinds->getAsRange<IREE::GPU::MmaAttr>());
SmallVector<GPUMatmulShapeType> intrinsics;
intrinsics.reserve(mmaKinds->size());
for (auto mma : mmaAttrs) {
auto [mSize, nSize, kSize] = mma.getMNKShape();
auto [aType, bType, cType] = mma.getABCElementTypes();
intrinsics.emplace_back(mSize, nSize, kSize, aType, bType, cType);
}

int64_t mSize = matmulM;
int64_t nSize = matmulN;
GPUMMAHeuristicSeeds seeds;

// Note that the following heuristic seeds are just placeholder values.
// We need to clean it up and make it adjusting to different targets.
// See https://github.com/openxla/iree/issues/16341 for details.
if (mSize * nSize <= clGPUMatmulCThreshold) {
// For matmuls with small M*N size, we want to distribute M*N onto more
// workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup
// and a larger bestKTileCountPerSubgroup.
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
/*bestMNTileCountPerSubgroup=*/4,
/*bestKTileCountPerSubgroup=*/8};
} else {
seeds = {/*bestSubgroupCountPerWorkgroup=*/4,
/*bestMNTileCountPerSubgroup=*/8,
/*bestKTileCountPerSubgroup=*/4};
}

// First try to find a schedule with an exactly matching intrinsic.
std::optional<GPUMMASchedule> schedule =
deduceMMASchedule(problem, intrinsics, seeds);
if (!schedule) {
// Then try again by allowing upcasting accumulator.
schedule =
deduceMMASchedule(problem, intrinsics, seeds, /*canUpcastAcc=*/true);
}
if (!schedule) {
return failure();
}

std::array<int64_t, 3> workgroupSize{schedule->nWarpCount * subgroupSize,
schedule->mWarpCount, 1};

// 1. Match a convolution and surrounding ops.
transform_ext::StructuredOpMatcher *fill;
transform_ext::StructuredOpMatcher *convolution;
transform_ext::StructuredOpMatcher *trailing;
transform_ext::MatchedConvolutionCaptures captures;
transform_ext::MatcherContext matcherContext;
makeConvolutionMatcher(matcherContext, convolution, fill, trailing, captures,
/*mustMatchEntireFunc=*/true);
if (!matchPattern(op, *convolution)) {
return failure();
}

// We are very peculiar about the dispatches we want to match for now:
// - f32 or f16 only atm.
// - Mandatory fill op.
// - Require minimum tile alignment due to img2col.
// - Otherwise, we take it.
if (!fill->getCaptured() || trailing->getCaptured()) {
return failure();
}

// Currently requires a typical 2d named convolution (conv_2d_nchw/nhwc).
if (captures.convolutionDims.outputChannel.size() != 1) {
return failure();
}
if (captures.convolutionDims.inputChannel.size() != 1) {
return failure();
}
if (captures.convolutionDims.outputImage.size() != 2) {
return failure();
}
if (captures.convolutionDims.filterLoop.size() != 2) {
return failure();
}
if (captures.convolutionDims.batch.size() != 1) {
return failure();
}

SmallVector<int64_t> distTileSizes = {
1, schedule->mWarpCount * schedule->mTileCount * schedule->mSize,
schedule->nWarpCount * schedule->nTileCount * schedule->nSize};
TileSizesListType tileSizesList = {distTileSizes};
SmallVector<int64_t> matmulRedTileSizes = {
0, 0, 0, schedule->kSize * schedule->kTileCount};
tileSizesList.push_back(matmulRedTileSizes);

// Attach the MMA schedule as an attribute to the entry point export function
// for later access in the pipeline.
MLIRContext *context = op.getContext();
auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get(
context, mmaAttrs[schedule->index], schedule->mWarpCount,
schedule->nWarpCount, schedule->mTileCount, schedule->nTileCount,
schedule->kTileCount);
SmallVector<NamedAttribute, 1> attrs;
attrs.emplace_back(
StringAttr::get(context, IREE::GPU::MMAScheduleAttr::getMnemonic()),
scheduleAttr);
auto configDict = DictionaryAttr::get(context, attrs);

return setOpConfigAndEntryPointFnTranslation(
entryPoint, op, tileSizesList,
IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUImplicitGEMM,
workgroupSize, subgroupSize, configDict);
}

static LogicalResult setConvolutionConfig(linalg::LinalgOp linalgOp,
const int64_t subgroupSize,
const int64_t bestTilingFactor) {
Expand Down Expand Up @@ -1720,6 +1920,12 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
if (succeeded(setWarpReductionConfig(entryPointFn, linalgOp, targetInfo))) {
return success();
}
if (clGPUEnableImplicitGemm) {
if (succeeded(
setConvolutionIGemmConfig(entryPointFn, linalgOp, targetInfo))) {
return success();
}
}
if (succeeded(setConvolutionConfig(
linalgOp, targetInfo.supportedSubgroupSizes.front(), 16))) {
return success();
Expand Down
81 changes: 81 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUIm2ColPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler {

namespace {

class LLVMGPUIm2ColPass : public LLVMGPUIm2ColBase<LLVMGPUIm2ColPass> {

public:
void getDependentDialects(DialectRegistry &registry) const override {}

void runOnOperation() override;
};

} // namespace

template <typename OpTy>
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, OpTy op,
SmallVector<NamedAttribute> &additionalAttributes) {
additionalAttributes = linalg::getPrunedAttributeList(op);
return rewriteInIm2Col(rewriter, op);
}

void LLVMGPUIm2ColPass::runOnOperation() {
auto operation = getOperation();

IRRewriter rewriter(&getContext());
operation->walk([&](Operation *op) {
IRRewriter::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(op);
SmallVector<NamedAttribute> additionalAttributes;
auto maybeTransformed =
TypeSwitch<Operation *, FailureOr<std::pair<Operation *, Operation *>>>(
op)
.Case<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcFhwcOp,
linalg::DepthwiseConv2DNhwcHwcOp, linalg::Conv2DNchwFchwOp>(
[&](auto op) {
return rewriteInIm2Col(rewriter, op, additionalAttributes);
})
.Default([&](Operation *op) {
return rewriter.notifyMatchFailure(op, "not supported");
});
if (failed(maybeTransformed)) {
return;
}
auto matmulOp = cast<tensor::ExpandShapeOp>(maybeTransformed->second)
.getSrc()
.getDefiningOp();
for (auto attr : additionalAttributes) {
matmulOp->setAttr(attr.getName(), attr.getValue());
}
});

// Bubble collapse
RewritePatternSet patterns(&getContext());
linalg::populateFoldReshapeOpsByCollapsingPatterns(
patterns, [](OpOperand *) { return true; });
populateReshapeToInterfaceTensorPatterns(patterns);
linalg::FillOp::getCanonicalizationPatterns(patterns, &getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMGPUIm2ColPass() {
return std::make_unique<LLVMGPUIm2ColPass>();
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ void LLVMGPULowerExecutableTargetPass::runOnOperation() {
pipeline, codegenSpec ? codegenSpec.getLeafReference() : StringRef(""));
break;
}
case IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUImplicitGEMM: {
addGPUImplicitGEMMPassPipeline(pipeline);
break;
}
// no pipeline specified, nothing to do.
case IREE::Codegen::DispatchLoweringPassPipeline::None:
return;
Expand Down

0 comments on commit d306196

Please sign in to comment.