diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td index 4af89f26a..41a3fd070 100644 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.td +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.td @@ -76,4 +76,35 @@ def ConvertMicrokernelToDnnlFunc: Pass<"convert-microkernel-to-dnnl-func", "::ml "microkernel::MicrokernelDialect"]; } +def EarlyDispatchMicrokernel: Pass<"early-dispatch-microkernel", "::mlir::ModuleOp"> { + let summary = "Early dispatch microkernel during compile time"; + let description = [{ + Early dispatch microkernel during compile time. + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", + "LLVM::LLVMDialect", + "microkernel::MicrokernelDialect"]; +} + +def MergeBranchMicrokernelContext: Pass<"merge-branch-microkernel-context", "::mlir::ModuleOp"> { + let summary = "Find and merge identical microkernel context operations in branches into one"; + let description = [{ + Find and merge identical microkernel context operations in branches into one. + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect"]; +} + +def MicrokernelInvariantCodeMotion: Pass<"microkernel-invariant-code-motion", "::mlir::ModuleOp"> { + let summary = "Hoist invariant microkernel code to avoid redundant execution"; + let description = [{ + Hoist invariant microkernel code to avoid redundant execution. + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", + "LLVM::LLVMDialect", + "microkernel::MicrokernelDialect"]; +} + #endif // GC_DIALECT_MICROKERNELPASSES diff --git a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp index 20f736fc7..9dc940fcd 100644 --- a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp +++ b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp @@ -213,6 +213,42 @@ static LogicalResult verifyBrgemmFlags(ArrayAttr flags, Operation *op, return success(); } +static bool isTypeSupported(Type outType, Type operandAType, + Type operandBType) { + if (!outType.isF32() && !outType.isSignedInteger(32)) + return false; + + if (outType.isF32()) { + if (!(operandAType.isF32() && operandBType.isF32()) && + !(operandAType.isBF16() && operandBType.isBF16())) + return false; + } + if (outType.isSignedInteger(32)) { + if (!(operandAType.isSignedInteger(8) || + operandAType.isUnsignedInteger(8)) && + (operandBType.isSignedInteger(8) || operandBType.isUnsignedInteger(8))) + return false; + } + return true; +} + +// TODO(haixin): could use compiler-wide VNNI utils? +static bool isInVnniLayout(ShapedType type) { + if (!type.getElementType().isBF16() && + !type.getElementType().isSignedInteger(8) && + !type.getElementType().isUnsignedInteger(8)) + return false; + + auto blockingFactor = 0; + if (type.getElementType().isBF16()) + blockingFactor = 2; + else if (type.getElementType().isSignedInteger(8) || + type.getElementType().isUnsignedInteger(8)) + blockingFactor = 4; + + return type.getShape().back() == blockingFactor; +} + ///////////////////////////////////////////////////// // Start of BrgemmOp @@ -308,9 +344,8 @@ static inline ArrayRef getShapedValueShape(Value val) { assert((llvm::isa(val.getType()) || llvm::isa(val.getType())) && "Expecting shaped value"); - if (auto tensorTy = dyn_cast_or_null(val.getType())) { + if (auto tensorTy = dyn_cast_or_null(val.getType())) return tensorTy.getShape(); - } auto memrefTy = dyn_cast_or_null(val.getType()); return memrefTy.getShape(); } @@ -331,15 +366,27 @@ LogicalResult BrgemmOp::verify() { return op.emitOpError() << "expect inputs and its related info to be size 2\n"; + auto elemTypeA = getElementTypeOrSelf(ins[0]); + auto elemTypeB = getElementTypeOrSelf(ins[1]); + auto elemTypeC = getElementTypeOrSelf(out); + if (!isTypeSupported(elemTypeC, elemTypeA, elemTypeB)) + return op.emitOpError() << "unsupported input matrix types\n"; + ArrayRef dimA = getShapedValueShape(ins[0]); ArrayRef dimB = getShapedValueShape(ins[1]); ArrayRef dimC = getShapedValueShape(out); if (dimA.size() != 3) return op.emitOpError() << "expect input A to be 3D\n"; - if (dimB.size() != 3 && dimB.size() != 4) - return op.emitOpError() << "expect input B to be 3D or 4D\n"; - if (dimB.size() == 4 && (dimB[3] != 2 && dimB[3] != 4)) - return op.emitOpError() << "expect input B vnni step to be 2 or 4\n"; + if (!elemTypeB.isF32()) { + if (dimB.size() != 4 || + !isInVnniLayout(dyn_cast(ins[1].getType()))) + return op.emitOpError() + << "expect a 4d VNNI input B for non-F32 operand: " << ins[1]; + } else { + if (dimB.size() != 3) + return op.emitOpError() + << "expect a 3d input B for F32 operand: " << ins[1]; + } if (dimC.size() != 2) return op.emitOpError() << "expect input C to be 2D\n"; for (auto dim : batchDims) @@ -558,42 +605,6 @@ LogicalResult BrgemmDispatchOp::verify() { ///////////////////////////////////////////////////// // Start of BrgemmExecuteOp -// TODO(haixin): could use compiler-wide VNNI utils? -static bool isInVnniLayout(MemRefType memref) { - if (!memref.getElementType().isBF16() && - !memref.getElementType().isSignedInteger(8) && - !memref.getElementType().isUnsignedInteger(8)) - return false; - - auto blockingFactor = 0; - if (memref.getElementType().isBF16()) - blockingFactor = 2; - else if (memref.getElementType().isSignedInteger(8) || - memref.getElementType().isUnsignedInteger(8)) - blockingFactor = 4; - - return memref.getShape().back() == blockingFactor; -} - -static bool isTypeSupported(Type outType, Type operandAType, - Type operandBType) { - if (!outType.isF32() && !outType.isSignedInteger(32)) - return false; - - if (outType.isF32()) { - if (!(operandAType.isF32() && operandBType.isF32()) && - !(operandAType.isBF16() && operandBType.isBF16())) - return false; - } - if (outType.isSignedInteger(32)) { - if (!(operandAType.isSignedInteger(8) || - operandAType.isUnsignedInteger(8)) && - (operandBType.isSignedInteger(8) || operandBType.isUnsignedInteger(8))) - return false; - } - return true; -} - LogicalResult BrgemmExecuteOp::verify() { BrgemmExecuteOp &brgemmOp = *this; diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 243e17580..484c2ef4b 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(Utils) gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR MLIRSupport + MLIRMicrokernelTransforms MLIRBufferizationToMemRef MLIRBufferizationPipelines) diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt index 9064b70db..9dbac3e94 100644 --- a/lib/gc/Transforms/Microkernel/CMakeLists.txt +++ b/lib/gc/Transforms/Microkernel/CMakeLists.txt @@ -1,4 +1,4 @@ -gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR) +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR MLIRMicrokernel) include(onednn) @@ -6,6 +6,9 @@ gc_add_mlir_dialect_library(MLIRMicrokernelTransforms ConvertLinalgToMicrokernel.cpp ExpandMicrokernel.cpp ConvertMicrokernelToDnnlFunc.cpp + EarlyDispatchMicrokernel.cpp + MicrokernelInvariantCodeMotion.cpp + MergeBranchMicrokernelContext.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index 980fe8288..f89a31179 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -157,6 +157,23 @@ static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { else return failure(); + OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; + OpOperand *operandB = linalgOp.getDpsInputOperands()[1]; + Type operandBElemType = getElementTypeOrSelf(operandB->get()); + if (operandBElemType.isF32()) { + if (kAffinePos.size() == 2) { + LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions for input " + "B, should be non-VNNI\n"); + return failure(); + } + } else { + if (kAffinePos.size() == 1) { + LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions for input " + "B, should be VNNI\n"); + return failure(); + } + } + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] Candidate dims: " << "\n"); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] m pos in affine: " << mAffinePos @@ -169,9 +186,6 @@ static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] batch pos in affine: " << batchAffinePos << "\n"); - OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; - OpOperand *operandB = linalgOp.getDpsInputOperands()[1]; - BrgemmDims brgemmDims; #define CHECK_GET_POS_IN_DOMAIN(dim, dimPos, operand) \ diff --git a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp new file mode 100644 index 000000000..2f66feee4 --- /dev/null +++ b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp @@ -0,0 +1,214 @@ +//===-- EarlyDispatchMicrokernel.cpp - Dispatch before runtime --*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + +#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" +#include "gc/Transforms/Utils/ValueUtils.h" +#include "oneapi/dnnl/dnnl_types.h" + +namespace mlir::microkernel { +#define GEN_PASS_DEF_EARLYDISPATCHMICROKERNEL +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define DEBUG_TYPE "early-dispatch-microkernel" + +static FailureOr +createGlobalKernelHandleName(RewriterBase &rewriter, + microkernel::BrgemmDispatchOp op) { + // TODO(haixin): Add runtime backend type to global name + std::stringstream ss; + ss << "g_dispatched_microkernel_brgemm"; + + bool isInit = false; + bool isStrideMode = false; + auto flags = op.getFlagsAttr(); + for (auto flag : flags) { + auto brgemmFlag = dyn_cast_or_null(flag); + if (!brgemmFlag) + return failure(); + if (brgemmFlag.getValue() == BrgemmFlags::LIST) + // TODO(haixin): Currently not supported. Support list brgemm in the + // future + return failure(); + else if (brgemmFlag.getValue() == BrgemmFlags::STRIDE) + isStrideMode = true; + else if (brgemmFlag.getValue() == BrgemmFlags::BETA_0) + isInit = true; + } + if (isStrideMode) + ss << "_stride"; + else + ss << "_list"; + if (isInit) + ss << "_init"; + + // M, N, K, LDA, LDB, LDC, stride_a, stride_b + // they are in the same order with BrgemmDispatchOp inputs + // TODO(haixin): Add order enforcement machanism for BrgemmDispatchOp + ArrayRef inputs = op.getInputsAttr().asArrayRef(); + for (auto input : inputs) { + ss << "_" << input; + } + + // dtypeA, dtypeB + auto dtypes = op.getDataType(); + if (dtypes.size() != 2) + return failure(); + ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[0]); + ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[1]); + + return ss.str(); +} + +// get or create global kernel handle with initializer, identified by +// `kernelName` +static FailureOr +getOrCreateGlobalKernelHandle(RewriterBase &rewriter, ModuleOp module, + const std::string &kernelName, + microkernel::BrgemmDispatchOp op) { + // Create the global at the entry of the module + LLVM::GlobalOp global = module.lookupSymbol(kernelName); + if (global) + return global; + + auto global_type = op.getResults().getType(); + FlatSymbolRefAttr ctorName = + SymbolRefAttr::get(module->getContext(), kernelName + "_ctor"); + if (module.lookupSymbol(ctorName.getAttr())) + return failure(); + + OpBuilder::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + global = rewriter.create( + module.getLoc(), global_type, /*isConstant=*/false, + LLVM::Linkage::Internal, kernelName, Attribute(), + /*alignment=*/0); + + // create ctor for this global, which needs to be LLVMFuncOp + LLVM::LLVMFuncOp ctorFunc = rewriter.create( + module.getLoc(), ctorName.getValue(), + LLVM::LLVMFunctionType::get(global_type, {}, false)); + + Location loc = ctorFunc.getLoc(); + Block *entryBlock = ctorFunc.addEntryBlock(rewriter); + rewriter.setInsertionPointToEnd(entryBlock); + + auto dispatch = op.clone(); + rewriter.insert(dispatch); + Value globalPtr = rewriter.create(loc, global); + rewriter.create(loc, dispatch.getResults(), globalPtr); + rewriter.create(loc, dispatch.getResults()); + + // initialize the gloabl with global_ctors, as the initializer of global + // does not allow side effect + rewriter.setInsertionPointToStart(module.getBody()); + LLVM::GlobalCtorsOp global_ctors = nullptr; + for (auto &op : module->getRegion(0).front()) { + auto ctorOp = dyn_cast(op); + if (ctorOp) { + global_ctors = ctorOp; + break; + } + } + + SmallVector ctorRefs; + SmallVector priorities; + if (global_ctors) { + auto ctorRefsAttr = global_ctors.getCtors(); + auto prioritiesAttr = global_ctors.getPriorities(); + for (auto &&[ctor, prior] : llvm::zip(ctorRefsAttr, prioritiesAttr)) { + ctorRefs.push_back(ctor); + priorities.push_back(prior); + } + LLVM_DEBUG(llvm::dbgs() + << "After append ctors: " << ctorRefs.size() << "\n"); + } + ctorRefs.push_back(ctorName); + // Set new ctor's priority to lowest + priorities.push_back(IntegerAttr::get(rewriter.getI32Type(), INT_MAX)); + if (global_ctors) { + LLVM_DEBUG(llvm::dbgs() << "Replace existing ctors\n"); + // If there's existing ctors + rewriter.replaceOpWithNewOp( + global_ctors, rewriter.getArrayAttr(ctorRefs), + rewriter.getArrayAttr(priorities)); + } else { + LLVM_DEBUG(llvm::dbgs() << "Create new ctor\n"); + rewriter.create(module.getLoc(), + rewriter.getArrayAttr(ctorRefs), + rewriter.getArrayAttr(priorities)); + } + return global; +} + +class EarlyDispatchBrgemmRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(microkernel::BrgemmDispatchOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + ModuleOp module = op->template getParentOfType(); + func::FuncOp func = op->template getParentOfType(); + + auto globalKernelName = createGlobalKernelHandleName(rewriter, op); + if (failed(globalKernelName)) { + return rewriter.notifyMatchFailure( + op, "Failed to create global kernel handle name"); + } + + // Generate kernel handle global name + auto globalKernel = + getOrCreateGlobalKernelHandle(rewriter, module, *globalKernelName, op); + if (failed(globalKernel)) { + return rewriter.notifyMatchFailure( + op, "Failed to create global kernel handle"); + } + + // Inject global val loading into start of function + auto funcBlock = &func.getBody().front(); + rewriter.setInsertionPointToStart(funcBlock); + Value globalPtr = rewriter.create(loc, *globalKernel); + Value globalVal = rewriter.create( + loc, op.getResults().getType(), globalPtr); + rewriter.moveOpAfter(op, funcBlock, funcBlock->begin()); + rewriter.replaceOp(op, globalVal); + return success(); + } +}; + +class EarlyDispatchMicrokernel + : public impl::EarlyDispatchMicrokernelBase { +public: + using impl::EarlyDispatchMicrokernelBase< + EarlyDispatchMicrokernel>::EarlyDispatchMicrokernelBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + + // Ignore newly created Ops + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + if (failed( + applyPatternsAndFoldGreedily(getOperation(), patternSet, config))) + signalPassFailure(); + } +}; + +} // namespace mlir::microkernel diff --git a/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp new file mode 100644 index 000000000..9865f5220 --- /dev/null +++ b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp @@ -0,0 +1,305 @@ +//===-- MergeBranchMicrokernelContext.cpp - Merge same context --*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" +#include "gc/Transforms/Utils/ValueUtils.h" +#include "oneapi/dnnl/dnnl_types.h" + +namespace mlir::microkernel { +#define GEN_PASS_DEF_MERGEBRANCHMICROKERNELCONTEXT +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define DEBUG_TYPE "merge-branch-microkernel-context" + +class BrgemmDispatchAnalysis { +private: + // A map for tile_config -> tile_dispatch + DenseMap brgemmDispatches; + + Operation *traceKernelDispatch(Operation *op); + Operation *traceDispatchInGlobalCtor(ModuleOp module, + llvm::StringRef global_name); + +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BrgemmDispatchAnalysis) + explicit BrgemmDispatchAnalysis(Operation *); + void setKernelDispatch(Operation *tilecfg, Operation *dispatch) { + LLVM_DEBUG(llvm::dbgs() << "* setKernelDispatch: " << tilecfg << "; " + << dispatch << "\n"); + brgemmDispatches[tilecfg] = dispatch; + }; + Operation *getKernelDispatch(Operation *tilecfg) const { + auto iter = brgemmDispatches.find(tilecfg); + if (iter == brgemmDispatches.end()) { + return nullptr; + } + return iter->second; + }; +}; + +BrgemmDispatchAnalysis::BrgemmDispatchAnalysis(Operation *root) { + LLVM_DEBUG(llvm::dbgs() << "* construct BrgemmDispatchAnalysis: " << *root + << "\n"); + ModuleOp module = dyn_cast_or_null(root); + if (!module) + return; + + module->walk([this](Operation *op) { + auto callOp = dyn_cast_or_null(op); + if (!callOp) + return; + StringAttr callee = callOp.getCalleeAttr().getAttr(); + if (callee != StringAttr::get(op->getContext(), DNNL_BRGEMM_TILECFG_NAME)) + return; + + auto dispatch = traceKernelDispatch(callOp); + assert(dispatch && "No dispatch found for tilecfg Op"); + setKernelDispatch(callOp, dispatch); + }); +} + +Operation *BrgemmDispatchAnalysis::traceKernelDispatch(Operation *op) { + ModuleOp module = op->template getParentOfType(); + auto callOp = dyn_cast_or_null(op); + assert(callOp); + StringAttr callee = callOp.getCalleeAttr().getAttr(); + assert(callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_TILECFG_NAME)); + auto kernelProducer = callOp.getOperand(0).getDefiningOp(); + // Direct producer is supposed to be either `brgemm.dispatch` or LLVM::load + // global Any other cases are extremely rare (mostly invalid MLIR), so + // considered as not found + if (auto tryCallOp = dyn_cast_or_null(kernelProducer)) { + callee = tryCallOp.getCalleeAttr().getAttr(); + if (callee != StringAttr::get(op->getContext(), DNNL_BRGEMM_DISPATCH_NAME)) + return nullptr; + return tryCallOp; + } + if (auto tryLoadOp = dyn_cast_or_null(kernelProducer)) { + if (auto tryAddrOfOp = dyn_cast_or_null( + tryLoadOp.getOperand().getDefiningOp())) + return traceDispatchInGlobalCtor(module, tryAddrOfOp.getGlobalName()); + } + return nullptr; +} + +Operation * +BrgemmDispatchAnalysis::traceDispatchInGlobalCtor(ModuleOp module, + llvm::StringRef global_name) { + std::string gctor_name = std::string(global_name) + "_ctor"; + FlatSymbolRefAttr ctorName = + SymbolRefAttr::get(module->getContext(), gctor_name); + auto ctor = module.lookupSymbol(ctorName); + if (!ctor) + return nullptr; + + // ctor should contain only one call for kernel dispatch + auto &body = ctor.getBody(); + for (auto &opRef : body.getOps()) { + auto *op = &opRef; + auto tryCallOp = dyn_cast_or_null(op); + if (!tryCallOp) + continue; + auto callee = tryCallOp.getCalleeAttr().getAttr(); + if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_DISPATCH_NAME)) + return op; + } + return nullptr; +} + +// return a pair of extracted from given region, +// only check direct descendants +static std::pair +extractTileOpsFromRegion(Region ®ion) { + std::pair ret{nullptr, nullptr}; + + for (auto &opRef : region.getOps()) { + LLVM_DEBUG(llvm::dbgs() << ">>> " << opRef << "\n"); + auto *op = &opRef; + auto tryCallOp = dyn_cast_or_null(op); + if (!tryCallOp) + continue; + auto callee = tryCallOp.getCalleeAttr().getAttr(); + if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_TILECFG_NAME)) + ret.first = op; + else if (callee == + StringAttr::get(op->getContext(), DNNL_BRGEMM_TILERELEASE_NAME)) + ret.second = op; + } + + return ret; +} + +static size_t DNNL_BRGEMM_DISPATCH_BETA_PARAM_INDEX = 8; +static bool dispatchHasSameContext(Operation *lhs, Operation *rhs) { + auto lhsDispatch = dyn_cast_or_null(lhs); + auto rhsDispatch = dyn_cast_or_null(rhs); + if (!lhsDispatch || !rhsDispatch) + return false; + auto dispatchNameAttr = + StringAttr::get(lhs->getContext(), DNNL_BRGEMM_DISPATCH_NAME); + if (lhsDispatch.getCalleeAttr().getAttr() != dispatchNameAttr || + rhsDispatch.getCalleeAttr().getAttr() != dispatchNameAttr) + return false; + + auto lhsOperands = lhsDispatch.getOperands(); + auto rhsOperands = rhsDispatch.getOperands(); + assert(lhsOperands.size() == rhsOperands.size() && + "Inconsistent operand size"); + for (size_t idx = 0; idx < lhsOperands.size(); idx++) { + if (idx == DNNL_BRGEMM_DISPATCH_BETA_PARAM_INDEX) { + // skip `beta` operand in index no.8 + // since per dnnl design, it does not affect BRGEMM blocking & palettes + continue; + } + auto lhsCstOp = + dyn_cast_or_null(lhsOperands[idx].getDefiningOp()); + auto rhsCstOp = + dyn_cast_or_null(rhsOperands[idx].getDefiningOp()); + if (!lhsCstOp || !rhsCstOp) + return false; + if (lhsCstOp.getValue() != rhsCstOp.getValue()) + return false; + } + return true; +} + +class ScfIfRewriter : public OpRewritePattern { +private: + BrgemmDispatchAnalysis &analysis; + +public: + using OpRewritePattern::OpRewritePattern; + + ScfIfRewriter(MLIRContext *context, BrgemmDispatchAnalysis &ana) + : OpRewritePattern(context), analysis{ana} {} + + LogicalResult matchAndRewrite(scf::IfOp op, + PatternRewriter &rewriter) const final { + auto &ifRegion = op.getThenRegion(); + auto &elseRegion = op.getElseRegion(); + if (!ifRegion.hasOneBlock() || !elseRegion.hasOneBlock()) + return rewriter.notifyMatchFailure(op, + "Cannot merge for non-full branch"); + auto ifTileOps = extractTileOpsFromRegion(ifRegion); + auto elseTileOps = extractTileOpsFromRegion(elseRegion); + if (!ifTileOps.first || !ifTileOps.second || !elseTileOps.first || + !elseTileOps.second) + return rewriter.notifyMatchFailure( + op, "Cannot merge for inconsistent branch"); + + auto ifTileDispatch = analysis.getKernelDispatch(ifTileOps.first); + auto elseTileDispatch = analysis.getKernelDispatch(elseTileOps.first); + if (!ifTileDispatch || !elseTileDispatch) + return rewriter.notifyMatchFailure(op, "Cannot find kernel dispatch"); + + if (!dispatchHasSameContext(ifTileDispatch, elseTileDispatch)) + return rewriter.notifyMatchFailure( + op, "Kernels in branch has different context"); + + // Avoid breaking dominance of dispatch + if (ifTileDispatch->getParentRegion() == &ifRegion) + return rewriter.notifyMatchFailure(op, + "Dispatch dominance prevents merging"); + + // Whole branch reuses internal context of kernel in `if` region + rewriter.eraseOp(elseTileOps.first); + rewriter.eraseOp(elseTileOps.second); + rewriter.moveOpBefore(ifTileOps.first, op); + rewriter.moveOpAfter(ifTileOps.second, op); + + return success(); + } +}; + +class ScfIndexSwitchRewriter : public OpRewritePattern { +private: + BrgemmDispatchAnalysis &analysis; + +public: + using OpRewritePattern::OpRewritePattern; + + ScfIndexSwitchRewriter(MLIRContext *context, BrgemmDispatchAnalysis &ana) + : OpRewritePattern(context), analysis{ana} {} + + LogicalResult matchAndRewrite(scf::IndexSwitchOp op, + PatternRewriter &rewriter) const final { + auto &defaultRegion = op.getDefaultRegion(); + auto caseRegions = op.getCaseRegions(); + + auto defaultTileOps = extractTileOpsFromRegion(defaultRegion); + if (!defaultTileOps.first || !defaultTileOps.second) + return rewriter.notifyMatchFailure( + op, "Cannot merge for inconsistent branch"); + SmallVector, 5> caseTilesOps; + for (auto &caseRegion : caseRegions) { + auto caseTileOps = extractTileOpsFromRegion(caseRegion); + if (!caseTileOps.first || !caseTileOps.second) + return rewriter.notifyMatchFailure( + op, "Cannot merge for inconsistent branch"); + caseTilesOps.push_back(caseTileOps); + } + + auto defaultTileDispatch = analysis.getKernelDispatch(defaultTileOps.first); + if (!defaultTileDispatch) + return rewriter.notifyMatchFailure(op, "Cannot find kernel dispatch"); + + for (size_t idx = 0; idx < caseRegions.size(); idx++) { + auto caseTileDispatch = + analysis.getKernelDispatch(caseTilesOps[idx].first); + if (!caseTileDispatch) + return rewriter.notifyMatchFailure(op, "Cannot find kernel dispatch"); + if (!dispatchHasSameContext(defaultTileDispatch, caseTileDispatch)) + return rewriter.notifyMatchFailure( + op, "Kernels in branch has different context"); + } + + // Avoid breaking dominance of dispatch + if (defaultTileDispatch->getParentRegion() == &defaultRegion) + return rewriter.notifyMatchFailure(op, + "Dispatch dominance prevents merging"); + + // Whole branch reuses internal context of kernel in `default` region + for (auto &ops : caseTilesOps) { + rewriter.eraseOp(ops.first); + rewriter.eraseOp(ops.second); + } + rewriter.moveOpBefore(defaultTileOps.first, op); + rewriter.moveOpAfter(defaultTileOps.second, op); + + return success(); + } +}; + +class MergeBranchMicrokernelContext + : public impl::MergeBranchMicrokernelContextBase< + MergeBranchMicrokernelContext> { +public: + using impl::MergeBranchMicrokernelContextBase< + MergeBranchMicrokernelContext>::MergeBranchMicrokernelContextBase; + void runOnOperation() final { + auto &dispatchAnalysis = getAnalysis(); + + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), dispatchAnalysis); + patterns.add(&getContext(), dispatchAnalysis); + FrozenRewritePatternSet patternSet(std::move(patterns)); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { + signalPassFailure(); + } + } +}; + +} // namespace mlir::microkernel diff --git a/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp new file mode 100644 index 000000000..ad8a0631f --- /dev/null +++ b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp @@ -0,0 +1,431 @@ +//===-- MicrokernelInvariantCodeMotion.cpp - Hoist invariance ---*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include +#include + +#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" +#include "gc/Transforms/Utils/ValueUtils.h" +#include "oneapi/dnnl/dnnl_types.h" + +namespace mlir::microkernel { +#define GEN_PASS_DEF_MICROKERNELINVARIANTCODEMOTION +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define DEBUG_TYPE "microkernel-invariant-code-motion" + +enum BrgemmCallType { INAPPLICABLE = -1, DISPATCH, TILECFG, TILERELEASE }; + +static bool isParallelLoop(Operation *op) { + return llvm::isa(op) || llvm::isa(op) || + llvm::isa(op) || llvm::isa(op); +} + +static bool isConcernedCF(Operation *op) { + return llvm::isa(op) || llvm::isa(op) || + llvm::isa(op) || llvm::isa(op); +} + +static BrgemmCallType getBrgemmCallType(Operation *op) { + if (!llvm::isa(op)) { + return BrgemmCallType::INAPPLICABLE; + } + auto callOp = dyn_cast(op); + auto calleeName = callOp.getCalleeAttr().getAttr().getValue(); + + if (calleeName == DNNL_BRGEMM_DISPATCH_NAME) + return BrgemmCallType::DISPATCH; + if (calleeName == DNNL_BRGEMM_TILECFG_NAME) + return BrgemmCallType::TILECFG; + if (calleeName == DNNL_BRGEMM_TILERELEASE_NAME) + return BrgemmCallType::TILERELEASE; + return BrgemmCallType::INAPPLICABLE; +} + +// Tree node of structure info tree, each node represents an Op +// This tree contains only concerned Ops +struct BrgemmContextStructInfo { + // Basic structure info retrieved by first walk + Operation *contextRoot; // Could be parallel loop or func + Operation *self, *parent; + DenseSet child; + SmallVector containBrgemmCallType; + // Rewrite-time info retrieved by analysing basic structure info + union { + Operation *maxInvariantScope; // Used by BrgemmCallOps for hoisting + bool hasTilereleased; // Used by other Ops as hoisting scopes to + // dedup tilerelease injection + }; + BrgemmContextStructInfo() { + contextRoot = nullptr; + self = nullptr; + parent = nullptr; + containBrgemmCallType = {false, false, false}; + maxInvariantScope = nullptr; + } +}; + +using OpStructInfoMap = DenseMap; + +class BrgemmTilecfgRewriter : public OpRewritePattern { +private: + OpStructInfoMap &structInfo; + +public: + using OpRewritePattern::OpRewritePattern; + + BrgemmTilecfgRewriter(MLIRContext *context, OpStructInfoMap &si) + : OpRewritePattern(context), structInfo{si} {} + + LogicalResult matchAndRewrite(func::CallOp op, + PatternRewriter &rewriter) const final { + ModuleOp module = op->template getParentOfType(); + + StringAttr callee = op.getCalleeAttr().getAttr(); + if (!module.lookupSymbol(callee)) + return rewriter.notifyMatchFailure(op, + "Invalid CallOp to unknown callee"); + if (callee != + StringAttr::get(rewriter.getContext(), DNNL_BRGEMM_TILECFG_NAME)) + return rewriter.notifyMatchFailure(op, "Not call to BRGEMM tilecfg"); + auto opInfoIter = structInfo.find(op); + if (opInfoIter == structInfo.end()) { + return rewriter.notifyMatchFailure(op, "Cannot find structInfo for Op"); + } + auto &opStructInfo = opInfoIter->second; + + // Don't hoist if max invariant scope is itself to reduce + // unnecessary movement + if (opStructInfo.maxInvariantScope == op) { + return rewriter.notifyMatchFailure(op, "No need to hoist"); + } + rewriter.moveOpBefore(op, opStructInfo.maxInvariantScope); + // Avoid being hoisted again + opStructInfo.maxInvariantScope = op; + return success(); + } +}; + +static void markScopeAsReleased(OpStructInfoMap &structInfo, Operation *op) { + auto iter = structInfo.find(op); + assert(iter != structInfo.end()); + // Don't mark BrgemmCallOps + if (getBrgemmCallType(op) != BrgemmCallType::INAPPLICABLE) + return; + iter->second.hasTilereleased = true; + + for (auto ch : iter->second.child) { + markScopeAsReleased(structInfo, ch); + } +} + +class BrgemmTilereleaseRewriter : public OpRewritePattern { +private: + OpStructInfoMap &structInfo; + +public: + using OpRewritePattern::OpRewritePattern; + + BrgemmTilereleaseRewriter(MLIRContext *context, OpStructInfoMap &si) + : OpRewritePattern(context), structInfo{si} {} + LogicalResult matchAndRewrite(func::CallOp op, + PatternRewriter &rewriter) const final { + ModuleOp module = op->template getParentOfType(); + + StringAttr callee = op.getCalleeAttr().getAttr(); + if (!module.lookupSymbol(callee)) + return rewriter.notifyMatchFailure(op, + "Invalid CallOp to unknown callee"); + if (callee != + StringAttr::get(rewriter.getContext(), DNNL_BRGEMM_TILERELEASE_NAME)) + return rewriter.notifyMatchFailure(op, "Not call to BRGEMM tilerelease"); + + auto opInfoIter = structInfo.find(op); + if (opInfoIter == structInfo.end()) { + return rewriter.notifyMatchFailure(op, "Cannot find structInfo for Op"); + } + auto &opStructInfo = opInfoIter->second; + auto targetInfoIter = structInfo.find(opStructInfo.maxInvariantScope); + assert(opStructInfo.maxInvariantScope); + // Don't hoist if max invariant scope is itself to reduce + // unnecessary movement + if (opStructInfo.maxInvariantScope == op) { + return rewriter.notifyMatchFailure(op, "No need to hoist"); + } + assert(targetInfoIter != structInfo.end()); + // move last tilerelease to end of contextRoot, and remove all + // others + if (targetInfoIter->second.hasTilereleased) { + rewriter.eraseOp(op); + } else { + // rewriter.moveOpBefore(op, block, enditer); + rewriter.moveOpAfter(op, opStructInfo.maxInvariantScope); + // Mark all sub scope as released to avoid duplicate Tilerelease + markScopeAsReleased(structInfo, opStructInfo.maxInvariantScope); + // Avoid being hoisted again + opStructInfo.maxInvariantScope = op; + } + return success(); + } +}; + +class MicrokernelInvariantCodeMotion + : public impl::MicrokernelInvariantCodeMotionBase< + MicrokernelInvariantCodeMotion> { +private: + // This helper create structInfo tree node along the path from input Op(as + // leaf) to contextRoot(FuncOp or any parallel Op) on demand; + // This tree only contains concerned Ops, including BrgemmCall Ops, parallel + // ops and related SCF Ops etc.; + // Input Op should be a BrgemmCall Op or the op + // depent by BrgemmTilecfg/BrgemmRelease + BrgemmContextStructInfo getOrCreateBrgemmContext(OpStructInfoMap &structInfo, + Operation *op) { + auto resIter = structInfo.find(op); + if (resIter != structInfo.end()) { + return resIter->second; + } + + SmallVector createdInfo; + Operation *contextRootForCreatedInfo = nullptr; + + auto doCreateStructInfo = [&](Operation *child, Operation *op) { + BrgemmContextStructInfo info; + info.self = op; + if (child) { + auto iter = structInfo.find(child); + assert(iter != structInfo.end()); + iter->second.parent = op; + info.child.insert(child); + } + structInfo.insert(std::make_pair(op, std::move(info))); + auto iter = structInfo.find(op); + createdInfo.push_back(&iter->second); + return &iter->second; + }; + // Create info for input Op as leaf + auto brgemmInfo = doCreateStructInfo(nullptr, op); + auto callType = getBrgemmCallType(op); + if (callType != BrgemmCallType::INAPPLICABLE) { + brgemmInfo->containBrgemmCallType[callType] = true; + } + + auto last = op; + auto current = op->getParentOp(); + // Traverse up the IR tree, creating structInfo for each concerned Op + while (current) { + bool isParaLoop = isParallelLoop(current); + bool isCCF = isConcernedCF(current); + if (!llvm::isa(current) && !isParaLoop && !isCCF) { + // Only care about selected Ops + current = current->getParentOp(); + continue; + } + + auto iter = structInfo.find(current); + if (iter != structInfo.end()) { + // StructInfo exists for current Op, then we don't need to create info + // anymore as all ancestors have been created + // But we still need to propagate containBrgemmCallType if we are + // dealing with BrgemmCall Ops + if (last) { + auto lastIter = structInfo.find(last); + assert(lastIter != structInfo.end()); + lastIter->second.parent = current; + iter->second.child.insert(last); + // Invalidate last as we don't create new info anymore + last = nullptr; + } + if (callType != BrgemmCallType::INAPPLICABLE) { + // Propagate containCallType if needed + iter->second.containBrgemmCallType[callType] = true; + } else + break; + } else { + // StructInfo not exist, then create one for current Op and keep + // Traversing up + auto created = doCreateStructInfo(last, current); + if (callType != BrgemmCallType::INAPPLICABLE) { + created->containBrgemmCallType[callType] = true; + } + last = current; + } + if (llvm::isa(current) || isParaLoop) { + // Encounter `contextRoot`, then record and terminate traversing + contextRootForCreatedInfo = current; + break; + } + current = current->getParentOp(); + } + + // Assign `contextRoot` for newly created structInfo + if (contextRootForCreatedInfo) { + for (auto info : createdInfo) + info->contextRoot = contextRootForCreatedInfo; + } + + resIter = structInfo.find(op); + assert(resIter != structInfo.end()); + return resIter->second; + } + + // This helper expand invariant scope according to two function: + // 1. controlFlowAllow: Whether we can hoist the BrgemmCallOp out of the scope + // of current Op; For example, we won't move TILECFG out of an IfOp as it + // contains underministic control flow. + // 2. peerAllow: Whether we can hoist the BrgemmCallOp out of the scope of + // current Op without violating other peer BrgemmCallOp in the same level; For + // example, one scf.ForOp contains two TILECFG in the same level, then we + // cannot hoist any of them. + // NOLINTBEGIN(performance-unnecessary-value-param) + void expandInvariantScopeWithCond( + OpStructInfoMap &structInfo, Operation *op, + std::function controlFlowAllow, + std::function &)> + peerAllow) { + // NOLINTEND(performance-unnecessary-value-param) + auto opIter = structInfo.find(op); + assert(opIter != structInfo.end()); + auto contextRoot = opIter->second.contextRoot; + auto current = op; + auto currIter = opIter; + auto parent = opIter->second.parent; + while (parent != contextRoot) { + auto parentIter = structInfo.find(parent); + assert(parentIter != structInfo.end()); + // Verify whether we can expand the scope to direct parent + bool isControlFlowAllow = controlFlowAllow(parent); + bool isPeerAllow = + peerAllow(op, structInfo, current, parentIter->second.child); + if (!isControlFlowAllow || !isPeerAllow) { + break; + } + current = parent; + currIter = parentIter; + parent = parentIter->second.parent; + } + + opIter->second.maxInvariantScope = current; + } + + void expandInvariantScope(OpStructInfoMap &structInfo, Operation *op) { + BrgemmCallType brgemmCallType = getBrgemmCallType(op); + assert(brgemmCallType == BrgemmCallType::TILECFG || + brgemmCallType == BrgemmCallType::TILERELEASE); + + if (brgemmCallType == BrgemmCallType::TILECFG) { + expandInvariantScopeWithCond( + structInfo, op, + [](Operation *op) -> bool { + return !llvm::isa(op) && + !llvm::isa(op); + }, + [](Operation *self, const OpStructInfoMap &structInfo, + Operation *current, const DenseSet &peers) -> bool { + for (auto peer : peers) { + if (peer == current) + continue; + if (peer == self->getOperand(0).getDefiningOp()) { + // Don't break operand domination + return false; + } + const auto iter = structInfo.find(peer); + assert(iter != structInfo.end()); + const auto &containType = iter->second.containBrgemmCallType; + if (containType[BrgemmCallType::DISPATCH] || + containType[BrgemmCallType::TILECFG]) { + return false; + } + } + return true; + }); + } else { // brgemmCallType == BrgemmCallType::TILERELEASE + expandInvariantScopeWithCond( + structInfo, op, + [](Operation *op) -> bool { + return !llvm::isa(op); + }, + [](Operation *self, const OpStructInfoMap &structInfo, + Operation *current, + const DenseSet &peers) -> bool { return true; }); + } + } + + LogicalResult collectBrgemmContextStructInfo(OpStructInfoMap &structInfo) { + // First walk to collect basic structure + getOperation()->walk( + [this, &structInfo](Operation *op) { + BrgemmCallType brgemmCallType = getBrgemmCallType(op); + if (brgemmCallType == BrgemmCallType::INAPPLICABLE) { + return; + } + + // Construct the structInfo tree lazily upon encountering BrgemmCall + // Op + auto info = getOrCreateBrgemmContext(structInfo, op); + structInfo.insert(std::make_pair(op, std::move(info))); + if (brgemmCallType == BrgemmCallType::TILECFG) { + // Also contruct tree node for the input of BrgemmTilecfg for + // dependency check in `expandInvariantScope` + auto dependOp = op->getOperand(0).getDefiningOp(); + auto dependInfo = getOrCreateBrgemmContext(structInfo, dependOp); + structInfo.insert(std::make_pair(dependOp, std::move(dependInfo))); + } + }); + + // Second walk to analyse hoist related info + getOperation()->walk( + [this, &structInfo](Operation *op) { + BrgemmCallType brgemmCallType = getBrgemmCallType(op); + if (brgemmCallType != BrgemmCallType::TILECFG && + brgemmCallType != BrgemmCallType::TILERELEASE) { + return; + } + + // find the maximal invariant scope for hoisting + expandInvariantScope(structInfo, op); + }); + + return success(); + } + +public: + using impl::MicrokernelInvariantCodeMotionBase< + MicrokernelInvariantCodeMotion>::MicrokernelInvariantCodeMotionBase; + void runOnOperation() final { + OpStructInfoMap structInfo; + + if (failed(collectBrgemmContextStructInfo(structInfo))) { + signalPassFailure(); + } + + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), structInfo); + patterns.add(&getContext(), structInfo); + FrozenRewritePatternSet patternSet(std::move(patterns)); + + // Ignore newly created Ops + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + if (failed( + applyPatternsAndFoldGreedily(getOperation(), patternSet, config))) { + signalPassFailure(); + } + } +}; + +} // namespace mlir::microkernel diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index fef58eae0..b63886fa0 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -30,6 +30,7 @@ #ifdef GC_HAS_ONEDNN_DIALECT #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #endif +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" #include "gc/Transforms/Passes.h" namespace mlir::gc { @@ -57,6 +58,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // Fine-grain fusion pass pm.addNestedPass(createIterativeTilingAndFusion()); // todo: fine-grain fusion pass + pm.addNestedPass( + mlir::microkernel::createConvertLinalgToMicrokernel()); // todo: lower linalg to arith/math on virtual vector pass // REMOVE this pass after the above passes are added. Currently we add this @@ -120,13 +123,12 @@ void populateBufferizationPasses(mlir::OpPassManager &pm) { // scf + arith + math + vector + memref + func/microkernel void populateMicroKernelPasses(mlir::OpPassManager &pm) { - // todo: ConvertLinalgToMicrokernel pass - // todo: CleanupInvalidMicrokernel pass - // todo: InvariantMicrokernelMotion pass - // todo: ConvertMicrokernelToDnnlFunc to lower brgemm to dnnl call - // todo: ConvertMicrokernelToXsmm, to lower brgemm to libxsmm call - // todo: LowerMicrokernel pass - // todo: DispatchMicrokernel + pm.addNestedPass(mlir::microkernel::createExpandMicrokernel()); + pm.addPass(mlir::microkernel::createEarlyDispatchMicrokernel()); + pm.addPass(mlir::microkernel::createConvertMicrokernelToDnnlFunc()); + pm.addPass(mlir::microkernel::createMergeBranchMicrokernelContext()); + pm.addPass(mlir::microkernel::createMicrokernelInvariantCodeMotion()); + populateCleanUpPasses(pm); } void populateCPURuntimePasses(mlir::OpPassManager &pm) { diff --git a/lib/gc/Transforms/Utils/ValueUtils.cpp b/lib/gc/Transforms/Utils/ValueUtils.cpp index 8750042ee..0f26bdd42 100644 --- a/lib/gc/Transforms/Utils/ValueUtils.cpp +++ b/lib/gc/Transforms/Utils/ValueUtils.cpp @@ -120,7 +120,7 @@ std::pair getPtrAndOffset(OpBuilder &builder, Value operand) { auto memrefType = dyn_cast(operand.getType()); assert(memrefType && "Expect a memref value"); - Location loc = operand.getDefiningOp()->getLoc(); + Location loc = operand.getLoc(); OpBuilder::InsertionGuard guard(builder); // Insert right after operand producer for better opt chances. builder.setInsertionPointAfterValue(operand); diff --git a/test/mlir/test/gc/Dialect/Microkernel/early-dispatch-microkernel.mlir b/test/mlir/test/gc/Dialect/Microkernel/early-dispatch-microkernel.mlir new file mode 100644 index 000000000..364b88dd1 --- /dev/null +++ b/test/mlir/test/gc/Dialect/Microkernel/early-dispatch-microkernel.mlir @@ -0,0 +1,62 @@ +// RUN: gc-opt %s -early-dispatch-microkernel -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg0, %arg1) in (4, 8) { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %1 = arith.addf %in, %in_7 : f32 + linalg.yield %1 : f32 + } + %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.maximumf %in, %cst : f32 + linalg.yield %1 : f32 + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + return + } +} + +// CHECK: llvm.mlir.global_ctors {ctors = [@[[G_CTOR_NAME:.+]]], priorities = [[[G_CTOR_PRIOR:.+]] : i32]} +// CHECK: llvm.mlir.global internal @[[G_NAME:.+]]() {addr_space = 0 : i32} : i64 + +// CHECK: llvm.func @[[G_CTOR_NAME]]() -> i64 { +// CHECK-DAG: %[[G_PTR:.+]] = llvm.mlir.addressof @[[G_NAME]] : !llvm.ptr +// CHECK-DAG: %[[KERNEL:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) +// CHECK: llvm.store %[[KERNEL]], %[[G_PTR]] : i64, !llvm.ptr + +// CHECK-LABEL: simple_brgemm +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 + +// CHECK: %[[G_PTR_2:.+]] = llvm.mlir.addressof @[[G_NAME]] : !llvm.ptr +// CHECK-NEXT: %[[KERNEL2:.+]] = llvm.load %[[G_PTR_2]] : !llvm.ptr -> i64 + +// CHECK: %[[memrefC:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + +// CHECK: %[[subviewA:.+]] = memref.subview %[[memrefA:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[subviewB:.+]] = memref.subview %[[memrefB:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + +// CHECK: microkernel.brgemm.execute(%[[KERNEL2]], %[[subviewA]], %[[subviewB]], %[[memrefC]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + +// ----- diff --git a/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir b/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir new file mode 100644 index 000000000..7429ec970 --- /dev/null +++ b/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir @@ -0,0 +1,460 @@ +// RUN: gc-opt %s -merge-branch-microkernel-context -split-input-file | FileCheck %s + +module { + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @if_branch_context_merge() { + %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %2 = llvm.load %1 : !llvm.ptr -> i64 + %3 = llvm.load %0 : !llvm.ptr -> i64 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %4 = arith.index_cast %intptr : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_2 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %8 = arith.index_cast %intptr_8 : index to i64 + %9 = llvm.inttoptr %8 : i64 to !llvm.ptr + %10 = arith.cmpi eq, %arg0, %c0 : index + scf.if %10 { + func.call @dnnl_brgemm_tileconfig(%2) : (i64) -> () + func.call @dnnl_brgemm_execute(%2, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + } else { + func.call @dnnl_brgemm_tileconfig(%3) : (i64) -> () + func.call @dnnl_brgemm_execute(%3, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + } + memref.dealloc %alloc_1 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: if_branch_context_merge + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK-NEXT: scf.if +// CHECK: } else { +// CHECK: } +// CHECK-NEXT: func.call @dnnl_brgemm_tilerelease() : () -> () + +// ----- + +module { + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @if_only_branch_context_merge() { + %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.load %0 : !llvm.ptr -> i64 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %2 = arith.index_cast %intptr : index to i64 + %3 = llvm.inttoptr %2 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %4 = arith.index_cast %intptr_2 : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_8 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + %8 = arith.cmpi eq, %arg0, %c0 : index + scf.if %8 { + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %5, %offset, %7, %offset_5, %3, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + } + memref.dealloc %alloc_1 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: if_only_branch_context_merge + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: scf.if +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: } + +// ----- + +module { + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2_ctor, @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c512_i64, %c512_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @if_branch_context_no_merge() { + %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr + %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %2 = llvm.load %1 : !llvm.ptr -> i64 + %3 = llvm.load %0 : !llvm.ptr -> i64 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %4 = arith.index_cast %intptr : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_2 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %8 = arith.index_cast %intptr_8 : index to i64 + %9 = llvm.inttoptr %8 : i64 to !llvm.ptr + %10 = arith.cmpi eq, %arg0, %c0 : index + scf.if %10 { + func.call @dnnl_brgemm_tileconfig(%2) : (i64) -> () + func.call @dnnl_brgemm_execute(%2, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + } else { + func.call @dnnl_brgemm_tileconfig(%3) : (i64) -> () + func.call @dnnl_brgemm_execute(%3, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + } + memref.dealloc %alloc_1 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: if_branch_context_no_merge + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: scf.if +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: } else { +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: } + +// ----- + +module { + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @switch_branch_context_merge() { + %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %2 = llvm.load %1 : !llvm.ptr -> i64 + %3 = llvm.load %0 : !llvm.ptr -> i64 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %4 = arith.index_cast %intptr : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_2 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %8 = arith.index_cast %intptr_8 : index to i64 + %9 = llvm.inttoptr %8 : i64 to !llvm.ptr + scf.index_switch %arg0 + case 0 { + func.call @dnnl_brgemm_tileconfig(%3) : (i64) -> () + func.call @dnnl_brgemm_execute(%3, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + scf.yield + } + case 1 { + func.call @dnnl_brgemm_tileconfig(%2) : (i64) -> () + func.call @dnnl_brgemm_execute(%2, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + scf.yield + } + default { + func.call @dnnl_brgemm_tileconfig(%2) : (i64) -> () + func.call @dnnl_brgemm_execute(%2, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + } + memref.dealloc %alloc_1 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: switch_branch_context_merge + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK-NEXT: scf.index_switch +// CHECK: case 0 { +// CHECK: case 1 { +// CHECK: default { +// CHECK: } +// CHECK-NEXT: func.call @dnnl_brgemm_tilerelease() : () -> () + +// ----- + +module { + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32, 2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c512_i64, %c512_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @switch_branch_context_no_merge() { + %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %2 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr + %3 = llvm.load %2 : !llvm.ptr -> i64 + %4 = llvm.load %1 : !llvm.ptr -> i64 + %5 = llvm.load %0 : !llvm.ptr -> i64 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %6 = arith.index_cast %intptr : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %8 = arith.index_cast %intptr_2 : index to i64 + %9 = llvm.inttoptr %8 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %10 = arith.index_cast %intptr_8 : index to i64 + %11 = llvm.inttoptr %10 : i64 to !llvm.ptr + scf.index_switch %arg0 + case 0 { + func.call @dnnl_brgemm_tileconfig(%4) : (i64) -> () + func.call @dnnl_brgemm_execute(%4, %9, %offset, %11, %offset_5, %7, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + scf.yield + } + case 1 { + func.call @dnnl_brgemm_tileconfig(%5) : (i64) -> () + func.call @dnnl_brgemm_execute(%5, %9, %offset, %11, %offset_5, %7, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + scf.yield + } + default { + func.call @dnnl_brgemm_tileconfig(%3) : (i64) -> () + func.call @dnnl_brgemm_execute(%3, %9, %offset, %11, %offset_5, %7, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + } + memref.dealloc %alloc_1 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: switch_branch_context_no_merge + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: scf.index_switch +// CHECK: case 0 { +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: case 1 { +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: default { +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: } diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir new file mode 100644 index 000000000..6cd0173c8 --- /dev/null +++ b/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir @@ -0,0 +1,307 @@ +// RUN: gc-opt %s -microkernel-invariant-code-motion -split-input-file | FileCheck %s + +module { + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @parallel_no_hoist() { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.load %0 : !llvm.ptr -> i64 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + scf.forall (%arg0, %arg1) in (4, 8) { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %2 = arith.index_cast %intptr : index to i64 + %3 = llvm.inttoptr %2 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %4 = arith.index_cast %intptr_2 : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_8 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %5, %offset, %7, %offset_5, %3, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + memref.dealloc %alloc_1 : memref<32x32xf32> + } + return + } +} + +// CHECK-LABEL: parallel_no_hoist + +// CHECK: scf.forall (%arg0, %arg1) in (4, 8) +// CHECK: call @dnnl_brgemm_tileconfig(%[[A:.+]]) : (i64) -> () +// CHECK: call @dnnl_brgemm_execute([[B:.+]]) : ([[C:.+]]) -> () +// CHECK-NEXT: call @dnnl_brgemm_tilerelease() : () -> () + +// ----- + +module { + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @multi_level_conflict() { + %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c2_i64 = arith.constant 2 : i64 + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.load %0 : !llvm.ptr -> i64 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %2 = arith.index_cast %intptr : index to i64 + %3 = llvm.inttoptr %2 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %4 = arith.index_cast %intptr_2 : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_8 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %5, %offset, %7, %offset_5, %3, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + %subview_9 = memref.subview %alloc[%arg0, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer_10, %offset_11, %sizes_12:3, %strides_13:3 = memref.extract_strided_metadata %subview_9 : memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_14 = memref.extract_aligned_pointer_as_index %subview_9 : memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %8 = arith.index_cast %intptr_14 : index to i64 + %9 = llvm.inttoptr %8 : i64 to !llvm.ptr + %subview_15 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_16, %offset_17, %sizes_18:4, %strides_19:4 = memref.extract_strided_metadata %subview_15 : memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_20 = memref.extract_aligned_pointer_as_index %subview_15 : memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %10 = arith.index_cast %intptr_20 : index to i64 + %11 = llvm.inttoptr %10 : i64 to !llvm.ptr + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %9, %offset_11, %11, %offset_17, %3, %c0, %c2_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + memref.dealloc %alloc_1 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: multi_level_conflict + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: call @dnnl_brgemm_tileconfig(%[[A:.+]]) : (i64) -> () +// CHECK: call @dnnl_brgemm_execute([[B:.+]]) : ([[C:.+]]) -> () +// CHECK-NOT: call @dnnl_brgemm_tilerelease() : () -> () + +// CHECK: call @dnnl_brgemm_tileconfig(%[[D:.+]]) : (i64) -> () +// CHECK: call @dnnl_brgemm_execute([[E:.+]]) : ([[F:.+]]) -> () +// CHECK-NOT: call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: call @dnnl_brgemm_tilerelease() : () -> () +// CHECK-NEXT: return + +// ----- + +module { + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @multi_level_partial_hoist() { + %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c2_i64 = arith.constant 2 : i64 + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.load %0 : !llvm.ptr -> i64 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %2 = arith.index_cast %intptr : index to i64 + %3 = llvm.inttoptr %2 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %4 = arith.index_cast %intptr_2 : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_8 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %5, %offset, %7, %offset_5, %3, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + memref.dealloc %alloc_1 : memref<32x32xf32> + } + scf.for %arg1 = %c0 to %c4 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %2 = arith.index_cast %intptr : index to i64 + %3 = llvm.inttoptr %2 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg1, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %4 = arith.index_cast %intptr_2 : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_8 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %5, %offset, %7, %offset_5, %3, %c0, %c2_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + } + } + return + } +} + +// CHECK-LABEL: multi_level_partial_hoist + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: call @dnnl_brgemm_tileconfig(%[[A:.+]]) : (i64) -> () +// CHECK: scf.for %arg1 = %c0 to %c8 step %c1 +// CHECK: call @dnnl_brgemm_execute([[B:.+]]) : ([[C:.+]]) -> () + +// CHECK: } +// CHECK-NEXT: call @dnnl_brgemm_tileconfig(%[[D:.+]]) : (i64) -> () +// CHECK-NEXT: scf.for %arg1 = %c0 to %c4 step %c1 +// CHECK: call @dnnl_brgemm_execute([[E:.+]]) : ([[F:.+]]) -> () + +// CHECK: call @dnnl_brgemm_tilerelease() : () -> () +// CHECK-NEXT: return + +// ----- + +module { + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @multi_level_full_hoist() { + %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.load %0 : !llvm.ptr -> i64 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %2 = arith.index_cast %intptr : index to i64 + %3 = llvm.inttoptr %2 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %4 = arith.index_cast %intptr_2 : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_8 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %5, %offset, %7, %offset_5, %3, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + memref.dealloc %alloc_1 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: multi_level_full_hoist + +// CHECK: call @dnnl_brgemm_tileconfig(%[[A:.+]]) : (i64) -> () +// CHECK-NEXT: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: call @dnnl_brgemm_execute([[B:.+]]) : ([[C:.+]]) -> () + +// CHECK: call @dnnl_brgemm_tilerelease() : () -> () +// CHECK-NEXT: return + +// ----- diff --git a/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir b/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir new file mode 100644 index 000000000..759f22776 --- /dev/null +++ b/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir @@ -0,0 +1,148 @@ +// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --cse --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-cpuruntime-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c2_i64 = arith.constant 2 : i64 + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c32_index = arith.constant 32 : index + %c16_i64 = arith.constant 16 : i64 + %cst0f = arith.constant 0.000000e+00 : f32 + %cstn64f = arith.constant -64.000000e+00 : f32 + %cst1f = arith.constant 1.000000e+00 : bf16 + %cst2f = arith.constant 2.000000e+00 : bf16 + %cst3f = arith.constant 3.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + linalg.fill ins(%cst1f : bf16) outs(%alloc : memref<4x16x32x32xbf16>) + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + linalg.fill ins(%cst2f : bf16) outs(%alloc_0 : memref<8x16x16x32x2xbf16>) + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst3f : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst0f : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %arg0i = arith.index_castui %arg0 : index to i64 + %arg1i = arith.index_castui %arg1 : index to i64 + %argmulti = arith.muli %arg0i, %arg1i : i64 + %v = arith.uitofp %argmulti : i64 to f32 + %v1 = arith.mulf %v, %cstn64f : f32 + linalg.fill ins(%v1 : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %1 = arith.addf %in, %in_7 : f32 + linalg.yield %1 : f32 + } + %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.maximumf %in, %cst0f : f32 + linalg.yield %1 : f32 + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + %subview_c = memref.subview %alloc_2[%arg0, 0, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + %subview_7 = memref.subview %alloc[%arg0, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_8 = memref.subview %alloc_0[%arg0, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %10 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) + microkernel.brgemm.prologue(%10) : (i64) -> () + microkernel.brgemm.execute(%10, %subview_7, %subview_8, %subview_c, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64, i64) -> () + microkernel.brgemm.epilogue(%10) : (i64) -> () + } + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + scf.for %arg2 = %c0_index to %c32_index step %c1_index { + scf.for %arg3 = %c0_index to %c32_index step %c1_index { + %elem = memref.load %alloc_2[%arg0, %arg1, %arg2, %arg3] : memref<4x8x32x32xf32> + cpuruntime.printf "%f, " %elem : f32 + } + cpuruntime.printf "\n" + } + cpuruntime.printf "==================================\n" + } + } + return + } + + func.func @main() { + call @simple_brgemm() : ()->() + cpuruntime.printf "BRGEMM DONE\n" + return + } + + // CHECK: 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, + // CHECK: ================================== + // CHECK: 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, + // CHECK: ================================== + // CHECK: 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, + // CHECK: ================================== + // CHECK: 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, + // CHECK: ================================== + // CHECK: 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, + // CHECK: ================================== + // CHECK: 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, + // CHECK: ================================== + // CHECK: 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, + // CHECK: ================================== + // CHECK: 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, + // CHECK: ================================== + // CHECK: 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, + // CHECK: ================================== + // CHECK: 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, + // CHECK: ================================== + // CHECK: 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, + // CHECK: ================================== + // CHECK: 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, + // CHECK: ================================== + // CHECK: 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, + // CHECK: ================================== + // CHECK: 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, + // CHECK: ================================== + // CHECK: 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, + // CHECK: ================================== + // CHECK: 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // CHECK: ================================== + // CHECK: 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // CHECK: ================================== + // CHECK: BRGEMM DONE +} diff --git a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir index cd927f175..38cf0dbf0 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir @@ -1,22 +1,36 @@ -// RUN: gc-opt %s --convert-microkernel-to-dnnl-func --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-cpuruntime-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void +// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --cse --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-cpuruntime-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void #map = affine_map<(d0, d1) -> (d0, d1)> module { func.func @simple_brgemm() { %c0_i64 = arith.constant 0 : i64 %c16_i64 = arith.constant 16 : i64 - %cst = arith.constant 0.000000e+00 : f32 + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c32_index = arith.constant 32 : index + %cst0f = arith.constant 0.000000e+00 : f32 + %cstn64f = arith.constant -64.000000e+00 : f32 + %cst1f = arith.constant 1.000000e+00 : f32 + %cst2f = arith.constant 2.000000e+00 : f32 + %cst3f = arith.constant 3.000000e+00 : f32 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc : memref<4x16x32x32xf32>) + linalg.fill ins(%cst1f : f32) outs(%alloc : memref<4x16x32x32xf32>) %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<8x16x32x32xf32>) + linalg.fill ins(%cst2f : f32) outs(%alloc_0 : memref<8x16x32x32xf32>) %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) + linalg.fill ins(%cst3f : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) + linalg.fill ins(%cst0f : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) scf.forall (%arg0, %arg1) in (4, 8) { %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %arg0i = arith.index_castui %arg0 : index to i64 + %arg1i = arith.index_castui %arg1 : index to i64 + %argmulti = arith.muli %arg0i, %arg1i : i64 + %v = arith.uitofp %argmulti : i64 to f32 + %v1 = arith.mulf %v, %cstn64f : f32 + linalg.fill ins(%v1 : f32) outs(%alloc_3 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) @@ -32,16 +46,90 @@ module { %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { ^bb0(%in: f32, %out: f32): - %1 = arith.maximumf %in, %cst : f32 + %1 = arith.maximumf %in, %cst0f : f32 linalg.yield %1 : f32 } memref.dealloc %alloc_3 : memref<32x32xf32> } + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + scf.for %arg2 = %c0_index to %c32_index step %c1_index { + scf.for %arg3 = %c0_index to %c32_index step %c1_index { + %elem = memref.load %alloc_2[%arg0, %arg1, %arg2, %arg3] : memref<4x8x32x32xf32> + cpuruntime.printf "%f, " %elem : f32 + } + cpuruntime.printf "\n" + } + cpuruntime.printf "==================================\n" + } + } return } func.func @main() { call @simple_brgemm() : ()->() + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, + // CHECK: ================================== + // CHECK: 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, + // CHECK: ================================== + // CHECK: 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, + // CHECK: ================================== + // CHECK: 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, + // CHECK: ================================== + // CHECK: 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, + // CHECK: ================================== + // CHECK: 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, + // CHECK: ================================== + // CHECK: 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, + // CHECK: ================================== + // CHECK: 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, + // CHECK: ================================== + // CHECK: 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, + // CHECK: ================================== + // CHECK: 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, + // CHECK: ================================== + // CHECK: 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // CHECK: ================================== + // CHECK: 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // CHECK: ================================== cpuruntime.printf "BRGEMM DONE\n" return } diff --git a/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir b/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir new file mode 100644 index 000000000..b7f3ca6a7 --- /dev/null +++ b/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir @@ -0,0 +1,140 @@ +// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --cse --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-cpuruntime-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c32_index = arith.constant 32 : index + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %cst0f = arith.constant 0.000000e+00 : f32 + %cstn64f = arith.constant -64.000000e+00 : f32 + %cst1f = arith.constant 1.000000e+00 : f32 + %cst2f = arith.constant 2.000000e+00 : f32 + %cst3f = arith.constant 3.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + linalg.fill ins(%cst1f : f32) outs(%alloc : memref<4x16x32x32xf32>) + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + linalg.fill ins(%cst2f : f32) outs(%alloc_0 : memref<8x16x32x32xf32>) + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst3f : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst0f : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %arg0i = arith.index_castui %arg0 : index to i64 + %arg1i = arith.index_castui %arg1 : index to i64 + %argmulti = arith.muli %arg0i, %arg1i : i64 + %v = arith.uitofp %argmulti : i64 to f32 + %v1 = arith.mulf %v, %cstn64f : f32 + linalg.fill ins(%v1 : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %1 = arith.addf %in, %in_7 : f32 + linalg.yield %1 : f32 + } + %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.maximumf %in, %cst0f : f32 + linalg.yield %1 : f32 + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + } + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + scf.for %arg2 = %c0_index to %c32_index step %c1_index { + scf.for %arg3 = %c0_index to %c32_index step %c1_index { + %elem = memref.load %alloc_2[%arg0, %arg1, %arg2, %arg3] : memref<4x8x32x32xf32> + cpuruntime.printf "%f, " %elem : f32 + } + cpuruntime.printf "\n" + } + cpuruntime.printf "==================================\n" + } + } + return + } + + func.func @main() { + call @simple_brgemm() : ()->() + cpuruntime.printf "BRGEMM DONE\n" + return + } + + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, + // CHECK: ================================== + // CHECK: 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, + // CHECK: ================================== + // CHECK: 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, + // CHECK: ================================== + // CHECK: 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, + // CHECK: ================================== + // CHECK: 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, + // CHECK: ================================== + // CHECK: 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, + // CHECK: ================================== + // CHECK: 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, + // CHECK: ================================== + // CHECK: 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, + // CHECK: ================================== + // CHECK: 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, + // CHECK: ================================== + // CHECK: 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, + // CHECK: ================================== + // CHECK: 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // CHECK: ================================== + // CHECK: 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // CHECK: ================================== + // CHECK: BRGEMM DONE +}