From 8213a9aeec79226f81f29ad5178ea092a93d8959 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 29 May 2024 20:42:58 -0700 Subject: [PATCH 01/93] add microkernel dialect --- include/gc/CMakeLists.txt | 2 +- include/gc/Dialect/CMakeLists.txt | 2 +- include/gc/Dialect/Microkernel/CMakeLists.txt | 5 + .../Dialect/Microkernel/MicrokernelDialect.h | 2 + .../Dialect/Microkernel/MicrokernelDialect.td | 13 +- .../gc/Dialect/Microkernel/MicrokernelEnum.h | 18 + .../gc/Dialect/Microkernel/MicrokernelEnum.td | 26 ++ .../gc/Dialect/Microkernel/MicrokernelOps.h | 9 + .../gc/Dialect/Microkernel/MicrokernelOps.td | 97 +++- include/gc/Utils/StructuredOpMatcher.h | 429 ++++++++++++++++++ include/gc/Utils/ValueUtils.h | 45 ++ lib/gc/CMakeLists.txt | 3 +- lib/gc/Dialect/Microkernel/CMakeLists.txt | 4 +- .../Microkernel/MicrokernelDialect.cpp | 7 + .../Dialect/Microkernel/MicrokernelEnum.cpp | 15 + lib/gc/Dialect/Microkernel/MicrokernelOps.cpp | 364 ++++++++++++++- lib/gc/Utils/CMakeLists.txt | 7 + lib/gc/Utils/StructuredOpMatcher.cpp | 306 +++++++++++++ lib/gc/Utils/ValueUtils.cpp | 153 +++++++ src/gc-opt/gc-opt.cpp | 2 + 20 files changed, 1499 insertions(+), 10 deletions(-) create mode 100644 include/gc/Dialect/Microkernel/MicrokernelEnum.h create mode 100644 include/gc/Dialect/Microkernel/MicrokernelEnum.td create mode 100644 include/gc/Utils/StructuredOpMatcher.h create mode 100644 include/gc/Utils/ValueUtils.h create mode 100644 lib/gc/Dialect/Microkernel/MicrokernelEnum.cpp create mode 100644 lib/gc/Utils/CMakeLists.txt create mode 100644 lib/gc/Utils/StructuredOpMatcher.cpp create mode 100644 lib/gc/Utils/ValueUtils.cpp diff --git a/include/gc/CMakeLists.txt b/include/gc/CMakeLists.txt index db67942f3..557daa847 100644 --- a/include/gc/CMakeLists.txt +++ b/include/gc/CMakeLists.txt @@ -1,2 +1,2 @@ add_subdirectory(Dialect) -add_subdirectory(Transforms) \ No newline at end of file +add_subdirectory(Transforms) diff --git a/include/gc/Dialect/CMakeLists.txt b/include/gc/Dialect/CMakeLists.txt index 2867b7972..db17a6f99 100644 --- a/include/gc/Dialect/CMakeLists.txt +++ b/include/gc/Dialect/CMakeLists.txt @@ -1,4 +1,4 @@ add_subdirectory(CPURuntime) add_subdirectory(OneDNNGraph) add_subdirectory(Microkernel) -add_subdirectory(Linalgx) \ No newline at end of file +add_subdirectory(Linalgx) diff --git a/include/gc/Dialect/Microkernel/CMakeLists.txt b/include/gc/Dialect/Microkernel/CMakeLists.txt index 4d8f855e0..d9046919c 100644 --- a/include/gc/Dialect/Microkernel/CMakeLists.txt +++ b/include/gc/Dialect/Microkernel/CMakeLists.txt @@ -1,3 +1,8 @@ +set(LLVM_TARGET_DEFINITIONS MicrokernelEnum.td) +mlir_tablegen(MicrokernelEnum.h.inc -gen-enum-decls) +mlir_tablegen(MicrokernelEnum.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRMicrokernelAttrDefIncGen) + add_mlir_dialect(MicrokernelOps microkernel) add_mlir_doc(MicrokernelOps MicrokernelOps gc/Dialect/Microkernel/ -gen-op-doc) add_mlir_doc(MicrokernelDialect MicrokernelDialect gc/Dialect/Microkernel/ -gen-dialect-doc) diff --git a/include/gc/Dialect/Microkernel/MicrokernelDialect.h b/include/gc/Dialect/Microkernel/MicrokernelDialect.h index 71b368655..4cca70cf9 100644 --- a/include/gc/Dialect/Microkernel/MicrokernelDialect.h +++ b/include/gc/Dialect/Microkernel/MicrokernelDialect.h @@ -9,7 +9,9 @@ #ifndef GC_DIALECTS_MICROKERNELDIALECT_H #define GC_DIALECTS_MICROKERNELDIALECT_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "gc/Dialect/Microkernel/MicrokernelOpsDialect.h.inc" diff --git a/include/gc/Dialect/Microkernel/MicrokernelDialect.td b/include/gc/Dialect/Microkernel/MicrokernelDialect.td index 93c59d2bb..ef8e544df 100644 --- a/include/gc/Dialect/Microkernel/MicrokernelDialect.td +++ b/include/gc/Dialect/Microkernel/MicrokernelDialect.td @@ -15,15 +15,20 @@ include "mlir/IR/OpBase.td" // Microkernel dialect definition. //===----------------------------------------------------------------------===// -def MicrokernelDialect : Dialect { +def Microkernel_Dialect : Dialect { let name = "microkernel"; let summary = "A dialect for microkernel abstraction."; let description = [{ - The dialect wraps the BRGEMM API to set up the HW context etc. + This dialect contains wrappers for microkernel primitives like BRGEMM. }]; let cppNamespace = "::mlir::microkernel"; - - let useDefaultTypePrinterParser = 1; } +//===----------------------------------------------------------------------===// +// Base microkernel operation definition. +//===----------------------------------------------------------------------===// + +class Microkernel_Op traits = []> : + Op; + #endif // MICROKERNEL_DIALECT diff --git a/include/gc/Dialect/Microkernel/MicrokernelEnum.h b/include/gc/Dialect/Microkernel/MicrokernelEnum.h new file mode 100644 index 000000000..ba6505b67 --- /dev/null +++ b/include/gc/Dialect/Microkernel/MicrokernelEnum.h @@ -0,0 +1,18 @@ +//===- MicrokernelEnum.h - microkernel dialect enums ------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_DIALECTS_MICROKERNELENUM_H +#define GC_DIALECTS_MICROKERNELENUM_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/DialectImplementation.h" + +#define GET_ATTRDEF_CLASSES +#include "gc/Dialect/Microkernel/MicrokernelEnum.h.inc" + +#endif // GC_DIALECTS_MICROKERNELENUM_H diff --git a/include/gc/Dialect/Microkernel/MicrokernelEnum.td b/include/gc/Dialect/Microkernel/MicrokernelEnum.td new file mode 100644 index 000000000..3a4e4bad0 --- /dev/null +++ b/include/gc/Dialect/Microkernel/MicrokernelEnum.td @@ -0,0 +1,26 @@ +//===- MicrokernelEnum.td - microkernel dialect enum -------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MICROKERNEL_ENUM +#define MICROKERNEL_ENUM + +include "mlir/IR/EnumAttr.td" +include "gc/Dialect/Microkernel/MicrokernelDialect.td" + +def Microkernel_BrgemmFlags : I64EnumAttr< + "BrgemmFlags", "see: microkernel_brgemm_flags", + [ + I64EnumAttrCase<"NONE", 0, "none">, + I64EnumAttrCase<"BETA_0", 1, "beta_0">, + I64EnumAttrCase<"STRIDE", 2, "stride">, + I64EnumAttrCase<"LIST", 4, "list"> + ]> { + let cppNamespace = "::mlir::microkernel"; +} + +#endif // MICROKERNEL_ENUM diff --git a/include/gc/Dialect/Microkernel/MicrokernelOps.h b/include/gc/Dialect/Microkernel/MicrokernelOps.h index ca36f4c02..a478c1dee 100644 --- a/include/gc/Dialect/Microkernel/MicrokernelOps.h +++ b/include/gc/Dialect/Microkernel/MicrokernelOps.h @@ -9,7 +9,16 @@ #ifndef GC_DIALECTS_MICROKERNELOPS_H #define GC_DIALECTS_MICROKERNELOPS_H +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "gc/Dialect/Microkernel/MicrokernelDialect.h" +#include "gc/Dialect/Microkernel/MicrokernelEnum.h" #define GET_OP_CLASSES #include "gc/Dialect/Microkernel/MicrokernelOps.h.inc" diff --git a/include/gc/Dialect/Microkernel/MicrokernelOps.td b/include/gc/Dialect/Microkernel/MicrokernelOps.td index 7be8b8aed..d0d50d04e 100644 --- a/include/gc/Dialect/Microkernel/MicrokernelOps.td +++ b/include/gc/Dialect/Microkernel/MicrokernelOps.td @@ -10,5 +10,100 @@ #define MICROKERNEL_OPS include "MicrokernelDialect.td" +include "gc/Dialect/Microkernel/MicrokernelEnum.td" +include "mlir/Interfaces/SideEffectInterfaces.td" -#endif // MICROKERNEL_OPS \ No newline at end of file +class StaticMemRefRankOf allowedTypes, list ranks> : + Type.predicate, + HasAnyRankOfPred, HasStaticShapePred]>, + !interleave(!foreach(rank, ranks, rank # "D"), "/") # " static " # + MemRefOf.summary, "::mlir::MemRefType">; + +def Microkernel_BrgemmDispatchOp : Microkernel_Op<"brgemm.dispatch", [Pure]> { + let summary = "JIT the brgemm microkernel given the parameters"; + let description = [{ + The operation has the following arguments: 1) m, n, k, lda, ldb, ldc, stride_a and stride_b. + Inputs is a dense attribute of I64 elements. 2) flags carry information on + the different flags that can be used for brgemm like whether beta == 0 or strided batch. For + more details, see: `Microkernel_BrgemmFlags`. 3) data_types of operand A & B. + Outpus is the id of JITed kernel. + }]; + + let arguments = (ins + ConfinedAttr]>:$inputs, + TypedArrayAttrBase:$flags, + TypedArrayAttrBase:$data_type); + + let results = (outs I64:$results); + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def Microkernel_BrgemmPrologueOp : Microkernel_Op<"brgemm.prologue"> { + let summary = "Prologue before executing the JITed brgemm " + "microkernel, and the context is considered core-level"; + let description = [{ + The operation has the following arguments: Input is the id of JITed kernel. + There is no output. + }]; + + let arguments = (ins I64:$inputs); + + let assemblyFormat = [{ + `(` $inputs `)` + attr-dict `:` functional-type($inputs, results) + }]; +} + +def Microkernel_BrgemmEpilogueOp : Microkernel_Op<"brgemm.epilogue"> { + let summary = "Epilogue after executing the JITed brgemm microkernel"; + let description = [{ + The operation has the following arguments: Input is the id of JITed kernel. + There is no output. + }]; + + let arguments = (ins I64:$inputs); + + let assemblyFormat = [{ + `(` $inputs `)` + attr-dict `:` functional-type($inputs, results) + }]; +} + +def BrgemmMemRef : AnyTypeOf<[StaticMemRefRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>, I64]>; + +def Microkernel_BrgemmOp : Microkernel_Op<"brgemm"> { + let summary = "execute the JITed brgemm kernel."; + let description = [{ + The operation has the following arguments: + 1) For stride mode, id of JITed kernel, MemRef of operand A/B/C, and the batch size; + 2) For addr mode, plus the length of addr list at the end. + There is no output. + }]; + + let arguments = (ins Variadic:$inputs); + + let assemblyFormat = [{ + `(` $inputs `)` + attr-dict `:` functional-type($inputs, results) + }]; + + let extraClassDeclaration = [{ + Value getDispatch() { return getInputs()[0]; } + + Value getOperandA() { return getInputs()[1]; } + + Value getOperandB() { return getInputs()[2]; } + + Value getOutput() { return getInputs()[3]; } + + Value getBatch() { return getInputs()[4]; } + + Value getAddrLen() { return getInputs()[5]; } + }]; + + let hasVerifier = 1; +} + +#endif // MICROKERNEL_OPS diff --git a/include/gc/Utils/StructuredOpMatcher.h b/include/gc/Utils/StructuredOpMatcher.h new file mode 100644 index 000000000..e0aa9acc7 --- /dev/null +++ b/include/gc/Utils/StructuredOpMatcher.h @@ -0,0 +1,429 @@ +//===- StructuredOpMatcher.h - Utils for structuterd Op ----------*-C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +/* + * This code is borrowed from tpp-mlir: + * https://github.com/plaidml/tpp-mlir/tree/main/include/TPP/IR/StructuredOpMatcher.h + */ + +#ifndef GC_UTILS_STRUCTUREDOPMATCHER_H +#define GC_UTILS_STRUCTUREDOPMATCHER_H + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include +#include + +namespace mlir { +class Operation; +namespace gcext { +namespace utils { +namespace structured_match { + +// Base class for the matcher predicates selection tag. +struct MatchSelector { + MatchSelector() = delete; + size_t getLowerBound() const { return lowerBound; }; + size_t getUpperBound() const { return upperBound; }; + +protected: + explicit MatchSelector(size_t lowerBound, size_t upperBound) + : lowerBound(lowerBound), upperBound(upperBound) { + assert(upperBound > lowerBound); + } + +private: + const size_t lowerBound; + const size_t upperBound; +}; + +// Selector which specifies that predicate should apply on all values. +struct MatchAll : public MatchSelector { + MatchAll() : MatchSelector(0, std::numeric_limits::max()) {} +}; + +// Selector which specifies that predicate should apply only on one value at +// the position `idx`. +struct MatchOne : public MatchSelector { + MatchOne() = delete; + MatchOne(size_t idx) : MatchSelector(idx, idx + 1) {} +}; + +// Selector which specifies that predicate should apply only on range of values +// at positions from `lowerBound` up to - but not including - `upperBound`. +struct MatchRange : public MatchSelector { + MatchRange() = delete; + MatchRange(size_t lowerBound, size_t upperBound) + : MatchSelector(lowerBound, upperBound) {} +}; + +// Callable object to check if the number of loops in `op` satisfies `fun`. +struct NumOfLoops { + NumOfLoops() = delete; + explicit NumOfLoops(std::function fun) : fun(std::move(fun)){}; + + bool operator()(Operation *op) const { + if (auto linalgOp = dyn_cast_or_null(op)) { + auto numberOfLoops = linalgOp.getNumLoops(); + return fun(numberOfLoops); + } + return false; + } + std::function fun; +}; + +// Callable object to check if the `operand` of `op` has a map that satisfies +// `fun`. +struct HasMap { + HasMap() = delete; + explicit HasMap(std::function fun) : fun(std::move(fun)){}; + explicit HasMap(std::function fun, AffineMap *ptrMap) + : fun(std::move(fun)), ptrMap(ptrMap){}; + + bool operator()(OpOperand *operand, Operation *op) const { + if (auto linalgOp = dyn_cast_or_null(op)) { + auto map = linalgOp.getMatchingIndexingMap(operand); + assert(fun && "must be a callable target"); + if (!fun(map)) + return false; + if (ptrMap) + *ptrMap = std::move(map); + return true; + } + return false; + } + std::function fun; + AffineMap *ptrMap = nullptr; +}; + +// Callble object to verify if `map` is a projected permutation map. +// We require the dimensions to be in sorted order this avoid filtering +// projected permutation without broadcasting semantics, for example +// affine_map<(d0, d1) -> (d1, d0)> is rejected. +struct ProjectedPermutation { + ProjectedPermutation() = default; + + bool operator()(AffineMap map) const { + if (map.getNumSymbols() > 0 || map.getNumResults() > map.getNumInputs()) + return false; + + SmallVector seen(map.getNumInputs(), false); + SmallVector pos; + for (auto expr : map.getResults()) { + if (auto dim = dyn_cast(expr)) { + if (seen[dim.getPosition()]) + return false; + seen[dim.getPosition()] = true; + pos.push_back(dim.getPosition()); + } else if (auto constExpr = dyn_cast(expr)) { + if (constExpr.getValue() != 0) + return false; + } else + return false; + } + return llvm::is_sorted(pos); + } +}; + +// Callable object to verify if `map` is an identity map. +struct Identity { + Identity() = default; + + bool operator()(AffineMap map) const { return map.isIdentity(); } +}; + +// Callable object to capture any map. +struct Any { + Any() = default; + + bool operator()(AffineMap map) const { return true; } +}; + +// Callable object to verify if `operand` has static shape. +struct HasStaticShape { + HasStaticShape() = default; + HasStaticShape(SmallVectorImpl *shape) : shape(shape){}; + + bool operator()(OpOperand *operand, Operation *op) const { + auto operandType = operand->get().getType(); + if (auto shapedType = dyn_cast_or_null(operandType)) { + if (!shapedType.hasStaticShape()) + return false; + if (shape) { + for (int64_t shapeOnDim : shapedType.getShape()) + shape->push_back(shapeOnDim); + } + } + return true; + } + SmallVectorImpl *shape = nullptr; +}; + +// Callable object to verify if `operand` has static strides. +// If `operand` is a tensor type or a scalar, return true. +struct HasStaticStrides { + HasStaticStrides() = default; + HasStaticStrides(SmallVector *strides) : strides(strides){}; + + bool operator()(OpOperand *operand, Operation *op) const { + auto operandType = operand->get().getType(); + SmallVector strides; + if (auto memRefType = dyn_cast_or_null(operandType)) { + int64_t offset; + if (failed(getStridesAndOffset(memRefType, strides, offset))) + return false; + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) { + return false; + } + if (this->strides) + this->strides->append(strides.begin(), strides.end()); + } + return true; + } + SmallVectorImpl *strides = nullptr; +}; + +// Callable object to verify `operand` to have a rank in `ranks`. +struct HasRank { + HasRank() = delete; + explicit HasRank(std::initializer_list ranks) : ranks(ranks){}; + + bool operator()(OpOperand *operand, Operation *op) const { + auto operandType = operand->get().getType(); + if (!isa(operandType)) + return llvm::is_contained(ranks, HasRank::SCALAR); + int64_t rank = cast(operandType).getRank(); + return llvm::any_of( + ranks, [=](int64_t expectedRank) { return expectedRank == rank; }); + } + + // There are multiple way to represent a scalar: f32, tensor. + // SCALAR means f32. + static constexpr int64_t SCALAR = -1; + std::vector ranks; +}; + +// Callable object to verify `operand` to have an element type `T`. +template struct HasElementType { + bool operator()(OpOperand *operand, Operation *op) const { + auto operandType = getElementTypeOrSelf(operand->get().getType()); + return isa(operandType); + } +}; + +// Callable object to check if the input is equal to specified `value`. +template struct EqualsTo { + EqualsTo() = delete; + explicit EqualsTo(T value) : value(value){}; + + const T value; + + bool operator()(T value) const { return value == this->value; } +}; +template EqualsTo(T) -> EqualsTo; + +// Callable object to check if the input is less than or equal to specified +// `value`. +struct LessThanOrEqualTo { + LessThanOrEqualTo() = delete; + explicit LessThanOrEqualTo(size_t value) : value(value){}; + const size_t value; + + bool operator()(size_t value) const { return value <= this->value; } +}; + +// Callable object to check if the input is greater than or equal to specified +// `value`. +struct GreaterThanOrEqualTo { + GreaterThanOrEqualTo() = delete; + explicit GreaterThanOrEqualTo(size_t value) : value(value){}; + const size_t value; + + bool operator()(size_t value) const { return value >= this->value; } +}; + +// Callable object to validate number of init operands for `op`. +struct NumDpsInits { + NumDpsInits() = delete; + explicit NumDpsInits(std::function fun) : fun(std::move(fun)){}; + + bool operator()(Operation *op) const { + if (auto linalgOp = dyn_cast_or_null(op)) + return fun(linalgOp.getNumDpsInits()); + return false; + } + + std::function fun; +}; + +// Callable object to check the number of affine map for `op`. +struct NumAffineMaps { + NumAffineMaps() = delete; + explicit NumAffineMaps(std::function fun) + : fun(std::move(fun)){}; + + bool operator()(Operation *op) const { + if (auto linalgOp = dyn_cast_or_null(op)) + return fun(linalgOp.getIndexingMapsArray().size()); + return false; + } + + std::function fun; +}; + +// Callable object to validate number of input operands for `op`. +struct NumDpsInputs { + NumDpsInputs() = delete; + explicit NumDpsInputs(std::function fun) + : fun(std::move(fun)){}; + + bool operator()(Operation *op) { + if (auto linalgOp = dyn_cast_or_null(op)) + return fun(linalgOp.getNumDpsInputs()); + return false; + } + + std::function fun; +}; + +// Callable object to validate number of regions for `op`. +struct NumRegions { + NumRegions() = delete; + explicit NumRegions(std::function fun) : fun(std::move(fun)){}; + + bool operator()(Operation *op) const { + if (auto linalgOp = dyn_cast_or_null(op)) + return fun(linalgOp->getNumRegions()); + return false; + } + + std::function fun; +}; + +// Logical OR between two predicates. +struct _OR { + _OR() = delete; + _OR(std::function lhs, std::function rhs) + : lhs(std::move(lhs)), rhs(std::move(rhs)) {} + + bool operator()(size_t num) { return (lhs(num) || rhs(num)); } + + std::function lhs; + std::function rhs; +}; + +// Callable object to check if `op` adheres to a given property passed +// as an std::function object. +struct VerifyOpProperty { + VerifyOpProperty() = delete; + explicit VerifyOpProperty(std::function fun) + : fun(std::move(fun)){}; + + bool operator()(Operation *op) { + if (succeeded(fun(op))) + return true; + return false; + } + + std::function fun; +}; + +// Work-around for template specialization. +struct WithSingleOpImpl { + WithSingleOpImpl() = default; + + bool withSingleOpImpl(StringRef, Region *, Operation *, + SmallVectorImpl *); +}; + +// Callable object to check the `op` region for a single scalar operation OpTy. +template struct WithSingleOp { + WithSingleOp() : WithSingleOp(nullptr){}; + WithSingleOp(SmallVectorImpl *captures) : captures(captures){}; + + bool operator()(Region *region, Operation *op) { + return WithSingleOpImpl().withSingleOpImpl(OpTy::getOperationName(), region, + op, captures); + } + +private: + SmallVectorImpl *captures; +}; + +// Implemenation to allow definition in cpp file +using TypeCheckFunc = std::function; +bool withOpChainImpl(Region *region, Operation *op, SmallVectorImpl *, + SmallVectorImpl &); + +// Callable object to check the region for a chain of operations. +template struct WithOpChain { + WithOpChain() : WithOpChain(nullptr){}; + WithOpChain(SmallVectorImpl *captures) : captures(captures) { + (typeChecks.push_back([](Operation *op) { return isa(op); }), ...); + }; + + bool operator()(Region *region, Operation *op) { + return withOpChainImpl(region, op, captures, typeChecks); + } + +private: + SmallVectorImpl *captures; + SmallVector typeChecks; +}; + +class StructuredOpMatcher { + using PredicateFn = std::function; + +public: + StructuredOpMatcher() = default; + + StructuredOpMatcher(PredicateFn &&firstPredicate) { + predicates.push_back(std::move(firstPredicate)); + } + + template static StructuredOpMatcher make() { + return StructuredOpMatcher( + [](linalg::LinalgOp op) { return isa(op.getOperation()); }); + } + + // Match given `op` using stored predicates. + bool match(Operation *op); + + // Predicates on operation. + StructuredOpMatcher &operation(std::function); + + // Predicate on OpOperands. + StructuredOpMatcher &input(MatchSelector range, + std::function); + + // Predicates on OpOperands. + StructuredOpMatcher &output(MatchSelector range, + std::function); + + // Predicates on Iterators. + StructuredOpMatcher &dim(MatchSelector range, + SmallVector kinds); + StructuredOpMatcher &dim(MatchSelector range, mlir::utils::IteratorType kind); + + // Predicates on region. + StructuredOpMatcher ®ion(MatchSelector range, + std::function); + +private: + llvm::SmallVector predicates; +}; + +} // namespace structured_match +} // namespace utils +} // namespace gcext +} // namespace mlir + +#endif // GC_UTILS_STRUCTUREDOPMATCHER_H diff --git a/include/gc/Utils/ValueUtils.h b/include/gc/Utils/ValueUtils.h new file mode 100644 index 000000000..5d940b000 --- /dev/null +++ b/include/gc/Utils/ValueUtils.h @@ -0,0 +1,45 @@ +//===- ValueUtils.h - Utils for handling mlir::Value -------------*-C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +/* + * This code is borrowed from tpp-mlir: + * https://github.com/plaidml/tpp-mlir/tree/main/include/TPP/Transforms/Utils/ValueUtils.h + */ + +#ifndef GC_UTILS_VALUEUTILS_H +#define GC_UTILS_VALUEUTILS_H + +namespace mlir { +class Value; +class OpBuilder; +namespace gcext { +namespace utils { + +using namespace mlir; + +// Returns true if the value is a constant float or integer. +bool isValConstZero(Value val); + +// Returns true if the op defining `val` represents a zero filled tensor. +bool isZeroTensor(Value val); + +// Returns the strides of `val`. The method returns something usefull +// only if the `val` type is a strided memref and the strides are statically +// known. +FailureOr> getStaticStrides(Value val); + +// Return the offset and ptr for `val`. Assert if `val` +// is not a memref. +std::pair getPtrAndOffset(OpBuilder &builder, Value val, + Location loc); + +} // namespace utils +} // namespace gcext +} // namespace mlir + +#endif diff --git a/lib/gc/CMakeLists.txt b/lib/gc/CMakeLists.txt index 03f7023b8..0e5f94288 100644 --- a/lib/gc/CMakeLists.txt +++ b/lib/gc/CMakeLists.txt @@ -7,4 +7,5 @@ include(functions) add_subdirectory(CAPI) add_subdirectory(Dialect) add_subdirectory(Transforms) -add_subdirectory(ExecutionEngine) \ No newline at end of file +add_subdirectory(ExecutionEngine) +add_subdirectory(Utils) diff --git a/lib/gc/Dialect/Microkernel/CMakeLists.txt b/lib/gc/Dialect/Microkernel/CMakeLists.txt index a3eaa8d3d..f67309940 100644 --- a/lib/gc/Dialect/Microkernel/CMakeLists.txt +++ b/lib/gc/Dialect/Microkernel/CMakeLists.txt @@ -1,6 +1,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR) add_mlir_dialect_library(MLIRMicrokernel + MicrokernelEnum.cpp MicrokernelDialect.cpp MicrokernelOps.cpp @@ -12,5 +13,6 @@ add_mlir_dialect_library(MLIRMicrokernel LINK_LIBS PUBLIC ${MLIR_LINK_COMPONENTS} + GCMLIRUtils ) -set_property(GLOBAL APPEND PROPERTY GC_DIALECT_LIBS MLIRMicrokernel) \ No newline at end of file +set_property(GLOBAL APPEND PROPERTY GC_DIALECT_LIBS MLIRMicrokernel) diff --git a/lib/gc/Dialect/Microkernel/MicrokernelDialect.cpp b/lib/gc/Dialect/Microkernel/MicrokernelDialect.cpp index d4857e87a..0b80abc18 100644 --- a/lib/gc/Dialect/Microkernel/MicrokernelDialect.cpp +++ b/lib/gc/Dialect/Microkernel/MicrokernelDialect.cpp @@ -7,11 +7,18 @@ //===----------------------------------------------------------------------===// #include "gc/Dialect/Microkernel/MicrokernelDialect.h" +#include "gc/Dialect/Microkernel/MicrokernelEnum.h" #include "gc/Dialect/Microkernel/MicrokernelOps.h" using namespace mlir; using namespace mlir::microkernel; +#include "gc/Dialect/Microkernel/MicrokernelOpsDialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// Microkernel dialect. +//===----------------------------------------------------------------------===// + void MicrokernelDialect::initialize() { addOperations< #define GET_OP_LIST diff --git a/lib/gc/Dialect/Microkernel/MicrokernelEnum.cpp b/lib/gc/Dialect/Microkernel/MicrokernelEnum.cpp new file mode 100644 index 000000000..11b9723ad --- /dev/null +++ b/lib/gc/Dialect/Microkernel/MicrokernelEnum.cpp @@ -0,0 +1,15 @@ +//===-- MicrokernelEnum.cpp - microkernel dialect enum ----------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Dialect/Microkernel/MicrokernelEnum.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::microkernel; + +#include "gc/Dialect/Microkernel/MicrokernelEnum.cpp.inc" diff --git a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp index 45bc5719f..bd730a22c 100644 --- a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp +++ b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp @@ -8,7 +8,369 @@ #include "gc/Dialect/Microkernel/MicrokernelOps.h" #include "gc/Dialect/Microkernel/MicrokernelDialect.h" -#include "mlir/IR/OpImplementation.h" +#include #define GET_OP_CLASSES #include "gc/Dialect/Microkernel/MicrokernelOps.cpp.inc" + +#include + +namespace mlir { + +namespace microkernel { + +constexpr std::string_view INPUTS = "inputs"; +constexpr std::string_view DATA_TYPE = "data_type"; +constexpr std::string_view FLAGS_NAME = "flags"; + +template +static void printInputImpl(OpAsmPrinter &printer, OpTy op) { + printer << " [" << op.getInputs() << ']'; +} + +template +static void printFlagsImpl(OpAsmPrinter &printer, + const std::function &fn, + const std::string_view &flagsName) { + printer << " " << flagsName << " = ("; + llvm::interleaveComma(fn(), printer, [&](auto &flag) { + printer << stringifyEnum(cast(flag).getValue()); + }); + printer << ") "; +} + +template +static void printDataTypeImpl(OpAsmPrinter &printer, OpTy op) { + printer << DATA_TYPE << " = ("; + auto dataTypes = op.getDataType(); + for (size_t idx = 0; idx < dataTypes.size(); idx++) { + printer.printAttribute(dataTypes[idx]); + if (idx != dataTypes.size() - 1) { + printer << ", "; + } + } + printer << ") "; +} + +template +static ParseResult parseEnum(EnumClass &value, OpAsmParser &parser) { + StringRef flag; + auto loc = parser.getCurrentLocation(); + if (parser.parseKeyword(&flag)) + return failure(); + auto flagAttr = symbolizeEnum(flag); + if (!flagAttr) + return parser.emitError(loc, "invalid enum ") << flag; + value = *flagAttr; + return success(); +} + +static ParseResult parseOperandImpl(OpAsmParser &parser, + OperationState &result) { + DenseI64ArrayAttr kindAttr; + if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, INPUTS, + result.attributes)) { + return failure(); + } + auto &builder = parser.getBuilder(); + result.addTypes(builder.getIntegerType(64)); + return success(); +} + +static ParseResult parseDataTypeImpl(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + if (parser.parseKeyword(DATA_TYPE) || parser.parseEqual() || + parser.parseLParen()) + return failure(); + SmallVector dataTypes; + auto parseTypeAttr = [&]() -> ParseResult { + Attribute dataType; + if (parser.parseAttribute(dataType)) + return failure(); + if (!isa(dataType)) + return failure(); + dataTypes.push_back(dataType); + return success(); + }; + if (parser.parseCommaSeparatedList(parseTypeAttr) || parser.parseRParen()) + return failure(); + + result.addAttribute(DATA_TYPE, builder.getArrayAttr(dataTypes)); + return success(); +} + +template +static ParseResult parseFlagsImpl(OpAsmParser &parser, OperationState &result, + const std::string_view &flagsName) { + auto &builder = parser.getBuilder(); + if (parser.parseKeyword(flagsName) || parser.parseEqual() || + parser.parseLParen()) + return failure(); + + SmallVector flags; + auto parseFlags = [&]() -> ParseResult { + FLAGS flag; + if (parseEnum(flag, parser)) + return failure(); + flags.push_back(builder.getI64IntegerAttr(static_cast(flag))); + return success(); + }; + if (parser.parseCommaSeparatedList(parseFlags) || parser.parseRParen()) + return failure(); + result.addAttribute(flagsName, builder.getArrayAttr(flags)); + return success(); +} + +template +static LogicalResult +verifyUniquenessAndConsistency(ArrayAttr flags, Operation *op, + const std::string_view &flagsName) { + SmallVector flagsAsInt; + for (auto flag : flags) + flagsAsInt.push_back(cast(flag).getInt()); + + // check uniqueness + std::sort(flagsAsInt.begin(), flagsAsInt.end()); + auto *it = std::unique(flagsAsInt.begin(), flagsAsInt.end()); + if (it != flagsAsInt.end()) + return op->emitOpError() << "expected " << flagsName << " to be unique"; + // none flag conflicts with all the others + if (llvm::is_contained(flagsAsInt, static_cast(FLAGS::NONE)) && + flagsAsInt.size() != 1) { + return op->emitOpError() + << "'none' " << flagsName << " conflicts with others"; + } + return success(); +} + +///////////////////////////////////////////////////// +// Start of BrgemmDispatchOp + +void BrgemmDispatchOp::print(OpAsmPrinter &printer) { + printInputImpl(printer, *this); + auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); }; + printFlagsImpl(printer, getOpFlags, FLAGS_NAME); + printDataTypeImpl(printer, *this); +} + +ParseResult BrgemmDispatchOp::parse(OpAsmParser &parser, + OperationState &result) { + if (failed(parseOperandImpl(parser, result)) || + failed(parseFlagsImpl(parser, result, FLAGS_NAME))) + return failure(); + if (failed(parseDataTypeImpl(parser, result))) + return failure(); + return parser.parseOptionalAttrDict(result.attributes); +} + +static LogicalResult verifyBrgemmDataTypes(ArrayAttr dtypes, + BrgemmDispatchOp op) { + if (dtypes.size() != 2) { + return op->emitOpError() << "data types size should be 2"; + } + + auto context = op.getContext(); + +#define FTAttr(t) TypeAttr::get(FloatType::get##t(context)) +#define ITAttr(s, w) TypeAttr::get(IntegerType::get(context, w, IntegerType::s)) + SmallVector> validDataTypes = { + {FTAttr(F32), FTAttr(F32)}, + {FTAttr(BF16), FTAttr(BF16)}, + {ITAttr(Unsigned, 8), ITAttr(Signed, 8)}, + {ITAttr(Signed, 8), ITAttr(Unsigned, 8)}, + {ITAttr(Unsigned, 8), ITAttr(Unsigned, 8)}, + {ITAttr(Signed, 8), ITAttr(Signed, 8)}}; +#undef FTAttr +#undef ITAttr + if (!llvm::any_of(validDataTypes, + [=](std::pair type_pair) { + return type_pair.first == dtypes[0] || + type_pair.second == dtypes[1]; + })) { + return op->emitOpError() << "invalid data type pair"; + } + + return success(); +} + +static LogicalResult verifyBrgemmFlags(ArrayAttr flags, BrgemmDispatchOp op, + const std::string_view &flagsName) { + // Verify flags. + if (failed(verifyUniquenessAndConsistency(flags, op, flagsName))) + return failure(); + + bool strideSet = false; + bool listSet = false; + for (auto flag : flags) { + if (cast(flag).getValue() == BrgemmFlags::STRIDE) { + strideSet = true; + } + if (cast(flag).getValue() == BrgemmFlags::LIST) { + listSet = true; + } + } + // VNNI flags must be specified only for bf16 type + if (strideSet && listSet) { + return op->emitOpError() + << "stride and addr flags conflict with each other"; + } + + return success(); +} + +LogicalResult BrgemmDispatchOp::verify() { + BrgemmDispatchOp &op = *this; + // 'inputs' = [m, n, k, lda, ldb, ldc, stride_a, stride_b] for BRGEMM. + size_t expected = 8; + size_t numInputs = op.getInputs().size(); + if (numInputs != expected) { + return op.emitOpError() + << "expect " << expected << " args but got: " << numInputs; + } + // Verify data types + if (failed(verifyBrgemmDataTypes(op.getDataType(), op))) { + return failure(); + } + + // Verify leading dims. + ArrayRef inputs = op.getInputs(); + int64_t n = inputs[1]; + int64_t k = inputs[2]; + int64_t lda = inputs[3]; + int64_t ldb = inputs[4]; + int64_t ldc = inputs[5]; + if (lda < k) + return op.emitOpError() << "expect lda to be >= of dimension k\n"; + if (ldb < n) + return op.emitOpError() << "expect ldb to be >= of dimension n\n"; + if (ldc < n) + return op.emitOpError() << "expect ldc to be >= of dimension n\n"; + + // Verify dispatch flags. + return verifyBrgemmFlags(op.getFlags(), op, FLAGS_NAME); +} + +///////////////////////////////////////////////////// +// Start of BrgemmOp + +// TODO(haixin): could use compiler-wide VNNI utils? +static bool isInVnniLayout(MemRefType memref) { + if (!memref.getElementType().isBF16() && + !memref.getElementType().isSignedInteger(8) && + !memref.getElementType().isUnsignedInteger(8)) { + return false; + } + + auto blockingFactor = 0; + if (memref.getElementType().isBF16()) { + blockingFactor = 2; + } else if (memref.getElementType().isSignedInteger(8) || + memref.getElementType().isUnsignedInteger(8)) { + blockingFactor = 4; + } + return memref.getShape().back() == blockingFactor; +} + +static bool isTypeCompatible(Type outType, Type operandAType, + Type operandBType) { + if (!outType.isF32() && !outType.isSignedInteger(32)) { + return false; + } + if (outType.isF32()) { + if (!(operandAType.isF32() && operandBType.isF32()) && + !(operandAType.isBF16() && operandBType.isBF16())) { + return false; + } + } + if (outType.isSignedInteger(32)) { + if (!(operandAType.isSignedInteger(8) || + operandAType.isUnsignedInteger(8)) && + (operandBType.isSignedInteger(8) || + operandBType.isUnsignedInteger(8))) { + return false; + } + } + return true; +} + +LogicalResult BrgemmOp::verify() { + BrgemmOp &brgemmOp = *this; + + SmallVector inputs = brgemmOp.getInputs(); + // inputs for BRGEMM: kernel id, A memref, B memref, C memref, batch_size, + // addr_len + if (inputs.size() != 6) { + return brgemmOp.emitOpError() << "expect 6" + << " inputs but got " << inputs.size(); + } + // Verify the dispatch to be an i64. + Value dispatch = brgemmOp.getDispatch(); + if (!dispatch.getType().isInteger(64)) { + return brgemmOp.emitOpError() + << "expect an i64 but got " << dispatch.getType() + << " for operand 0 (dispatch)"; + } + + // Verify the compatibility of memref types + SmallVector memrefOperands = { + brgemmOp.getOperandA(), brgemmOp.getOperandB(), brgemmOp.getOutput()}; + SmallVector typeOperands = { + getElementTypeOrSelf(memrefOperands[0].getType()), + getElementTypeOrSelf(memrefOperands[1].getType()), + getElementTypeOrSelf(memrefOperands[2].getType())}; + if (!isTypeCompatible(typeOperands[2], typeOperands[0], typeOperands[1])) { + return brgemmOp.emitOpError() + << "operands types: " << typeOperands[0] << " X " << typeOperands[1] + << " -> " << typeOperands[2] << " are imcompatible"; + } + + // Verify the rank of the shaped operands. + for (size_t idx = 0; idx < memrefOperands.size(); idx++) { + size_t actualIdx = idx + 1 /*skip dispatch*/; + auto memref = dyn_cast(memrefOperands[idx].getType()); + // Output memref. Must be of rank 2. + if (idx == 2 && memref.getRank() != 2) { + return brgemmOp.emitOpError() + << "expect a 2d layout for operand: " << actualIdx; + } + // Input A memref. Must be of rank 3. + if (idx == 0 && memref.getRank() != 3) { + return brgemmOp.emitOpError() + << "expect a 3d memref for operand: " << actualIdx; + } + // Input B memref. Must be in VNNI layout with rank 4 for non-F32. + if (idx == 1) { + auto dtype_B = typeOperands[idx]; + if (!dtype_B.isF32()) { + if (memref.getRank() != 4 && !isInVnniLayout(memref)) { + return brgemmOp.emitOpError() + << "expect a 4d VNNI memref for non-F32 operand: " + << actualIdx; + } + } else { + if (memref.getRank() != 3) { + return brgemmOp.emitOpError() + << "expect a 3d memref for F32 operand: " << actualIdx; + } + } + } + } + + // Verify the batch and addrLen to be i64. + Value batch = brgemmOp.getBatch(); + if (!batch.getType().isInteger(64)) { + return brgemmOp.emitOpError() << "expect an i64 but got " << batch.getType() + << " for operand 4 (batch)"; + } + Value addrLen = brgemmOp.getAddrLen(); + if (!addrLen.getType().isInteger(64)) { + return brgemmOp.emitOpError() + << "expect an i64 but got " << addrLen.getType() + << " for operand 5 (addrLen)"; + } + return success(); +} + +} // namespace microkernel +} // namespace mlir diff --git a/lib/gc/Utils/CMakeLists.txt b/lib/gc/Utils/CMakeLists.txt new file mode 100644 index 000000000..007d91710 --- /dev/null +++ b/lib/gc/Utils/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_library(GCMLIRUtils + StructuredOpMatcher.cpp + ValueUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/gc/Utils +) diff --git a/lib/gc/Utils/StructuredOpMatcher.cpp b/lib/gc/Utils/StructuredOpMatcher.cpp new file mode 100644 index 000000000..e5ba53141 --- /dev/null +++ b/lib/gc/Utils/StructuredOpMatcher.cpp @@ -0,0 +1,306 @@ +//===-- StructuredOpMatcher.cpp ---------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +/* + * This code is borrowed from tpp-mlir: + * https://github.com/plaidml/tpp-mlir/tree/main/lib/TPP/IR/StructuredOpMatcher.cpp + */ + +#include "gc/Utils/StructuredOpMatcher.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "structured-matchers" + +namespace mlir { +namespace gcext { +namespace utils { + +// Entry point. +bool structured_match::StructuredOpMatcher::match(Operation *op) { + auto linalgOp = dyn_cast_or_null(op); + if (!linalgOp) + return false; + LLVM_DEBUG(llvm::dbgs() << "Running matcher on: " << *op << "\n"); + + for (auto [idx, predicate] : llvm::enumerate(predicates)) { + if (!predicate(linalgOp)) { + LLVM_DEBUG(llvm::dbgs() << "Exit on predicate: " << idx << "\n"); + return false; + } + } + return true; +} + +//===---------------------------------------------------------------------===// +// Operation predicates. +//===---------------------------------------------------------------------===// + +structured_match::StructuredOpMatcher & +structured_match::StructuredOpMatcher::operation( + std::function fun) { + predicates.push_back( + [=](linalg::LinalgOp linalgOp) -> bool { return fun(linalgOp); }); + return *this; +} + +//===---------------------------------------------------------------------===// +// Operand predicates - input. +//===---------------------------------------------------------------------===// + +structured_match::StructuredOpMatcher & +structured_match::StructuredOpMatcher::input( + MatchSelector range, + std::function fun) { + predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { + auto operands = linalgOp.getDpsInputOperands(); + size_t upperBound = range.getUpperBound(); + size_t lowerBound = range.getLowerBound(); + if (upperBound == std::numeric_limits::max()) + upperBound = operands.size(); + + for (auto idx : + llvm::to_vector(llvm::seq(lowerBound, upperBound))) { + if (!fun(operands[idx], linalgOp.getOperation())) + return false; + } + return true; + }); + return *this; +} + +//===---------------------------------------------------------------------===// +// Operand predicates - output. +//===---------------------------------------------------------------------===// + +structured_match::StructuredOpMatcher & +structured_match::StructuredOpMatcher::output( + MatchSelector range, + std::function fun) { + predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { + auto operands = linalgOp.getDpsInitsMutable(); + size_t upperBound = range.getUpperBound(); + size_t lowerBound = range.getLowerBound(); + if (upperBound == std::numeric_limits::max()) + upperBound = operands.size(); + + for (auto idx : + llvm::to_vector(llvm::seq(lowerBound, upperBound))) { + if (!fun(&operands[idx], linalgOp.getOperation())) + return false; + } + return true; + }); + return *this; +} + +//===---------------------------------------------------------------------===// +// Dim predicates. +//===---------------------------------------------------------------------===// + +structured_match::StructuredOpMatcher & +structured_match::StructuredOpMatcher::dim( + MatchSelector range, SmallVector kinds) { + predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { + size_t upperBound = range.getUpperBound(); + size_t lowerBound = range.getLowerBound(); + if (upperBound == std::numeric_limits::max()) + upperBound = kinds.size(); + size_t sizeRange = upperBound - lowerBound; + + auto iteratorTypes = linalgOp.getIteratorTypesArray(); + if (iteratorTypes.size() != sizeRange) + return false; + + // Reverse iterators to have the innermost one at index 0. + std::reverse(iteratorTypes.begin(), iteratorTypes.end()); + for (auto [idx, rangeIdx] : + llvm::enumerate(llvm::seq(lowerBound, upperBound))) { + if (iteratorTypes[rangeIdx] != kinds[idx]) + return false; + } + return true; + }); + return *this; +} + +structured_match::StructuredOpMatcher & +structured_match::StructuredOpMatcher::dim(MatchSelector range, + mlir::utils::IteratorType kind) { + predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { + auto iteratorTypes = linalgOp.getIteratorTypesArray(); + size_t upperBound = range.getUpperBound(); + size_t lowerBound = range.getLowerBound(); + if (upperBound == std::numeric_limits::max()) + upperBound = iteratorTypes.size(); + + for (auto rangeIdx = lowerBound; rangeIdx < upperBound; rangeIdx++) { + if (iteratorTypes[rangeIdx] != kind) + return false; + } + return true; + }); + return *this; +} + +//===---------------------------------------------------------------------===// +// Region predicates. +//===---------------------------------------------------------------------===// + +bool structured_match::WithSingleOpImpl::withSingleOpImpl( + StringRef operationName, Region *region, Operation *op, + SmallVectorImpl *capturedOperands) { + if (!isa(op)) + return false; + auto linalgOp = cast(op); + + if (!region->hasOneBlock()) + return false; + unsigned numberOfOpsInRegion = + (operationName.compare(linalg::YieldOp::getOperationName()) == 0) ? 1 : 2; + if (std::distance(region->front().begin(), region->front().end()) != + numberOfOpsInRegion) + return false; + if (linalgOp.getNumDpsInits() != 1) + return false; + + // Require only a single yield operand defined by innerOp. + Operation *yieldOp = linalgOp.getBlock()->getTerminator(); + if (yieldOp->getNumOperands() != 1) + return false; + // Only linalg.yield, exit true. + if (numberOfOpsInRegion == 1) { + if (capturedOperands) { + auto arg0 = dyn_cast(yieldOp->getOperand(0)); + // linalg.yield operand might be coming from a different region. + if (arg0 && arg0.getParentBlock() == linalgOp.getBlock()) + capturedOperands->push_back(linalgOp.getMatchingOpOperand(arg0)->get()); + capturedOperands->push_back(linalgOp.getDpsInitOperand(0)->get()); + } + return true; + } + + // Check on the only inner operation. + Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); + if (innerOp->getName().getStringRef() != operationName) + return false; + if (yieldOp->getOperand(0).getDefiningOp() != innerOp) + return false; + // The operand of the innerOp must comes from the region + // args of the generic. + auto arg0 = dyn_cast(innerOp->getOperand(0)); + auto arg1 = dyn_cast(innerOp->getOperand(1)); + if (!arg0 || !arg1) + return false; + if (arg0.getParentBlock() != linalgOp.getBlock() || + arg1.getParentBlock() != linalgOp.getBlock()) + return false; + if (capturedOperands) { + capturedOperands->push_back(linalgOp.getMatchingOpOperand(arg0)->get()); + capturedOperands->push_back(linalgOp.getMatchingOpOperand(arg1)->get()); + capturedOperands->push_back(linalgOp.getDpsInitOperand(0)->get()); + } + return true; +} + +// FIXME: This is a generalization of the method above and will eventually +// replace the matcher for both no-op (yield) and one op (add, max). +bool structured_match::withOpChainImpl( + Region *region, Operation *op, SmallVectorImpl *capturedOperands, + SmallVectorImpl &typeChecks) { + + // Number of ops includes yield + ptrdiff_t numOps = typeChecks.size() + 1; + + // Basic checks + if (!isa(op)) + return false; + auto linalgOp = cast(op); + if (!region->hasOneBlock()) + return false; + auto &block = region->front(); + if (std::distance(block.begin(), block.end()) != numOps) + return false; + if (linalgOp.getNumDpsInits() != 1) + return false; + + // Add generic arguments to the list of chained values + llvm::SmallSetVector chainedValues; + for (auto arg : block.getArguments()) { + chainedValues.insert(arg); + } + + // Check on the inner chain of operations in the right order. + // Make sure all operands are used and chained + for (auto [check, innerOp] : + llvm::zip_first(typeChecks, block.getOperations())) { + // Must be right op in right order + if (!check(&innerOp)) + return false; + + // At least one operand must come from args or a previous op + bool consumesValueFromChain = false; + for (auto operand : innerOp.getOperands()) { + if (chainedValues.contains(operand)) { + // First add to the captured + auto ba = dyn_cast(operand); + if (capturedOperands && ba && + ba.getParentBlock() == linalgOp.getBlock()) { + capturedOperands->push_back(linalgOp.getMatchingOpOperand(ba)->get()); + } + // Then erase from the set + chainedValues.remove(operand); + consumesValueFromChain = true; + } + } + + // Operation isn't in the chain + if (!consumesValueFromChain) + return false; + + // Add return value to the list of chained values + for (auto ret : innerOp.getResults()) { + chainedValues.insert(ret); + } + } + + // Last op must be a chained yield. + Operation *yieldOp = linalgOp.getBlock()->getTerminator(); + assert(isa(yieldOp) && "Wrong terminator"); + for (auto op : yieldOp->getOperands()) { + if (!chainedValues.contains(op)) + return false; + } + + return true; +} + +structured_match::StructuredOpMatcher & +structured_match::StructuredOpMatcher::region( + MatchSelector range, + std::function fun) { + predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { + auto regions = linalgOp->getRegions(); + assert(regions.size() != 0); + size_t upperBound = range.getUpperBound(); + size_t lowerBound = range.getLowerBound(); + if (upperBound == std::numeric_limits::max()) + upperBound = regions.size(); + + for (auto idx : + llvm::to_vector(llvm::seq(lowerBound, upperBound))) { + if (!fun(®ions[idx], linalgOp.getOperation())) + return false; + } + return true; + }); + return *this; +} + +} // namespace utils +} // namespace gcext +} // namespace mlir diff --git a/lib/gc/Utils/ValueUtils.cpp b/lib/gc/Utils/ValueUtils.cpp new file mode 100644 index 000000000..71ffbfa3d --- /dev/null +++ b/lib/gc/Utils/ValueUtils.cpp @@ -0,0 +1,153 @@ +//===-- ValueUtils.cpp ---------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-------------------------------------------------------------------===// + +/* + * This code is borrowed from tpp-mlir: + * https://github.com/plaidml/tpp-mlir/tree/main/lib/TPP/Transforms/Utils/ValueUtils.cpp + */ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "gc/Utils/ValueUtils.h" + +namespace mlir { +namespace gcext { +namespace utils { + +// Returns true if the value is a constant float or integer. +bool isValConstZero(Value val) { + return matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()); +} + +// Returns true if the attribute represent "all zeros" +static bool isZeroAttr(Attribute attribute) { + return TypeSwitch(attribute) + .Case([](auto attr) { return attr.getValueAsDouble() == 0.0; }) + .Case([](auto attr) { return attr.getInt() == 0; }) + .Case([](auto attr) { + if (!attr.getElementType().isIntOrFloat()) + return false; + if (!attr.isSplat()) + return false; + auto splat = attr.template getSplatValue(); + return isZeroAttr(splat); + }) + .Default([](auto attr) { return false; }); +} + +// Prototypes +static bool isZeroOp(Operation *); + +// Returns true if the value represents a zero filled tensor. +// Recurse into isZeroOp for defining ops if not immediately obvious +// Looks past linalg generic's argument (which don't have defining ops) +bool isZeroTensor(Value val) { + if (!val) + return false; + if (isValConstZero(val)) + return true; + + Operation *defOp = nullptr; + + // Block arguments don't have a defining op, but they do have an op arg + if (auto arg = dyn_cast(val)) { + // We need to find the argument to the linalg on the same order as this one + auto *linalgOp = arg.getParentRegion()->getParentOp(); + if (!isa(linalgOp)) + return false; + auto index = arg.getArgNumber(); + auto linalgArg = linalgOp->getOperand(index); + defOp = linalgArg.getDefiningOp(); + } else { + defOp = val.getDefiningOp(); + } + return isZeroOp(defOp); +} + +// Returns true if the operation represents a zero filled tensor +// Recurses into isZeroTensor for operands and isZeroAttr for attributes +static bool isZeroOp(Operation *defOp) { + if (!defOp) + return false; + + return TypeSwitch(defOp) + .Case([&](auto op) { + // Dense attributes don't match APFloat.isZero() + auto attr = op.getValue(); + return isZeroAttr(attr); + }) + .Case([&](auto op) { + if (op.getInputs().size() != 1) + return false; + return isZeroTensor(op.getInputs()[0]); + }) + .Case( + [&](auto op) { return isZeroTensor(op.getSource()); }) + .Case([&](auto op) { + auto name = op.getName(); + auto module = defOp->getParentOfType(); + auto global = module.lookupSymbol(name); + auto attr = global.getInitialValueAttr(); + return isZeroAttr(attr); + }) + .Default([&](Operation *op) { return false; }); +} + +FailureOr> getStaticStrides(Value value) { + auto valueType = value.getType(); + if (!isa(valueType)) + return failure(); + auto memrefType = cast(valueType); + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + return failure(); + } + if (llvm::any_of(strides, [](int64_t stride) { + return stride == ShapedType::kDynamic; + })) { + return failure(); + } + return strides; +} + +std::pair getPtrAndOffset(OpBuilder &builder, Value operand, + Location loc) { + auto memrefType = dyn_cast(operand.getType()); + assert(memrefType && "Expect a memref value"); + MemRefType baseMemrefType = MemRefType::get({}, memrefType.getElementType()); + Type basePtrType = builder.getIndexType(); + Type offsetType = builder.getIndexType(); + SmallVector sizesTypes(memrefType.getRank(), offsetType); + SmallVector stridesTypes(memrefType.getRank(), offsetType); + auto meta = builder.create( + loc, baseMemrefType, offsetType, sizesTypes, stridesTypes, operand); + Value alignedPointerAsIndex = + builder.create(loc, basePtrType, + operand); + Value alignedPointerAsI64 = builder.create( + loc, builder.getIntegerType(64), alignedPointerAsIndex); + // TODO: non-POD will require an LLVMTypeConverter. + Value alignedPointer = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), + alignedPointerAsI64); + Value offset = meta.getOffset(); + return std::make_pair(alignedPointer, offset); +} + +} // namespace utils +} // namespace gcext +} // namespace mlir diff --git a/src/gc-opt/gc-opt.cpp b/src/gc-opt/gc-opt.cpp index acde29010..5aa0bfc1d 100644 --- a/src/gc-opt/gc-opt.cpp +++ b/src/gc-opt/gc-opt.cpp @@ -19,6 +19,7 @@ #include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "gc/Dialect/Linalgx/LinalgxDialect.h" +#include "gc/Dialect/Microkernel/MicrokernelDialect.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" @@ -52,6 +53,7 @@ int main(int argc, char *argv[]) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); mlir::registerAllDialects(registry); #ifdef GC_USE_GPU registry.insert<::imex::xetile::XeTileDialect, ::imex::gpux::GPUXDialect>(); From 1c431825924386c2e31544146ebb581d074f11b7 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Sun, 7 Jul 2024 20:15:33 -0700 Subject: [PATCH 02/93] fix licenses --- include/gc/Utils/StructuredOpMatcher.h | 2 +- include/gc/Utils/ValueUtils.h | 2 +- lib/gc/Utils/ValueUtils.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/gc/Utils/StructuredOpMatcher.h b/include/gc/Utils/StructuredOpMatcher.h index e0aa9acc7..392b358a7 100644 --- a/include/gc/Utils/StructuredOpMatcher.h +++ b/include/gc/Utils/StructuredOpMatcher.h @@ -1,4 +1,4 @@ -//===- StructuredOpMatcher.h - Utils for structuterd Op ----------*-C++ -*-===// +//===-- StructuredOpMatcher.h - Utils for structuterd Op --------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/include/gc/Utils/ValueUtils.h b/include/gc/Utils/ValueUtils.h index 5d940b000..adf26ac56 100644 --- a/include/gc/Utils/ValueUtils.h +++ b/include/gc/Utils/ValueUtils.h @@ -1,4 +1,4 @@ -//===- ValueUtils.h - Utils for handling mlir::Value -------------*-C++ -*-===// +//===- ValueUtils.h - Utils for handling mlir::Value ------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/gc/Utils/ValueUtils.cpp b/lib/gc/Utils/ValueUtils.cpp index 71ffbfa3d..3cb7d7399 100644 --- a/lib/gc/Utils/ValueUtils.cpp +++ b/lib/gc/Utils/ValueUtils.cpp @@ -1,4 +1,4 @@ -//===-- ValueUtils.cpp ---------------------------------------*- C++ -*-===// +//===-- ValueUtils.cpp - Utils for handling Value ------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. From e5226188ab42720374d47e9088d82ebcaea92be0 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Sun, 7 Jul 2024 21:59:59 -0700 Subject: [PATCH 03/93] fix license check --- lib/gc/Utils/StructuredOpMatcher.cpp | 2 +- lib/gc/Utils/ValueUtils.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/gc/Utils/StructuredOpMatcher.cpp b/lib/gc/Utils/StructuredOpMatcher.cpp index e5ba53141..5ec88ea8a 100644 --- a/lib/gc/Utils/StructuredOpMatcher.cpp +++ b/lib/gc/Utils/StructuredOpMatcher.cpp @@ -1,4 +1,4 @@ -//===-- StructuredOpMatcher.cpp ---------------------------------*- C++ -*-===// +//===-- StructuredOpMatcher.cpp - Utils for structured Op -------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/gc/Utils/ValueUtils.cpp b/lib/gc/Utils/ValueUtils.cpp index 3cb7d7399..8a72cf323 100644 --- a/lib/gc/Utils/ValueUtils.cpp +++ b/lib/gc/Utils/ValueUtils.cpp @@ -1,10 +1,10 @@ -//===-- ValueUtils.cpp - Utils for handling Value ------------*- C++ -*-===// +//===-- ValueUtils.cpp - Utils for handling Value ---------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===-------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// /* * This code is borrowed from tpp-mlir: From 88f645a5119da08423ee42787bb6fd83688ed9ce Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Sun, 7 Jul 2024 23:03:35 -0700 Subject: [PATCH 04/93] fix tidy --- include/gc/Utils/StructuredOpMatcher.h | 10 +++++----- include/gc/Utils/ValueUtils.h | 2 +- lib/gc/Utils/StructuredOpMatcher.cpp | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/include/gc/Utils/StructuredOpMatcher.h b/include/gc/Utils/StructuredOpMatcher.h index 392b358a7..ad93ffa6e 100644 --- a/include/gc/Utils/StructuredOpMatcher.h +++ b/include/gc/Utils/StructuredOpMatcher.h @@ -398,24 +398,24 @@ class StructuredOpMatcher { bool match(Operation *op); // Predicates on operation. - StructuredOpMatcher &operation(std::function); + StructuredOpMatcher &operation(std::function &); // Predicate on OpOperands. StructuredOpMatcher &input(MatchSelector range, - std::function); + std::function &); // Predicates on OpOperands. StructuredOpMatcher &output(MatchSelector range, - std::function); + std::function &); // Predicates on Iterators. StructuredOpMatcher &dim(MatchSelector range, - SmallVector kinds); + SmallVector &kinds); StructuredOpMatcher &dim(MatchSelector range, mlir::utils::IteratorType kind); // Predicates on region. StructuredOpMatcher ®ion(MatchSelector range, - std::function); + std::function &); private: llvm::SmallVector predicates; diff --git a/include/gc/Utils/ValueUtils.h b/include/gc/Utils/ValueUtils.h index adf26ac56..c656cc07c 100644 --- a/include/gc/Utils/ValueUtils.h +++ b/include/gc/Utils/ValueUtils.h @@ -35,7 +35,7 @@ FailureOr> getStaticStrides(Value val); // Return the offset and ptr for `val`. Assert if `val` // is not a memref. -std::pair getPtrAndOffset(OpBuilder &builder, Value val, +std::pair getPtrAndOffset(OpBuilder &builder, Value operand, Location loc); } // namespace utils diff --git a/lib/gc/Utils/StructuredOpMatcher.cpp b/lib/gc/Utils/StructuredOpMatcher.cpp index 5ec88ea8a..3e7369b5d 100644 --- a/lib/gc/Utils/StructuredOpMatcher.cpp +++ b/lib/gc/Utils/StructuredOpMatcher.cpp @@ -42,7 +42,7 @@ bool structured_match::StructuredOpMatcher::match(Operation *op) { structured_match::StructuredOpMatcher & structured_match::StructuredOpMatcher::operation( - std::function fun) { + std::function &fun) { predicates.push_back( [=](linalg::LinalgOp linalgOp) -> bool { return fun(linalgOp); }); return *this; @@ -55,7 +55,7 @@ structured_match::StructuredOpMatcher::operation( structured_match::StructuredOpMatcher & structured_match::StructuredOpMatcher::input( MatchSelector range, - std::function fun) { + std::function &fun) { predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { auto operands = linalgOp.getDpsInputOperands(); size_t upperBound = range.getUpperBound(); @@ -80,7 +80,7 @@ structured_match::StructuredOpMatcher::input( structured_match::StructuredOpMatcher & structured_match::StructuredOpMatcher::output( MatchSelector range, - std::function fun) { + std::function &fun) { predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { auto operands = linalgOp.getDpsInitsMutable(); size_t upperBound = range.getUpperBound(); @@ -104,7 +104,7 @@ structured_match::StructuredOpMatcher::output( structured_match::StructuredOpMatcher & structured_match::StructuredOpMatcher::dim( - MatchSelector range, SmallVector kinds) { + MatchSelector range, SmallVector &kinds) { predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { size_t upperBound = range.getUpperBound(); size_t lowerBound = range.getLowerBound(); @@ -282,10 +282,10 @@ bool structured_match::withOpChainImpl( structured_match::StructuredOpMatcher & structured_match::StructuredOpMatcher::region( MatchSelector range, - std::function fun) { + std::function &fun) { predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { auto regions = linalgOp->getRegions(); - assert(regions.size() != 0); + assert(!regions.empty()); size_t upperBound = range.getUpperBound(); size_t lowerBound = range.getLowerBound(); if (upperBound == std::numeric_limits::max()) From 8a7ec98c8d04acad9be49716481c1a3a9d22c307 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Sun, 7 Jul 2024 23:26:10 -0700 Subject: [PATCH 05/93] fix lint --- lib/gc/Utils/StructuredOpMatcher.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/gc/Utils/StructuredOpMatcher.cpp b/lib/gc/Utils/StructuredOpMatcher.cpp index 3e7369b5d..9179c0b5b 100644 --- a/lib/gc/Utils/StructuredOpMatcher.cpp +++ b/lib/gc/Utils/StructuredOpMatcher.cpp @@ -27,7 +27,7 @@ bool structured_match::StructuredOpMatcher::match(Operation *op) { return false; LLVM_DEBUG(llvm::dbgs() << "Running matcher on: " << *op << "\n"); - for (auto [idx, predicate] : llvm::enumerate(predicates)) { + for (auto [idx, predicate] : llvm::enumerate(predicates)) { // NOLINT if (!predicate(linalgOp)) { LLVM_DEBUG(llvm::dbgs() << "Exit on predicate: " << idx << "\n"); return false; @@ -271,7 +271,7 @@ bool structured_match::withOpChainImpl( // Last op must be a chained yield. Operation *yieldOp = linalgOp.getBlock()->getTerminator(); assert(isa(yieldOp) && "Wrong terminator"); - for (auto op : yieldOp->getOperands()) { + for (auto op : yieldOp->getOperands()) { // NOLINT if (!chainedValues.contains(op)) return false; } From 738ba0cbff8f840bb375b50942cc4d33e033e834 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 00:14:19 -0700 Subject: [PATCH 06/93] remove Utils borrowed from TPP --- include/gc/Utils/StructuredOpMatcher.h | 429 ---------------------- include/gc/Utils/ValueUtils.h | 45 --- lib/gc/Dialect/Microkernel/CMakeLists.txt | 2 +- lib/gc/Utils/CMakeLists.txt | 7 - lib/gc/Utils/StructuredOpMatcher.cpp | 306 --------------- lib/gc/Utils/ValueUtils.cpp | 153 -------- 6 files changed, 1 insertion(+), 941 deletions(-) delete mode 100644 include/gc/Utils/StructuredOpMatcher.h delete mode 100644 include/gc/Utils/ValueUtils.h delete mode 100644 lib/gc/Utils/CMakeLists.txt delete mode 100644 lib/gc/Utils/StructuredOpMatcher.cpp delete mode 100644 lib/gc/Utils/ValueUtils.cpp diff --git a/include/gc/Utils/StructuredOpMatcher.h b/include/gc/Utils/StructuredOpMatcher.h deleted file mode 100644 index ad93ffa6e..000000000 --- a/include/gc/Utils/StructuredOpMatcher.h +++ /dev/null @@ -1,429 +0,0 @@ -//===-- StructuredOpMatcher.h - Utils for structuterd Op --------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -/* - * This code is borrowed from tpp-mlir: - * https://github.com/plaidml/tpp-mlir/tree/main/include/TPP/IR/StructuredOpMatcher.h - */ - -#ifndef GC_UTILS_STRUCTUREDOPMATCHER_H -#define GC_UTILS_STRUCTUREDOPMATCHER_H - -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallVector.h" -#include -#include - -namespace mlir { -class Operation; -namespace gcext { -namespace utils { -namespace structured_match { - -// Base class for the matcher predicates selection tag. -struct MatchSelector { - MatchSelector() = delete; - size_t getLowerBound() const { return lowerBound; }; - size_t getUpperBound() const { return upperBound; }; - -protected: - explicit MatchSelector(size_t lowerBound, size_t upperBound) - : lowerBound(lowerBound), upperBound(upperBound) { - assert(upperBound > lowerBound); - } - -private: - const size_t lowerBound; - const size_t upperBound; -}; - -// Selector which specifies that predicate should apply on all values. -struct MatchAll : public MatchSelector { - MatchAll() : MatchSelector(0, std::numeric_limits::max()) {} -}; - -// Selector which specifies that predicate should apply only on one value at -// the position `idx`. -struct MatchOne : public MatchSelector { - MatchOne() = delete; - MatchOne(size_t idx) : MatchSelector(idx, idx + 1) {} -}; - -// Selector which specifies that predicate should apply only on range of values -// at positions from `lowerBound` up to - but not including - `upperBound`. -struct MatchRange : public MatchSelector { - MatchRange() = delete; - MatchRange(size_t lowerBound, size_t upperBound) - : MatchSelector(lowerBound, upperBound) {} -}; - -// Callable object to check if the number of loops in `op` satisfies `fun`. -struct NumOfLoops { - NumOfLoops() = delete; - explicit NumOfLoops(std::function fun) : fun(std::move(fun)){}; - - bool operator()(Operation *op) const { - if (auto linalgOp = dyn_cast_or_null(op)) { - auto numberOfLoops = linalgOp.getNumLoops(); - return fun(numberOfLoops); - } - return false; - } - std::function fun; -}; - -// Callable object to check if the `operand` of `op` has a map that satisfies -// `fun`. -struct HasMap { - HasMap() = delete; - explicit HasMap(std::function fun) : fun(std::move(fun)){}; - explicit HasMap(std::function fun, AffineMap *ptrMap) - : fun(std::move(fun)), ptrMap(ptrMap){}; - - bool operator()(OpOperand *operand, Operation *op) const { - if (auto linalgOp = dyn_cast_or_null(op)) { - auto map = linalgOp.getMatchingIndexingMap(operand); - assert(fun && "must be a callable target"); - if (!fun(map)) - return false; - if (ptrMap) - *ptrMap = std::move(map); - return true; - } - return false; - } - std::function fun; - AffineMap *ptrMap = nullptr; -}; - -// Callble object to verify if `map` is a projected permutation map. -// We require the dimensions to be in sorted order this avoid filtering -// projected permutation without broadcasting semantics, for example -// affine_map<(d0, d1) -> (d1, d0)> is rejected. -struct ProjectedPermutation { - ProjectedPermutation() = default; - - bool operator()(AffineMap map) const { - if (map.getNumSymbols() > 0 || map.getNumResults() > map.getNumInputs()) - return false; - - SmallVector seen(map.getNumInputs(), false); - SmallVector pos; - for (auto expr : map.getResults()) { - if (auto dim = dyn_cast(expr)) { - if (seen[dim.getPosition()]) - return false; - seen[dim.getPosition()] = true; - pos.push_back(dim.getPosition()); - } else if (auto constExpr = dyn_cast(expr)) { - if (constExpr.getValue() != 0) - return false; - } else - return false; - } - return llvm::is_sorted(pos); - } -}; - -// Callable object to verify if `map` is an identity map. -struct Identity { - Identity() = default; - - bool operator()(AffineMap map) const { return map.isIdentity(); } -}; - -// Callable object to capture any map. -struct Any { - Any() = default; - - bool operator()(AffineMap map) const { return true; } -}; - -// Callable object to verify if `operand` has static shape. -struct HasStaticShape { - HasStaticShape() = default; - HasStaticShape(SmallVectorImpl *shape) : shape(shape){}; - - bool operator()(OpOperand *operand, Operation *op) const { - auto operandType = operand->get().getType(); - if (auto shapedType = dyn_cast_or_null(operandType)) { - if (!shapedType.hasStaticShape()) - return false; - if (shape) { - for (int64_t shapeOnDim : shapedType.getShape()) - shape->push_back(shapeOnDim); - } - } - return true; - } - SmallVectorImpl *shape = nullptr; -}; - -// Callable object to verify if `operand` has static strides. -// If `operand` is a tensor type or a scalar, return true. -struct HasStaticStrides { - HasStaticStrides() = default; - HasStaticStrides(SmallVector *strides) : strides(strides){}; - - bool operator()(OpOperand *operand, Operation *op) const { - auto operandType = operand->get().getType(); - SmallVector strides; - if (auto memRefType = dyn_cast_or_null(operandType)) { - int64_t offset; - if (failed(getStridesAndOffset(memRefType, strides, offset))) - return false; - if (llvm::any_of(strides, [](int64_t stride) { - return stride == ShapedType::kDynamic; - })) { - return false; - } - if (this->strides) - this->strides->append(strides.begin(), strides.end()); - } - return true; - } - SmallVectorImpl *strides = nullptr; -}; - -// Callable object to verify `operand` to have a rank in `ranks`. -struct HasRank { - HasRank() = delete; - explicit HasRank(std::initializer_list ranks) : ranks(ranks){}; - - bool operator()(OpOperand *operand, Operation *op) const { - auto operandType = operand->get().getType(); - if (!isa(operandType)) - return llvm::is_contained(ranks, HasRank::SCALAR); - int64_t rank = cast(operandType).getRank(); - return llvm::any_of( - ranks, [=](int64_t expectedRank) { return expectedRank == rank; }); - } - - // There are multiple way to represent a scalar: f32, tensor. - // SCALAR means f32. - static constexpr int64_t SCALAR = -1; - std::vector ranks; -}; - -// Callable object to verify `operand` to have an element type `T`. -template struct HasElementType { - bool operator()(OpOperand *operand, Operation *op) const { - auto operandType = getElementTypeOrSelf(operand->get().getType()); - return isa(operandType); - } -}; - -// Callable object to check if the input is equal to specified `value`. -template struct EqualsTo { - EqualsTo() = delete; - explicit EqualsTo(T value) : value(value){}; - - const T value; - - bool operator()(T value) const { return value == this->value; } -}; -template EqualsTo(T) -> EqualsTo; - -// Callable object to check if the input is less than or equal to specified -// `value`. -struct LessThanOrEqualTo { - LessThanOrEqualTo() = delete; - explicit LessThanOrEqualTo(size_t value) : value(value){}; - const size_t value; - - bool operator()(size_t value) const { return value <= this->value; } -}; - -// Callable object to check if the input is greater than or equal to specified -// `value`. -struct GreaterThanOrEqualTo { - GreaterThanOrEqualTo() = delete; - explicit GreaterThanOrEqualTo(size_t value) : value(value){}; - const size_t value; - - bool operator()(size_t value) const { return value >= this->value; } -}; - -// Callable object to validate number of init operands for `op`. -struct NumDpsInits { - NumDpsInits() = delete; - explicit NumDpsInits(std::function fun) : fun(std::move(fun)){}; - - bool operator()(Operation *op) const { - if (auto linalgOp = dyn_cast_or_null(op)) - return fun(linalgOp.getNumDpsInits()); - return false; - } - - std::function fun; -}; - -// Callable object to check the number of affine map for `op`. -struct NumAffineMaps { - NumAffineMaps() = delete; - explicit NumAffineMaps(std::function fun) - : fun(std::move(fun)){}; - - bool operator()(Operation *op) const { - if (auto linalgOp = dyn_cast_or_null(op)) - return fun(linalgOp.getIndexingMapsArray().size()); - return false; - } - - std::function fun; -}; - -// Callable object to validate number of input operands for `op`. -struct NumDpsInputs { - NumDpsInputs() = delete; - explicit NumDpsInputs(std::function fun) - : fun(std::move(fun)){}; - - bool operator()(Operation *op) { - if (auto linalgOp = dyn_cast_or_null(op)) - return fun(linalgOp.getNumDpsInputs()); - return false; - } - - std::function fun; -}; - -// Callable object to validate number of regions for `op`. -struct NumRegions { - NumRegions() = delete; - explicit NumRegions(std::function fun) : fun(std::move(fun)){}; - - bool operator()(Operation *op) const { - if (auto linalgOp = dyn_cast_or_null(op)) - return fun(linalgOp->getNumRegions()); - return false; - } - - std::function fun; -}; - -// Logical OR between two predicates. -struct _OR { - _OR() = delete; - _OR(std::function lhs, std::function rhs) - : lhs(std::move(lhs)), rhs(std::move(rhs)) {} - - bool operator()(size_t num) { return (lhs(num) || rhs(num)); } - - std::function lhs; - std::function rhs; -}; - -// Callable object to check if `op` adheres to a given property passed -// as an std::function object. -struct VerifyOpProperty { - VerifyOpProperty() = delete; - explicit VerifyOpProperty(std::function fun) - : fun(std::move(fun)){}; - - bool operator()(Operation *op) { - if (succeeded(fun(op))) - return true; - return false; - } - - std::function fun; -}; - -// Work-around for template specialization. -struct WithSingleOpImpl { - WithSingleOpImpl() = default; - - bool withSingleOpImpl(StringRef, Region *, Operation *, - SmallVectorImpl *); -}; - -// Callable object to check the `op` region for a single scalar operation OpTy. -template struct WithSingleOp { - WithSingleOp() : WithSingleOp(nullptr){}; - WithSingleOp(SmallVectorImpl *captures) : captures(captures){}; - - bool operator()(Region *region, Operation *op) { - return WithSingleOpImpl().withSingleOpImpl(OpTy::getOperationName(), region, - op, captures); - } - -private: - SmallVectorImpl *captures; -}; - -// Implemenation to allow definition in cpp file -using TypeCheckFunc = std::function; -bool withOpChainImpl(Region *region, Operation *op, SmallVectorImpl *, - SmallVectorImpl &); - -// Callable object to check the region for a chain of operations. -template struct WithOpChain { - WithOpChain() : WithOpChain(nullptr){}; - WithOpChain(SmallVectorImpl *captures) : captures(captures) { - (typeChecks.push_back([](Operation *op) { return isa(op); }), ...); - }; - - bool operator()(Region *region, Operation *op) { - return withOpChainImpl(region, op, captures, typeChecks); - } - -private: - SmallVectorImpl *captures; - SmallVector typeChecks; -}; - -class StructuredOpMatcher { - using PredicateFn = std::function; - -public: - StructuredOpMatcher() = default; - - StructuredOpMatcher(PredicateFn &&firstPredicate) { - predicates.push_back(std::move(firstPredicate)); - } - - template static StructuredOpMatcher make() { - return StructuredOpMatcher( - [](linalg::LinalgOp op) { return isa(op.getOperation()); }); - } - - // Match given `op` using stored predicates. - bool match(Operation *op); - - // Predicates on operation. - StructuredOpMatcher &operation(std::function &); - - // Predicate on OpOperands. - StructuredOpMatcher &input(MatchSelector range, - std::function &); - - // Predicates on OpOperands. - StructuredOpMatcher &output(MatchSelector range, - std::function &); - - // Predicates on Iterators. - StructuredOpMatcher &dim(MatchSelector range, - SmallVector &kinds); - StructuredOpMatcher &dim(MatchSelector range, mlir::utils::IteratorType kind); - - // Predicates on region. - StructuredOpMatcher ®ion(MatchSelector range, - std::function &); - -private: - llvm::SmallVector predicates; -}; - -} // namespace structured_match -} // namespace utils -} // namespace gcext -} // namespace mlir - -#endif // GC_UTILS_STRUCTUREDOPMATCHER_H diff --git a/include/gc/Utils/ValueUtils.h b/include/gc/Utils/ValueUtils.h deleted file mode 100644 index c656cc07c..000000000 --- a/include/gc/Utils/ValueUtils.h +++ /dev/null @@ -1,45 +0,0 @@ -//===- ValueUtils.h - Utils for handling mlir::Value ------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -/* - * This code is borrowed from tpp-mlir: - * https://github.com/plaidml/tpp-mlir/tree/main/include/TPP/Transforms/Utils/ValueUtils.h - */ - -#ifndef GC_UTILS_VALUEUTILS_H -#define GC_UTILS_VALUEUTILS_H - -namespace mlir { -class Value; -class OpBuilder; -namespace gcext { -namespace utils { - -using namespace mlir; - -// Returns true if the value is a constant float or integer. -bool isValConstZero(Value val); - -// Returns true if the op defining `val` represents a zero filled tensor. -bool isZeroTensor(Value val); - -// Returns the strides of `val`. The method returns something usefull -// only if the `val` type is a strided memref and the strides are statically -// known. -FailureOr> getStaticStrides(Value val); - -// Return the offset and ptr for `val`. Assert if `val` -// is not a memref. -std::pair getPtrAndOffset(OpBuilder &builder, Value operand, - Location loc); - -} // namespace utils -} // namespace gcext -} // namespace mlir - -#endif diff --git a/lib/gc/Dialect/Microkernel/CMakeLists.txt b/lib/gc/Dialect/Microkernel/CMakeLists.txt index f67309940..029f00cce 100644 --- a/lib/gc/Dialect/Microkernel/CMakeLists.txt +++ b/lib/gc/Dialect/Microkernel/CMakeLists.txt @@ -13,6 +13,6 @@ add_mlir_dialect_library(MLIRMicrokernel LINK_LIBS PUBLIC ${MLIR_LINK_COMPONENTS} - GCMLIRUtils + GCUtilsIR ) set_property(GLOBAL APPEND PROPERTY GC_DIALECT_LIBS MLIRMicrokernel) diff --git a/lib/gc/Utils/CMakeLists.txt b/lib/gc/Utils/CMakeLists.txt deleted file mode 100644 index 007d91710..000000000 --- a/lib/gc/Utils/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -add_mlir_library(GCMLIRUtils - StructuredOpMatcher.cpp - ValueUtils.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/gc/Utils -) diff --git a/lib/gc/Utils/StructuredOpMatcher.cpp b/lib/gc/Utils/StructuredOpMatcher.cpp deleted file mode 100644 index 9179c0b5b..000000000 --- a/lib/gc/Utils/StructuredOpMatcher.cpp +++ /dev/null @@ -1,306 +0,0 @@ -//===-- StructuredOpMatcher.cpp - Utils for structured Op -------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -/* - * This code is borrowed from tpp-mlir: - * https://github.com/plaidml/tpp-mlir/tree/main/lib/TPP/IR/StructuredOpMatcher.cpp - */ - -#include "gc/Utils/StructuredOpMatcher.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "structured-matchers" - -namespace mlir { -namespace gcext { -namespace utils { - -// Entry point. -bool structured_match::StructuredOpMatcher::match(Operation *op) { - auto linalgOp = dyn_cast_or_null(op); - if (!linalgOp) - return false; - LLVM_DEBUG(llvm::dbgs() << "Running matcher on: " << *op << "\n"); - - for (auto [idx, predicate] : llvm::enumerate(predicates)) { // NOLINT - if (!predicate(linalgOp)) { - LLVM_DEBUG(llvm::dbgs() << "Exit on predicate: " << idx << "\n"); - return false; - } - } - return true; -} - -//===---------------------------------------------------------------------===// -// Operation predicates. -//===---------------------------------------------------------------------===// - -structured_match::StructuredOpMatcher & -structured_match::StructuredOpMatcher::operation( - std::function &fun) { - predicates.push_back( - [=](linalg::LinalgOp linalgOp) -> bool { return fun(linalgOp); }); - return *this; -} - -//===---------------------------------------------------------------------===// -// Operand predicates - input. -//===---------------------------------------------------------------------===// - -structured_match::StructuredOpMatcher & -structured_match::StructuredOpMatcher::input( - MatchSelector range, - std::function &fun) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - auto operands = linalgOp.getDpsInputOperands(); - size_t upperBound = range.getUpperBound(); - size_t lowerBound = range.getLowerBound(); - if (upperBound == std::numeric_limits::max()) - upperBound = operands.size(); - - for (auto idx : - llvm::to_vector(llvm::seq(lowerBound, upperBound))) { - if (!fun(operands[idx], linalgOp.getOperation())) - return false; - } - return true; - }); - return *this; -} - -//===---------------------------------------------------------------------===// -// Operand predicates - output. -//===---------------------------------------------------------------------===// - -structured_match::StructuredOpMatcher & -structured_match::StructuredOpMatcher::output( - MatchSelector range, - std::function &fun) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - auto operands = linalgOp.getDpsInitsMutable(); - size_t upperBound = range.getUpperBound(); - size_t lowerBound = range.getLowerBound(); - if (upperBound == std::numeric_limits::max()) - upperBound = operands.size(); - - for (auto idx : - llvm::to_vector(llvm::seq(lowerBound, upperBound))) { - if (!fun(&operands[idx], linalgOp.getOperation())) - return false; - } - return true; - }); - return *this; -} - -//===---------------------------------------------------------------------===// -// Dim predicates. -//===---------------------------------------------------------------------===// - -structured_match::StructuredOpMatcher & -structured_match::StructuredOpMatcher::dim( - MatchSelector range, SmallVector &kinds) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - size_t upperBound = range.getUpperBound(); - size_t lowerBound = range.getLowerBound(); - if (upperBound == std::numeric_limits::max()) - upperBound = kinds.size(); - size_t sizeRange = upperBound - lowerBound; - - auto iteratorTypes = linalgOp.getIteratorTypesArray(); - if (iteratorTypes.size() != sizeRange) - return false; - - // Reverse iterators to have the innermost one at index 0. - std::reverse(iteratorTypes.begin(), iteratorTypes.end()); - for (auto [idx, rangeIdx] : - llvm::enumerate(llvm::seq(lowerBound, upperBound))) { - if (iteratorTypes[rangeIdx] != kinds[idx]) - return false; - } - return true; - }); - return *this; -} - -structured_match::StructuredOpMatcher & -structured_match::StructuredOpMatcher::dim(MatchSelector range, - mlir::utils::IteratorType kind) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - auto iteratorTypes = linalgOp.getIteratorTypesArray(); - size_t upperBound = range.getUpperBound(); - size_t lowerBound = range.getLowerBound(); - if (upperBound == std::numeric_limits::max()) - upperBound = iteratorTypes.size(); - - for (auto rangeIdx = lowerBound; rangeIdx < upperBound; rangeIdx++) { - if (iteratorTypes[rangeIdx] != kind) - return false; - } - return true; - }); - return *this; -} - -//===---------------------------------------------------------------------===// -// Region predicates. -//===---------------------------------------------------------------------===// - -bool structured_match::WithSingleOpImpl::withSingleOpImpl( - StringRef operationName, Region *region, Operation *op, - SmallVectorImpl *capturedOperands) { - if (!isa(op)) - return false; - auto linalgOp = cast(op); - - if (!region->hasOneBlock()) - return false; - unsigned numberOfOpsInRegion = - (operationName.compare(linalg::YieldOp::getOperationName()) == 0) ? 1 : 2; - if (std::distance(region->front().begin(), region->front().end()) != - numberOfOpsInRegion) - return false; - if (linalgOp.getNumDpsInits() != 1) - return false; - - // Require only a single yield operand defined by innerOp. - Operation *yieldOp = linalgOp.getBlock()->getTerminator(); - if (yieldOp->getNumOperands() != 1) - return false; - // Only linalg.yield, exit true. - if (numberOfOpsInRegion == 1) { - if (capturedOperands) { - auto arg0 = dyn_cast(yieldOp->getOperand(0)); - // linalg.yield operand might be coming from a different region. - if (arg0 && arg0.getParentBlock() == linalgOp.getBlock()) - capturedOperands->push_back(linalgOp.getMatchingOpOperand(arg0)->get()); - capturedOperands->push_back(linalgOp.getDpsInitOperand(0)->get()); - } - return true; - } - - // Check on the only inner operation. - Operation *innerOp = &(*linalgOp.getBlock()->getOperations().begin()); - if (innerOp->getName().getStringRef() != operationName) - return false; - if (yieldOp->getOperand(0).getDefiningOp() != innerOp) - return false; - // The operand of the innerOp must comes from the region - // args of the generic. - auto arg0 = dyn_cast(innerOp->getOperand(0)); - auto arg1 = dyn_cast(innerOp->getOperand(1)); - if (!arg0 || !arg1) - return false; - if (arg0.getParentBlock() != linalgOp.getBlock() || - arg1.getParentBlock() != linalgOp.getBlock()) - return false; - if (capturedOperands) { - capturedOperands->push_back(linalgOp.getMatchingOpOperand(arg0)->get()); - capturedOperands->push_back(linalgOp.getMatchingOpOperand(arg1)->get()); - capturedOperands->push_back(linalgOp.getDpsInitOperand(0)->get()); - } - return true; -} - -// FIXME: This is a generalization of the method above and will eventually -// replace the matcher for both no-op (yield) and one op (add, max). -bool structured_match::withOpChainImpl( - Region *region, Operation *op, SmallVectorImpl *capturedOperands, - SmallVectorImpl &typeChecks) { - - // Number of ops includes yield - ptrdiff_t numOps = typeChecks.size() + 1; - - // Basic checks - if (!isa(op)) - return false; - auto linalgOp = cast(op); - if (!region->hasOneBlock()) - return false; - auto &block = region->front(); - if (std::distance(block.begin(), block.end()) != numOps) - return false; - if (linalgOp.getNumDpsInits() != 1) - return false; - - // Add generic arguments to the list of chained values - llvm::SmallSetVector chainedValues; - for (auto arg : block.getArguments()) { - chainedValues.insert(arg); - } - - // Check on the inner chain of operations in the right order. - // Make sure all operands are used and chained - for (auto [check, innerOp] : - llvm::zip_first(typeChecks, block.getOperations())) { - // Must be right op in right order - if (!check(&innerOp)) - return false; - - // At least one operand must come from args or a previous op - bool consumesValueFromChain = false; - for (auto operand : innerOp.getOperands()) { - if (chainedValues.contains(operand)) { - // First add to the captured - auto ba = dyn_cast(operand); - if (capturedOperands && ba && - ba.getParentBlock() == linalgOp.getBlock()) { - capturedOperands->push_back(linalgOp.getMatchingOpOperand(ba)->get()); - } - // Then erase from the set - chainedValues.remove(operand); - consumesValueFromChain = true; - } - } - - // Operation isn't in the chain - if (!consumesValueFromChain) - return false; - - // Add return value to the list of chained values - for (auto ret : innerOp.getResults()) { - chainedValues.insert(ret); - } - } - - // Last op must be a chained yield. - Operation *yieldOp = linalgOp.getBlock()->getTerminator(); - assert(isa(yieldOp) && "Wrong terminator"); - for (auto op : yieldOp->getOperands()) { // NOLINT - if (!chainedValues.contains(op)) - return false; - } - - return true; -} - -structured_match::StructuredOpMatcher & -structured_match::StructuredOpMatcher::region( - MatchSelector range, - std::function &fun) { - predicates.push_back([=](linalg::LinalgOp linalgOp) -> bool { - auto regions = linalgOp->getRegions(); - assert(!regions.empty()); - size_t upperBound = range.getUpperBound(); - size_t lowerBound = range.getLowerBound(); - if (upperBound == std::numeric_limits::max()) - upperBound = regions.size(); - - for (auto idx : - llvm::to_vector(llvm::seq(lowerBound, upperBound))) { - if (!fun(®ions[idx], linalgOp.getOperation())) - return false; - } - return true; - }); - return *this; -} - -} // namespace utils -} // namespace gcext -} // namespace mlir diff --git a/lib/gc/Utils/ValueUtils.cpp b/lib/gc/Utils/ValueUtils.cpp deleted file mode 100644 index 8a72cf323..000000000 --- a/lib/gc/Utils/ValueUtils.cpp +++ /dev/null @@ -1,153 +0,0 @@ -//===-- ValueUtils.cpp - Utils for handling Value ---------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -/* - * This code is borrowed from tpp-mlir: - * https://github.com/plaidml/tpp-mlir/tree/main/lib/TPP/Transforms/Utils/ValueUtils.cpp - */ - -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Value.h" -#include "llvm/ADT/TypeSwitch.h" - -#include "gc/Utils/ValueUtils.h" - -namespace mlir { -namespace gcext { -namespace utils { - -// Returns true if the value is a constant float or integer. -bool isValConstZero(Value val) { - return matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()); -} - -// Returns true if the attribute represent "all zeros" -static bool isZeroAttr(Attribute attribute) { - return TypeSwitch(attribute) - .Case([](auto attr) { return attr.getValueAsDouble() == 0.0; }) - .Case([](auto attr) { return attr.getInt() == 0; }) - .Case([](auto attr) { - if (!attr.getElementType().isIntOrFloat()) - return false; - if (!attr.isSplat()) - return false; - auto splat = attr.template getSplatValue(); - return isZeroAttr(splat); - }) - .Default([](auto attr) { return false; }); -} - -// Prototypes -static bool isZeroOp(Operation *); - -// Returns true if the value represents a zero filled tensor. -// Recurse into isZeroOp for defining ops if not immediately obvious -// Looks past linalg generic's argument (which don't have defining ops) -bool isZeroTensor(Value val) { - if (!val) - return false; - if (isValConstZero(val)) - return true; - - Operation *defOp = nullptr; - - // Block arguments don't have a defining op, but they do have an op arg - if (auto arg = dyn_cast(val)) { - // We need to find the argument to the linalg on the same order as this one - auto *linalgOp = arg.getParentRegion()->getParentOp(); - if (!isa(linalgOp)) - return false; - auto index = arg.getArgNumber(); - auto linalgArg = linalgOp->getOperand(index); - defOp = linalgArg.getDefiningOp(); - } else { - defOp = val.getDefiningOp(); - } - return isZeroOp(defOp); -} - -// Returns true if the operation represents a zero filled tensor -// Recurses into isZeroTensor for operands and isZeroAttr for attributes -static bool isZeroOp(Operation *defOp) { - if (!defOp) - return false; - - return TypeSwitch(defOp) - .Case([&](auto op) { - // Dense attributes don't match APFloat.isZero() - auto attr = op.getValue(); - return isZeroAttr(attr); - }) - .Case([&](auto op) { - if (op.getInputs().size() != 1) - return false; - return isZeroTensor(op.getInputs()[0]); - }) - .Case( - [&](auto op) { return isZeroTensor(op.getSource()); }) - .Case([&](auto op) { - auto name = op.getName(); - auto module = defOp->getParentOfType(); - auto global = module.lookupSymbol(name); - auto attr = global.getInitialValueAttr(); - return isZeroAttr(attr); - }) - .Default([&](Operation *op) { return false; }); -} - -FailureOr> getStaticStrides(Value value) { - auto valueType = value.getType(); - if (!isa(valueType)) - return failure(); - auto memrefType = cast(valueType); - SmallVector strides; - int64_t offset; - if (failed(getStridesAndOffset(memrefType, strides, offset))) { - return failure(); - } - if (llvm::any_of(strides, [](int64_t stride) { - return stride == ShapedType::kDynamic; - })) { - return failure(); - } - return strides; -} - -std::pair getPtrAndOffset(OpBuilder &builder, Value operand, - Location loc) { - auto memrefType = dyn_cast(operand.getType()); - assert(memrefType && "Expect a memref value"); - MemRefType baseMemrefType = MemRefType::get({}, memrefType.getElementType()); - Type basePtrType = builder.getIndexType(); - Type offsetType = builder.getIndexType(); - SmallVector sizesTypes(memrefType.getRank(), offsetType); - SmallVector stridesTypes(memrefType.getRank(), offsetType); - auto meta = builder.create( - loc, baseMemrefType, offsetType, sizesTypes, stridesTypes, operand); - Value alignedPointerAsIndex = - builder.create(loc, basePtrType, - operand); - Value alignedPointerAsI64 = builder.create( - loc, builder.getIntegerType(64), alignedPointerAsIndex); - // TODO: non-POD will require an LLVMTypeConverter. - Value alignedPointer = builder.create( - loc, LLVM::LLVMPointerType::get(builder.getContext()), - alignedPointerAsI64); - Value offset = meta.getOffset(); - return std::make_pair(alignedPointer, offset); -} - -} // namespace utils -} // namespace gcext -} // namespace mlir From 3f57403a8eb332e69d49d1a74fc9b2a20ae5ab73 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 00:18:00 -0700 Subject: [PATCH 07/93] fix CMake --- lib/gc/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/gc/CMakeLists.txt b/lib/gc/CMakeLists.txt index 0e5f94288..781e46a63 100644 --- a/lib/gc/CMakeLists.txt +++ b/lib/gc/CMakeLists.txt @@ -8,4 +8,3 @@ add_subdirectory(CAPI) add_subdirectory(Dialect) add_subdirectory(Transforms) add_subdirectory(ExecutionEngine) -add_subdirectory(Utils) From e39ba7e59bf2473af830ee3555d3df001fef9b2b Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Thu, 25 Jul 2024 20:31:11 -0700 Subject: [PATCH 08/93] fix per comments --- .../Dialect/Microkernel/MicrokernelDialect.h | 1 - .../gc/Dialect/Microkernel/MicrokernelEnum.td | 2 +- .../gc/Dialect/Microkernel/MicrokernelOps.td | 11 +- lib/gc/Dialect/Microkernel/MicrokernelOps.cpp | 123 +++++++----------- 4 files changed, 59 insertions(+), 78 deletions(-) diff --git a/include/gc/Dialect/Microkernel/MicrokernelDialect.h b/include/gc/Dialect/Microkernel/MicrokernelDialect.h index 4cca70cf9..35390ab6f 100644 --- a/include/gc/Dialect/Microkernel/MicrokernelDialect.h +++ b/include/gc/Dialect/Microkernel/MicrokernelDialect.h @@ -9,7 +9,6 @@ #ifndef GC_DIALECTS_MICROKERNELDIALECT_H #define GC_DIALECTS_MICROKERNELDIALECT_H -#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" diff --git a/include/gc/Dialect/Microkernel/MicrokernelEnum.td b/include/gc/Dialect/Microkernel/MicrokernelEnum.td index 3a4e4bad0..fc51dfce9 100644 --- a/include/gc/Dialect/Microkernel/MicrokernelEnum.td +++ b/include/gc/Dialect/Microkernel/MicrokernelEnum.td @@ -13,7 +13,7 @@ include "mlir/IR/EnumAttr.td" include "gc/Dialect/Microkernel/MicrokernelDialect.td" def Microkernel_BrgemmFlags : I64EnumAttr< - "BrgemmFlags", "see: microkernel_brgemm_flags", + "BrgemmFlags", "Flags for indicating optional behaviours of Brgemm", [ I64EnumAttrCase<"NONE", 0, "none">, I64EnumAttrCase<"BETA_0", 1, "beta_0">, diff --git a/include/gc/Dialect/Microkernel/MicrokernelOps.td b/include/gc/Dialect/Microkernel/MicrokernelOps.td index d0d50d04e..76e5424c6 100644 --- a/include/gc/Dialect/Microkernel/MicrokernelOps.td +++ b/include/gc/Dialect/Microkernel/MicrokernelOps.td @@ -71,7 +71,14 @@ def Microkernel_BrgemmEpilogueOp : Microkernel_Op<"brgemm.epilogue"> { }]; } -def BrgemmMemRef : AnyTypeOf<[StaticMemRefRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>, I64]>; +/* A generic input type of Microkernel_BrgemmOp, allowing for `BrgemmMemRef` and I64. + * The `BrgemmMemRef` should be a static MemRef, and for each operand its shape should be: + * Operand A: StaticMemRefRankOf<[F32, BF16, SI8, UI8], [3]>; + * Operand B (none-VNNI): StaticMemRefRankOf<[F32], [3]>; + * Operand B (VNNI): StaticMemRefRankOf<[BF16, SI8, UI8], [4]>; + * Operand C: StaticMemRefRankOf<[F32, SI32], [2]>; + */ +def BrgemmMemRefOrI64 : AnyTypeOf<[StaticMemRefRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>, I64]>; def Microkernel_BrgemmOp : Microkernel_Op<"brgemm"> { let summary = "execute the JITed brgemm kernel."; @@ -82,7 +89,7 @@ def Microkernel_BrgemmOp : Microkernel_Op<"brgemm"> { There is no output. }]; - let arguments = (ins Variadic:$inputs); + let arguments = (ins Variadic:$inputs); let assemblyFormat = [{ `(` $inputs `)` diff --git a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp index bd730a22c..1cf8c0d05 100644 --- a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp +++ b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp @@ -45,9 +45,8 @@ static void printDataTypeImpl(OpAsmPrinter &printer, OpTy op) { auto dataTypes = op.getDataType(); for (size_t idx = 0; idx < dataTypes.size(); idx++) { printer.printAttribute(dataTypes[idx]); - if (idx != dataTypes.size() - 1) { + if (idx != dataTypes.size() - 1) printer << ", "; - } } printer << ") "; } @@ -137,10 +136,9 @@ verifyUniquenessAndConsistency(ArrayAttr flags, Operation *op, return op->emitOpError() << "expected " << flagsName << " to be unique"; // none flag conflicts with all the others if (llvm::is_contained(flagsAsInt, static_cast(FLAGS::NONE)) && - flagsAsInt.size() != 1) { + flagsAsInt.size() != 1) return op->emitOpError() << "'none' " << flagsName << " conflicts with others"; - } return success(); } @@ -166,9 +164,8 @@ ParseResult BrgemmDispatchOp::parse(OpAsmParser &parser, static LogicalResult verifyBrgemmDataTypes(ArrayAttr dtypes, BrgemmDispatchOp op) { - if (dtypes.size() != 2) { + if (dtypes.size() != 2) return op->emitOpError() << "data types size should be 2"; - } auto context = op.getContext(); @@ -185,11 +182,10 @@ static LogicalResult verifyBrgemmDataTypes(ArrayAttr dtypes, #undef ITAttr if (!llvm::any_of(validDataTypes, [=](std::pair type_pair) { - return type_pair.first == dtypes[0] || + return type_pair.first == dtypes[0] && type_pair.second == dtypes[1]; - })) { + })) return op->emitOpError() << "invalid data type pair"; - } return success(); } @@ -203,18 +199,15 @@ static LogicalResult verifyBrgemmFlags(ArrayAttr flags, BrgemmDispatchOp op, bool strideSet = false; bool listSet = false; for (auto flag : flags) { - if (cast(flag).getValue() == BrgemmFlags::STRIDE) { + if (cast(flag).getValue() == BrgemmFlags::STRIDE) strideSet = true; - } - if (cast(flag).getValue() == BrgemmFlags::LIST) { + if (cast(flag).getValue() == BrgemmFlags::LIST) listSet = true; - } } // VNNI flags must be specified only for bf16 type - if (strideSet && listSet) { + if (strideSet && listSet) return op->emitOpError() << "stride and addr flags conflict with each other"; - } return success(); } @@ -224,14 +217,12 @@ LogicalResult BrgemmDispatchOp::verify() { // 'inputs' = [m, n, k, lda, ldb, ldc, stride_a, stride_b] for BRGEMM. size_t expected = 8; size_t numInputs = op.getInputs().size(); - if (numInputs != expected) { + if (numInputs != expected) return op.emitOpError() << "expect " << expected << " args but got: " << numInputs; - } // Verify data types - if (failed(verifyBrgemmDataTypes(op.getDataType(), op))) { + if (failed(verifyBrgemmDataTypes(op.getDataType(), op))) return failure(); - } // Verify leading dims. ArrayRef inputs = op.getInputs(); @@ -258,38 +249,34 @@ LogicalResult BrgemmDispatchOp::verify() { static bool isInVnniLayout(MemRefType memref) { if (!memref.getElementType().isBF16() && !memref.getElementType().isSignedInteger(8) && - !memref.getElementType().isUnsignedInteger(8)) { + !memref.getElementType().isUnsignedInteger(8)) return false; - } auto blockingFactor = 0; - if (memref.getElementType().isBF16()) { + if (memref.getElementType().isBF16()) blockingFactor = 2; - } else if (memref.getElementType().isSignedInteger(8) || - memref.getElementType().isUnsignedInteger(8)) { + else if (memref.getElementType().isSignedInteger(8) || + memref.getElementType().isUnsignedInteger(8)) blockingFactor = 4; - } + return memref.getShape().back() == blockingFactor; } -static bool isTypeCompatible(Type outType, Type operandAType, - Type operandBType) { - if (!outType.isF32() && !outType.isSignedInteger(32)) { +static bool isTypeSupported(Type outType, Type operandAType, + Type operandBType) { + if (!outType.isF32() && !outType.isSignedInteger(32)) return false; - } + if (outType.isF32()) { if (!(operandAType.isF32() && operandBType.isF32()) && - !(operandAType.isBF16() && operandBType.isBF16())) { + !(operandAType.isBF16() && operandBType.isBF16())) return false; - } } if (outType.isSignedInteger(32)) { if (!(operandAType.isSignedInteger(8) || operandAType.isUnsignedInteger(8)) && - (operandBType.isSignedInteger(8) || - operandBType.isUnsignedInteger(8))) { + (operandBType.isSignedInteger(8) || operandBType.isUnsignedInteger(8))) return false; - } } return true; } @@ -300,75 +287,63 @@ LogicalResult BrgemmOp::verify() { SmallVector inputs = brgemmOp.getInputs(); // inputs for BRGEMM: kernel id, A memref, B memref, C memref, batch_size, // addr_len - if (inputs.size() != 6) { + if (inputs.size() != 6) return brgemmOp.emitOpError() << "expect 6" << " inputs but got " << inputs.size(); - } // Verify the dispatch to be an i64. Value dispatch = brgemmOp.getDispatch(); - if (!dispatch.getType().isInteger(64)) { + if (!dispatch.getType().isInteger(64)) return brgemmOp.emitOpError() << "expect an i64 but got " << dispatch.getType() << " for operand 0 (dispatch)"; - } - // Verify the compatibility of memref types + // Verify whether memref types are supported SmallVector memrefOperands = { brgemmOp.getOperandA(), brgemmOp.getOperandB(), brgemmOp.getOutput()}; SmallVector typeOperands = { getElementTypeOrSelf(memrefOperands[0].getType()), getElementTypeOrSelf(memrefOperands[1].getType()), getElementTypeOrSelf(memrefOperands[2].getType())}; - if (!isTypeCompatible(typeOperands[2], typeOperands[0], typeOperands[1])) { + if (!isTypeSupported(typeOperands[2], typeOperands[0], typeOperands[1])) return brgemmOp.emitOpError() << "operands types: " << typeOperands[0] << " X " << typeOperands[1] - << " -> " << typeOperands[2] << " are imcompatible"; - } + << " -> " << typeOperands[2] << " are unsupported"; + + // Verify the rank of the shaped operand A. + auto memrefTypeA = dyn_cast(memrefOperands[0].getType()); + if (memrefTypeA.getRank() != 3) + return brgemmOp.emitOpError() + << "expect a 3d memref for operand A: " << memrefTypeA; - // Verify the rank of the shaped operands. - for (size_t idx = 0; idx < memrefOperands.size(); idx++) { - size_t actualIdx = idx + 1 /*skip dispatch*/; - auto memref = dyn_cast(memrefOperands[idx].getType()); - // Output memref. Must be of rank 2. - if (idx == 2 && memref.getRank() != 2) { + // Verify the rank of the shaped operand B. + auto memrefTypeB = dyn_cast(memrefOperands[1].getType()); + auto dtypeB = typeOperands[1]; + if (!dtypeB.isF32()) { + if (memrefTypeB.getRank() != 4 || !isInVnniLayout(memrefTypeB)) return brgemmOp.emitOpError() - << "expect a 2d layout for operand: " << actualIdx; - } - // Input A memref. Must be of rank 3. - if (idx == 0 && memref.getRank() != 3) { + << "expect a 4d VNNI memref for non-F32 operand: " << memrefTypeB; + } else { + if (memrefTypeB.getRank() != 3) return brgemmOp.emitOpError() - << "expect a 3d memref for operand: " << actualIdx; - } - // Input B memref. Must be in VNNI layout with rank 4 for non-F32. - if (idx == 1) { - auto dtype_B = typeOperands[idx]; - if (!dtype_B.isF32()) { - if (memref.getRank() != 4 && !isInVnniLayout(memref)) { - return brgemmOp.emitOpError() - << "expect a 4d VNNI memref for non-F32 operand: " - << actualIdx; - } - } else { - if (memref.getRank() != 3) { - return brgemmOp.emitOpError() - << "expect a 3d memref for F32 operand: " << actualIdx; - } - } - } + << "expect a 3d memref for F32 operand: " << memrefTypeB; } + // Verify the rank of the shaped operand C. + auto memrefTypeC = dyn_cast(memrefOperands[2].getType()); + if (memrefTypeC.getRank() != 2) + return brgemmOp.emitOpError() + << "expect a 2d memref for operand C: " << memrefTypeC; + // Verify the batch and addrLen to be i64. Value batch = brgemmOp.getBatch(); - if (!batch.getType().isInteger(64)) { + if (!batch.getType().isInteger(64)) return brgemmOp.emitOpError() << "expect an i64 but got " << batch.getType() << " for operand 4 (batch)"; - } Value addrLen = brgemmOp.getAddrLen(); - if (!addrLen.getType().isInteger(64)) { + if (!addrLen.getType().isInteger(64)) return brgemmOp.emitOpError() << "expect an i64 but got " << addrLen.getType() << " for operand 5 (addrLen)"; - } return success(); } From 4acf417d87d1f40a9143ee9e4f1e2bb54cf2c3db Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 29 May 2024 23:19:50 -0700 Subject: [PATCH 09/93] add dialect lowering pass --- cmake/onednn.cmake | 21 +- cmake/onednn_lite_config.cmake | 353 ++++++++++++++ .../Microkernel/BrgemmRuntimeUtils.h | 49 ++ .../gc/Transforms/Microkernel/CMakeLists.txt | 6 + .../Microkernel/MicrokernelPasses.h | 27 ++ .../Microkernel/MicrokernelPasses.td | 75 +++ .../ExecutionEngine/CPURuntime/CMakeLists.txt | 27 +- .../CPURuntime/Microkernel/BrgemmNaive.cpp | 227 +++++++++ .../CPURuntime/Microkernel/BrgemmOnednn.cpp | 136 ++++++ lib/gc/Transforms/CMakeLists.txt | 2 + lib/gc/Transforms/Microkernel/CMakeLists.txt | 26 ++ .../ConvertLinalgToMicrokernel.cpp | 262 +++++++++++ .../ConvertMicrokernelToDnnlFunc.cpp | 222 +++++++++ .../Microkernel/EarlyDispatchMicrokernel.cpp | 194 ++++++++ .../MicrokernelInvariantCodeMotion.cpp | 437 ++++++++++++++++++ src/gc-opt/gc-opt.cpp | 3 + .../Microkernel/linalg-to-microkernel.mlir | 44 ++ .../Microkernel/microkernel-to-dnnl-func.mlir | 70 +++ .../test/gc/cpu-runner/brgemm-parallel.mlir | 50 ++ 19 files changed, 2221 insertions(+), 10 deletions(-) create mode 100644 cmake/onednn_lite_config.cmake create mode 100644 include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h create mode 100644 include/gc/Transforms/Microkernel/CMakeLists.txt create mode 100644 include/gc/Transforms/Microkernel/MicrokernelPasses.h create mode 100644 include/gc/Transforms/Microkernel/MicrokernelPasses.td create mode 100644 lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp create mode 100644 lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp create mode 100644 lib/gc/Transforms/Microkernel/CMakeLists.txt create mode 100644 lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp create mode 100644 lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp create mode 100644 lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp create mode 100644 lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp create mode 100644 test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir create mode 100644 test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir create mode 100644 test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir diff --git a/cmake/onednn.cmake b/cmake/onednn.cmake index 673c4b97e..3d56b9584 100644 --- a/cmake/onednn.cmake +++ b/cmake/onednn.cmake @@ -17,16 +17,19 @@ if (NOT DEFINED DNNL_INCLUDES) ${dnnl_SOURCE_DIR}/src ) set_property(GLOBAL PROPERTY DNNL_INCLUDES ${DNNL_INCLUDES}) + set_property(GLOBAL PROPERTY DNNL_SOURCE_DIR ${dnnl_SOURCE_DIR}) # This allows to generate headers from *.in without adding the library to the build. # If the build is required, remove this and the SKIP_ADD option above. - if (DEFINED CMAKE_GENERATOR) - set(GENERATOR_FLAG "-G ${CMAKE_GENERATOR}") - endif () - execute_process(COMMAND ${CMAKE_COMMAND} ${GENERATOR_FLAG} - -Wno-dev - -S ${dnnl_SOURCE_DIR} - -B ${dnnl_BINARY_DIR} - ${GC_DNNL_CMAKE_ARGS} - ) + # if (DEFINED CMAKE_GENERATOR) + # set(GENERATOR_FLAG "-G ${CMAKE_GENERATOR}") + # endif () + # execute_process(COMMAND ${CMAKE_COMMAND} ${GENERATOR_FLAG} + # -Wno-dev + # -S ${dnnl_SOURCE_DIR} + # -B ${dnnl_BINARY_DIR} + # ${GC_DNNL_CMAKE_ARGS} + # ) + + include(onednn_lite_config) endif () diff --git a/cmake/onednn_lite_config.cmake b/cmake/onednn_lite_config.cmake new file mode 100644 index 000000000..d67e7ca58 --- /dev/null +++ b/cmake/onednn_lite_config.cmake @@ -0,0 +1,353 @@ +include_guard() + +get_property(DNNL_INCLUDES GLOBAL PROPERTY DNNL_INCLUDES) +get_property(DNNL_PATH GLOBAL PROPERTY DNNL_SOURCE_DIR) +if (NOT DEFINED DNNL_INCLUDES) + return() +endif () + +########## This cmake build lite version of onednn, containing only microkernel related codes + +set(APP_NAME "dnnl_brgemm") + +# Build onednn +set(DNNL_BUILD_TESTS OFF) +set(DNNL_BUILD_EXAMPLES OFF) +set(DNNL_ENABLE_JIT_PROFILING OFF) +set(DNNL_BLAS_VENDOR NONE) +set(DNNL_LIBRARY_TYPE STATIC) + +set(DNNL_GPU_RUNTIME "NONE") +if(NOT DEFINED DNNL_CPU_RUNTIME) + set(DNNL_CPU_RUNTIME "OMP") + set(DNNL_CPU_THREADING_RUNTIME "OMP") +endif() + +if(${DNNL_CPU_RUNTIME} STREQUAL "OMP") + find_package(OpenMP REQUIRED) +endif() + +if(${DNNL_CPU_RUNTIME} STREQUAL "TBB") + include("${DNNL_PATH}/cmake/TBB.cmake") +endif() + +########## copied from main cmake file of DNNL +# Set the target architecture. +if(NOT DNNL_TARGET_ARCH) + if(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64.*|AARCH64.*|arm64.*|ARM64.*)") + set(DNNL_TARGET_ARCH "AARCH64") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(ppc64.*|PPC64.*|powerpc64.*)") + set(DNNL_TARGET_ARCH "PPC64") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(s390x.*|S390X.*)") + set(DNNL_TARGET_ARCH "S390X") + else() + set(DNNL_TARGET_ARCH "X64") + endif() +endif() + +if(UNIX OR MINGW) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") +endif() + +########## from cmake/options.cmake +option(DNNL_ENABLE_MAX_CPU_ISA + "enables control of CPU ISA detected by oneDNN via DNNL_MAX_CPU_ISA + environment variable and dnnl_set_max_cpu_isa() function" ON) + +include("${DNNL_PATH}/cmake/Threading.cmake") + +########### copied from cmake/SDL.cmake, for -fstack-protector-strong +if(UNIX) + set(CMAKE_CCXX_FLAGS "-fPIC -Wformat -Wformat-security -ffunction-sections -fdata-sections") + append(CMAKE_CXX_FLAGS_RELEASE "-D_FORTIFY_SOURCE=2") + append(CMAKE_C_FLAGS_RELEASE "-D_FORTIFY_SOURCE=2") + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 4.9) + append(CMAKE_CCXX_FLAGS "-fstack-protector-all") + else() + append(CMAKE_CCXX_FLAGS "-fstack-protector-strong") + endif() + + # GCC might be very paranoid for partial structure initialization, e.g. + # struct { int a, b; } s = { 0, }; + # However the behavior is triggered by `Wmissing-field-initializers` + # only. To prevent warnings on users' side who use the library and turn + # this warning on, let's use it too. Applicable for the library sources + # and interfaces only (tests currently rely on that fact heavily) + append(CMAKE_SRC_CCXX_FLAGS "-Wmissing-field-initializers") + append(CMAKE_EXAMPLE_CCXX_FLAGS "-Wmissing-field-initializers") + elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + append(CMAKE_CCXX_FLAGS "-fstack-protector-all") + elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel") + append(CMAKE_CXX_FLAGS "-fstack-protector") + endif() + append(CMAKE_C_FLAGS "${CMAKE_CCXX_FLAGS}") + append(CMAKE_CXX_FLAGS "${CMAKE_CCXX_FLAGS}") + if(APPLE) + append(CMAKE_SHARED_LINKER_FLAGS "-Wl,-bind_at_load") + append(CMAKE_EXE_LINKER_FLAGS "-Wl,-bind_at_load") + else() + append(CMAKE_EXE_LINKER_FLAGS "-pie") + append(CMAKE_SHARED_LINKER_FLAGS "-Wl,-z,noexecstack -Wl,-z,relro -Wl,-z,now") + append(CMAKE_EXE_LINKER_FLAGS "-Wl,-z,noexecstack -Wl,-z,relro -Wl,-z,now") + endif() +elseif(MSVC AND ${CMAKE_CXX_COMPILER_ID} STREQUAL MSVC) + set(CMAKE_CCXX_FLAGS "/guard:cf") +endif() +########### END of copy of cmake/SDL.cmake + +########### copied from cmake/platform.cmake, for STDC* and -msse4.1 +add_definitions(-D__STDC_LIMIT_MACROS -D__STDC_CONSTANT_MACROS) +if(MSVC) + set(USERCONFIG_PLATFORM "x64") + append_if(DNNL_WERROR CMAKE_CCXX_FLAGS "/WX") + if(${CMAKE_CXX_COMPILER_ID} STREQUAL MSVC) + append(CMAKE_CCXX_FLAGS "/MP") + # int -> bool + append(CMAKE_CCXX_NOWARN_FLAGS "/wd4800") + # unknown pragma + append(CMAKE_CCXX_NOWARN_FLAGS "/wd4068") + # double -> float + append(CMAKE_CCXX_NOWARN_FLAGS "/wd4305") + # UNUSED(func) + append(CMAKE_CCXX_NOWARN_FLAGS "/wd4551") + # int64_t -> int (tent) + append(CMAKE_CCXX_NOWARN_FLAGS "/wd4244") + endif() + if(CMAKE_CXX_COMPILER_ID STREQUAL "Intel") + append(CMAKE_CCXX_FLAGS "/MP") + set(DEF_ARCH_OPT_FLAGS "-QxSSE4.1") + # disable: loop was not vectorized with "simd" + append(CMAKE_CCXX_NOWARN_FLAGS "-Qdiag-disable:13379") + # disable: loop was not vectorized with "simd" + append(CMAKE_CCXX_NOWARN_FLAGS "-Qdiag-disable:15552") + # disable: unknown pragma + append(CMAKE_CCXX_NOWARN_FLAGS "-Qdiag-disable:3180") + # disable: foo has been targeted for automatic cpu dispatch + append(CMAKE_CCXX_NOWARN_FLAGS "-Qdiag-disable:15009") + # disable: disabling user-directed function packaging (COMDATs) + append(CMAKE_CCXX_NOWARN_FLAGS "-Qdiag-disable:11031") + # disable: decorated name length exceeded, name was truncated + append(CMAKE_CCXX_NOWARN_FLAGS "-Qdiag-disable:2586") + # disable: disabling optimization; runtime debug checks enabled + append(CMAKE_CXX_FLAGS_DEBUG "-Qdiag-disable:10182") + endif() + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + append(CMAKE_CCXX_NOEXCEPT_FLAGS "-fno-exceptions") + # Clang cannot vectorize some loops with #pragma omp simd and gets + # very upset. Tell it that it's okay and that we love it + # unconditionally. + append(CMAKE_CCXX_FLAGS "-Wno-pass-failed") + # Clang doesn't like the idea of overriding optimization flags. + # We don't want to optimize jit gemm kernels to reduce compile time + append(CMAKE_CCXX_FLAGS "-Wno-overriding-t-option") + endif() +elseif(UNIX OR MINGW) + append(CMAKE_CCXX_FLAGS "-Wall -Wno-unknown-pragmas") + append_if(DNNL_WERROR CMAKE_CCXX_FLAGS "-Werror") + append(CMAKE_CCXX_FLAGS "-fvisibility=internal") + append(CMAKE_CXX_FLAGS "-fvisibility-inlines-hidden") + append(CMAKE_CCXX_NOEXCEPT_FLAGS "-fno-exceptions") + # compiler specific settings + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + if(DNNL_TARGET_ARCH STREQUAL "AARCH64") + set(DEF_ARCH_OPT_FLAGS "-O3") + # For native compilation tune for the host processor + if (CMAKE_SYSTEM_PROCESSOR STREQUAL CMAKE_HOST_SYSTEM_PROCESSOR) + append(DEF_ARCH_OPT_FLAGS "-mcpu=native") + endif() + elseif(DNNL_TARGET_ARCH STREQUAL "PPC64") + set(DEF_ARCH_OPT_FLAGS "-O3") + # For native compilation tune for the host processor + if (CMAKE_SYSTEM_PROCESSOR STREQUAL CMAKE_HOST_SYSTEM_PROCESSOR) + append(DEF_ARCH_OPT_FLAGS "-mcpu=native") + endif() + elseif(DNNL_TARGET_ARCH STREQUAL "S390X") + set(DEF_ARCH_OPT_FLAGS "-O3") + # For native compilation tune for the host processor + if (CMAKE_SYSTEM_PROCESSOR STREQUAL CMAKE_HOST_SYSTEM_PROCESSOR) + append(DEF_ARCH_OPT_FLAGS "-march=native") + endif() + elseif(DNNL_TARGET_ARCH STREQUAL "X64") + set(DEF_ARCH_OPT_FLAGS "-msse4.1") + endif() + # Clang cannot vectorize some loops with #pragma omp simd and gets + # very upset. Tell it that it's okay and that we love it + # unconditionally. + append(CMAKE_CCXX_NOWARN_FLAGS "-Wno-pass-failed") + if(DNNL_USE_CLANG_SANITIZER MATCHES "Memory(WithOrigin)?") + if(NOT DNNL_CPU_THREADING_RUNTIME STREQUAL "SEQ") + message(WARNING "Clang OpenMP is not compatible with MSan! " + "Expect a lot of false positives!") + endif() + append(CMAKE_CCXX_SANITIZER_FLAGS "-fsanitize=memory") + if(DNNL_USE_CLANG_SANITIZER STREQUAL "MemoryWithOrigin") + append(CMAKE_CCXX_SANITIZER_FLAGS + "-fsanitize-memory-track-origins=2") + append(CMAKE_CCXX_SANITIZER_FLAGS + "-fno-omit-frame-pointer") + endif() + set(DNNL_ENABLED_CLANG_SANITIZER "${DNNL_USE_CLANG_SANITIZER}") + elseif(DNNL_USE_CLANG_SANITIZER STREQUAL "Undefined") + append(CMAKE_CCXX_SANITIZER_FLAGS "-fsanitize=undefined") + append(CMAKE_CCXX_SANITIZER_FLAGS + "-fno-sanitize=function,vptr") # work around linking problems + append(CMAKE_CCXX_SANITIZER_FLAGS "-fno-omit-frame-pointer") + set(DNNL_ENABLED_CLANG_SANITIZER "${DNNL_USE_CLANG_SANITIZER}") + elseif(DNNL_USE_CLANG_SANITIZER STREQUAL "Address") + append(CMAKE_CCXX_SANITIZER_FLAGS "-fsanitize=address") + set(DNNL_ENABLED_CLANG_SANITIZER "${DNNL_USE_CLANG_SANITIZER}") + elseif(DNNL_USE_CLANG_SANITIZER STREQUAL "Thread") + append(CMAKE_CCXX_SANITIZER_FLAGS "-fsanitize=thread") + set(DNNL_ENABLED_CLANG_SANITIZER "${DNNL_USE_CLANG_SANITIZER}") + elseif(DNNL_USE_CLANG_SANITIZER STREQUAL "Leak") + append(CMAKE_CCXX_SANITIZER_FLAGS "-fsanitize=leak") + set(DNNL_ENABLED_CLANG_SANITIZER "${DNNL_USE_CLANG_SANITIZER}") + elseif(NOT DNNL_USE_CLANG_SANITIZER STREQUAL "") + message(FATAL_ERROR + "Unsupported Clang sanitizer '${DNNL_USE_CLANG_SANITIZER}'") + endif() + if(DNNL_ENABLED_CLANG_SANITIZER) + message(STATUS + "Using Clang ${DNNL_ENABLED_CLANG_SANITIZER} " + "sanitizer (experimental!)") + append(CMAKE_CCXX_SANITIZER_FLAGS "-g -fno-omit-frame-pointer") + endif() + + if (DNNL_USE_CLANG_TIDY MATCHES "(CHECK|FIX)" AND ${CMAKE_VERSION} VERSION_LESS "3.6.0") + message(FATAL_ERROR "Using clang-tidy requires CMake 3.6.0 or newer") + elseif(DNNL_USE_CLANG_TIDY MATCHES "(CHECK|FIX)") + find_program(CLANG_TIDY NAMES clang-tidy) + if(NOT CLANG_TIDY) + message(FATAL_ERROR "Clang-tidy not found") + else() + if(DNNL_USE_CLANG_TIDY STREQUAL "CHECK") + set(CMAKE_CXX_CLANG_TIDY ${CLANG_TIDY}) + message(STATUS "Using clang-tidy to run checks") + elseif(DNNL_USE_CLANG_TIDY STREQUAL "FIX") + set(CMAKE_CXX_CLANG_TIDY ${CLANG_TIDY} -fix) + message(STATUS "Using clang-tidy to run checks and fix found issues") + endif() + endif() + endif() + + elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + if(DNNL_TARGET_ARCH STREQUAL "AARCH64") + set(DEF_ARCH_OPT_FLAGS "-O3") + # For native compilation tune for the host processor + if (CMAKE_SYSTEM_PROCESSOR STREQUAL CMAKE_HOST_SYSTEM_PROCESSOR) + append(DEF_ARCH_OPT_FLAGS "-mcpu=native") + endif() + elseif(DNNL_TARGET_ARCH STREQUAL "PPC64") + set(DEF_ARCH_OPT_FLAGS "-O3") + # In GCC, -ftree-vectorize is turned on under -O3 since 2007. + # For native compilation tune for the host processor + if (CMAKE_SYSTEM_PROCESSOR STREQUAL CMAKE_HOST_SYSTEM_PROCESSOR) + append(DEF_ARCH_OPT_FLAGS "-mcpu=native") + endif() + elseif(DNNL_TARGET_ARCH STREQUAL "S390X") + set(DEF_ARCH_OPT_FLAGS "-O3") + # In GCC, -ftree-vectorize is turned on under -O3 since 2007. + # For native compilation tune for the host processor + if (CMAKE_SYSTEM_PROCESSOR STREQUAL CMAKE_HOST_SYSTEM_PROCESSOR) + append(DEF_ARCH_OPT_FLAGS "-march=native") + endif() + elseif(DNNL_TARGET_ARCH STREQUAL "X64") + set(DEF_ARCH_OPT_FLAGS "-msse4.1") + endif() + # suppress warning on assumptions made regarding overflow (#146) + append(CMAKE_CCXX_NOWARN_FLAGS "-Wno-strict-overflow") + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Intel") + set(DEF_ARCH_OPT_FLAGS "-xSSE4.1") + # workaround for Intel Compiler that produces error caused + # by pragma omp simd collapse(..) + append(CMAKE_CCXX_NOWARN_FLAGS "-diag-disable:13379") + append(CMAKE_CCXX_NOWARN_FLAGS "-diag-disable:15552") + # disable `was not vectorized: vectorization seems inefficient` remark + append(CMAKE_CCXX_NOWARN_FLAGS "-diag-disable:15335") + # disable: foo has been targeted for automatic cpu dispatch + append(CMAKE_CCXX_NOWARN_FLAGS "-diag-disable:15009") + endif() +endif() + +append(CMAKE_C_FLAGS "${CMAKE_CCXX_FLAGS} ${DEF_ARCH_OPT_FLAGS}") +append(CMAKE_CXX_FLAGS "${CMAKE_CCXX_FLAGS} ${DEF_ARCH_OPT_FLAGS}") + +########### END of copy of cmake/platform.cmake + +########### setting dummy version info +set(DNNL_VERSION_MAJOR 0) +set(DNNL_VERSION_MINOR 0) +set(DNNL_VERSION_PATCH 0) +set(DNNL_VERSION_HASH "N/A") +########### END of setting dummy version info + +add_definitions(-DDNNL_ENABLE_JIT_PROFILING=0) +configure_file( + "${DNNL_PATH}/include/oneapi/dnnl/dnnl_config.h.in" + "${PROJECT_BINARY_DIR}/include/oneapi/dnnl/dnnl_config.h" +) + +configure_file( + "${DNNL_PATH}/include/oneapi/dnnl/dnnl_version.h.in" + "${PROJECT_BINARY_DIR}/include/oneapi/dnnl/dnnl_version.h" +) + +include_directories( + ${PROJECT_BINARY_DIR}/include + ${DNNL_PATH}/src + ${DNNL_PATH}/include + ) + +if(DNNL_ENABLE_MAX_CPU_ISA) + add_definitions(-DDNNL_ENABLE_MAX_CPU_ISA) +endif() + +add_definitions(-DDISABLE_VERBOSE=1) +file(GLOB_RECURSE DNNL_SOURCES + ${DNNL_PATH}/src/cpu/x64/brgemm/*.cpp + ${DNNL_PATH}/src/cpu/x64/injectors/*.cpp + ${DNNL_PATH}/src/cpu/x64/cpu_isa_traits.cpp + ${DNNL_PATH}/src/cpu/x64/jit_avx512_core_bf16cvt.cpp + ${DNNL_PATH}/src/cpu/x64/amx_tile_configure.[ch]pp + ${DNNL_PATH}/src/cpu/x64/jit_uni_convert_xf16.[ch]pp + ${DNNL_PATH}/src/cpu/jit_utils/jit_utils.cpp + ${DNNL_PATH}/src/cpu/platform.[ch]pp + ${DNNL_PATH}/src/cpu/bfloat16.cpp + ${DNNL_PATH}/src/cpu/binary_injector_utils.cpp + ${DNNL_PATH}/src/common/fpmath_mode.cpp + ${DNNL_PATH}/src/common/utils.cpp + ${DNNL_PATH}/src/common/bfloat16.[ch]pp + ${DNNL_PATH}/src/common/float8.[ch]pp + ${DNNL_PATH}/src/common/memory_debug.cpp + ${DNNL_PATH}/src/common/primitive_attr.cpp + ${DNNL_PATH}/src/common/broadcast_strategy.cpp + ${DNNL_PATH}/src/common/primitive_exec_types.cpp + ${DNNL_PATH}/src/common/memory.cpp + ${DNNL_PATH}/src/common/memory_zero_pad.cpp + ${DNNL_PATH}/src/common/memory_desc_wrapper.cpp + ${DNNL_PATH}/src/common/memory_desc.cpp + ${DNNL_PATH}/src/common/dnnl_thread.cpp + ${DNNL_PATH}/src/common/verbose.cpp + ${DNNL_PATH}/src/common/dnnl_debug.cpp + ${DNNL_PATH}/src/common/dnnl_debug_autogenerated.cpp + ${DNNL_PATH}/src/cpu/x64/jit_avx512_core_fp8cvt.cpp + ) + +add_library(dnnl_brgemm OBJECT ${DNNL_SOURCES}) +set_property(TARGET dnnl_brgemm PROPERTY POSITION_INDEPENDENT_CODE ON + CXX_VISIBILITY_PRESET "hidden" + VISIBILITY_INLINES_HIDDEN 1) + +# install(TARGETS dnnl_brgemm +# EXPORT dnnl_brgemm_export +# RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} +# LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} +# ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +# set_property(GLOBAL APPEND PROPERTY DNNL_SUBDIR_EXTRA_STATIC_LIBS $) +# set_property(GLOBAL APPEND PROPERTY DNNL_SUBDIR_EXTRA_SHARED_LIBS dnnl_brgemm) +# Currently build objs only +set_property(GLOBAL APPEND PROPERTY DNNL_LIB_DEPS + $) diff --git a/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h b/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h new file mode 100644 index 000000000..afe2da9b5 --- /dev/null +++ b/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h @@ -0,0 +1,49 @@ +//===- BrgemmRuntimeUtils.h - Utils for Brgemm Runtime -----------*-C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_MICROKERNEL_BRGEMMRUNTIMEUTILS_H +#define GC_MICROKERNEL_BRGEMMRUNTIMEUTILS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/PatternMatch.h" +#include "oneapi/dnnl/dnnl_types.h" + +namespace mlir::microkernel { + +// these strings contain symbols for BRGEMM interfaces used in mlir pass +static const std::string DNNL_BRGEMM_DISPATCH_NAME = "dnnl_brgemm_dispatch"; +static const std::string DNNL_BRGEMM_TILECFG_NAME = "dnnl_brgemm_tileconfig"; +static const std::string DNNL_BRGEMM_TILERELEASE_NAME = + "dnnl_brgemm_tilerelease"; +static const std::string DNNL_BRGEMM_EXECUTE_NAME = "dnnl_brgemm_execute"; + +static inline int64_t getDnnlDataTypeVal(RewriterBase &rewriter, + Attribute attr) { + auto context = rewriter.getContext(); + auto tattr = dyn_cast_or_null(attr); + assert(tattr); + if (tattr == TypeAttr::get(FloatType::getF32(context))) { + return static_cast(dnnl_f32); + } else if (tattr == TypeAttr::get(FloatType::getBF16(context))) { + return static_cast(dnnl_bf16); + } else if (tattr == TypeAttr::get( + IntegerType::get(context, 32, IntegerType::Signed))) { + return static_cast(dnnl_s32); + } else if (tattr == + TypeAttr::get(IntegerType::get(context, 8, IntegerType::Signed))) { + return static_cast(dnnl_s8); + } else if (tattr == TypeAttr::get(IntegerType::get(context, 8, + IntegerType::Unsigned))) { + return static_cast(dnnl_u8); + } + return static_cast(dnnl_data_type_undef); +} + +}; // namespace mlir::microkernel + +#endif // GC_MICROKERNEL_BRGEMMRUNTIMEUTILS_H diff --git a/include/gc/Transforms/Microkernel/CMakeLists.txt b/include/gc/Transforms/Microkernel/CMakeLists.txt new file mode 100644 index 000000000..2e345775c --- /dev/null +++ b/include/gc/Transforms/Microkernel/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS MicrokernelPasses.td) +mlir_tablegen(MicrokernelPasses.h.inc --gen-pass-decls -name Microkernel) +mlir_tablegen(MicrokernelPasses.capi.h.inc -gen-pass-capi-header --prefix Microkernel) +mlir_tablegen(MicrokernelPasses.capi.cpp.inc -gen-pass-capi-impl --prefix Microkernel) +add_public_tablegen_target(MLIRMicrokernelPassesIncGen) +add_mlir_doc(MicrokernelPasses GraphCompilerMicrokernelPasses ./ -gen-pass-doc) diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.h b/include/gc/Transforms/Microkernel/MicrokernelPasses.h new file mode 100644 index 000000000..a053253e6 --- /dev/null +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.h @@ -0,0 +1,27 @@ +//===- MicrokernelPasses.h - Graph Compiler microkerenl passes --*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_MICROKERNELPASSES_H +#define GC_MICROKERNELPASSES_H + +#include "gc/Dialect/Microkernel/MicrokernelDialect.h" +#include "gc/Dialect/Microkernel/MicrokernelOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace microkernel { +#define GEN_PASS_DECL +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define GEN_PASS_REGISTRATION +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" +} // namespace microkernel +} // namespace mlir + +#endif // GC_MICROKERNELPASSES_H diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td new file mode 100644 index 000000000..c015dddf3 --- /dev/null +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.td @@ -0,0 +1,75 @@ +//===- MicrokernelPasses.td - Graph Compiler microkernel passes *- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===--------------------------------------------------------------------------===// + +#ifndef GC_DIALECT_MICROKERNELPASSES +#define GC_DIALECT_MICROKERNELPASSES + +include "mlir/Pass/PassBase.td" + +def ConvertLinalgToMicrokernel: Pass<"convert-linalg-to-microkernel", "::mlir::func::FuncOp"> { + let summary = "Lower eligible linalg ops to microkernels"; + let description = [{ + Convert eligible linalg ops to microkernel dialects based on pattern matching. + For example: + ``` + scf.forall { + linalg.fill ins(...) outs(...) -> tensor<...> + linalg.batch_reduce_matmul ins(...) outs(...) -> tensor<...> + } + ``` + Will be changed into + ``` + scf.forall { + linalg.fill ins(...) outs(...) -> tensor<...> + %0 = microkernel.brgemm.dispatch(...) + microkernel.brgemm.prologue(%0) + microkernel.brgemm(%0, ...) + microkernel.brgemm.epilogue(%0) + } + ``` + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", + "linalg::LinalgDialect", + "microkernel::MicrokernelDialect"]; +} + +def ConvertMicrokernelToDnnlFunc: Pass<"convert-microkernel-to-dnnl-func", "::mlir::ModuleOp"> { + let summary = "Lower microkernel dialects to dnnl func call"; + let description = [{ + Convert microkernel dialects to runtime function call to oneDNN library. + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", + "LLVM::LLVMDialect", + "microkernel::MicrokernelDialect"]; +} + +def EarlyDispatchMicrokernel: Pass<"early-dispatch-microkernel", "::mlir::ModuleOp"> { + let summary = "Early dispatch microkernel during compile time"; + let description = [{ + Early dispatch microkernel during compile time. + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", + "LLVM::LLVMDialect", + "microkernel::MicrokernelDialect"]; +} + +def MicrokernelInvariantCodeMotion: Pass<"microkernel-invariant-code-motion", "::mlir::ModuleOp"> { + let summary = "Hoist invariant microkernel code to avoid redundant execution"; + let description = [{ + Hoist invariant microkernel code to avoid redundant execution. + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", + "LLVM::LLVMDialect", + "microkernel::MicrokernelDialect"]; +} + +#endif // GC_DIALECT_MICROKERNELPASSES diff --git a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt index 6be58e28f..180413719 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt @@ -5,11 +5,36 @@ else() add_definitions("-DGC_NEEDS_OMP_WRAPPER=1") endif() +file(GLOB_RECURSE MICROKERNEL_RUNTIME_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/*.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/*.c + ) + +if (GC_MLIR_NAIVE_BRGEMM) + string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/BrgemmOnednn.cpp;" "" MICROKERNEL_RUNTIME_SOURCES "${MICROKERNEL_RUNTIME_SOURCES}") +else() + string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/BrgemmNaive.cpp;" "" MICROKERNEL_RUNTIME_SOURCES "${MICROKERNEL_RUNTIME_SOURCES}") +endif() + +include(onednn) + +get_property(DNNL_INCLUDES GLOBAL PROPERTY DNNL_INCLUDES) +get_property(DNNL_LIB_DEPS GLOBAL PROPERTY DNNL_LIB_DEPS) + +include_directories(${DNNL_INCLUDES}) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fopenmp") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") add_mlir_library(GCCpuRuntime SHARED Parallel.cpp + ${MICROKERNEL_RUNTIME_SOURCES} + + DEPENDS + dnnl_brgemm + + LINK_LIBS PRIVATE + ${DNNL_LIB_DEPS} EXCLUDE_FROM_LIBMLIR - ) \ No newline at end of file + ) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp new file mode 100644 index 000000000..3cb585c31 --- /dev/null +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp @@ -0,0 +1,227 @@ +//===-- BrgemmNaive.cpp - BRGEMM Naive Implementation -----------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include + +#include "oneapi/dnnl/dnnl_types.h" + +namespace { + +struct bf16_t { + uint16_t storage_; + union caster_t { + uint32_t vl; + float vf; + }; + operator float() const { + caster_t val; + val.vl = uint32_t(storage_) << 16; + return val.vf; + } + bool operator==(const bf16_t &compare_to) const { + return storage_ == compare_to.storage_; + } + bool operator!=(const bf16_t &compare_to) const { + return storage_ != compare_to.storage_; + } + bf16_t(float v) { + if (std::isnan(v)) { + storage_ = UINT32_C(0x7FC0); + } else { + caster_t caster; + caster.vf = v; + uint32_t rounding_bias = ((caster.vl >> 16) & 1) + UINT32_C(0x7FFF); + storage_ = static_cast((caster.vl + rounding_bias) >> 16); + } + } + bf16_t() : storage_(0) {} + inline static bf16_t from_storage(uint16_t v) { + bf16_t ret; + ret.storage_ = v; + return ret; + } +}; + +struct brgemm_params_t { + int64_t M, N, K; + int64_t LDA, LDB, LDC; + int64_t stride_a, stride_b; + float beta; + int64_t dtypeA, dtypeB; + brgemm_params_t(int64_t m, int64_t n, int64_t k, int64_t lda, int64_t ldb, + int64_t ldc, int64_t sa, int64_t sb, float b, int64_t da, + int64_t db) + : M(m), N(n), K(k), LDA(lda), LDB(ldb), LDC(ldc), stride_a(sa), + stride_b(sb), beta(b), dtypeA(da), dtypeB(db) {} +}; + +}; // namespace + +static int naive_brgemm_execute_fp32(brgemm_params_t params, void *A, + uint64_t A_offset, void *B, + uint64_t B_offset, void *C, + uint64_t C_offset, int num) { + float *Abuf = (float *)A; + float *Bbuf = (float *)B; + float *Cbuf = (float *)C; + Abuf += A_offset; + Bbuf += B_offset; + Cbuf += C_offset; + for (int i = 0; i < num; i++) { + // a is MxK + for (int m = 0; m < params.M; m++) { + for (int n = 0; n < params.N; n++) { + for (int k = 0; k < params.K; k++) { + Cbuf[m * params.LDC + n] += + Abuf[m * params.LDA + k] * Bbuf[k * params.LDB + n]; + } + } + } + Abuf += params.stride_a; + Bbuf += params.stride_b; + } + return 0; +} + +static void naive_brgemm_execute_bf16(brgemm_params_t params, void *A, + uint64_t A_offset, void *B, + uint64_t B_offset, void *C, + uint64_t C_offset, int num) { + bf16_t *Abuf = (bf16_t *)A; + bf16_t *Bbuf = (bf16_t *)B; + float *Cbuf = (float *)C; + Abuf += A_offset; + Bbuf += B_offset; + Cbuf += C_offset; + for (int i = 0; i < num; i++) { + // a is MxK + // b is KxNx2k (vnni format) + for (int m = 0; m < params.M; m++) { + for (int n = 0; n < params.N; n++) { + for (int k = 0; k < params.K; k += 2) { + Cbuf[m * params.LDC + n] += + Abuf[m * params.LDA + k] * Bbuf[k * params.LDB + 2 * n]; + if (k + 1 < params.K) { + // simulate vnni padding + Cbuf[m * params.LDC + n] += + Abuf[m * params.LDA + k + 1] * Bbuf[k * params.LDB + 2 * n + 1]; + } + } + } + } + Abuf += params.stride_a; + Bbuf += params.stride_b; + } +} + +template +static void naive_brgemm_execute_int8(brgemm_params_t params, void *A, + uint64_t A_offset, void *B, + uint64_t B_offset, void *C, + uint64_t C_offset, int num) { + TA *Abuf = (TA *)A; + TB *Bbuf = (TB *)B; + int32_t *Cbuf = (int32_t *)C; + Abuf += A_offset; + Bbuf += B_offset; + Cbuf += C_offset; + for (int i = 0; i < num; i++) { + // a is MxK + // b is KxNx4k (vnni format) + for (int m = 0; m < params.M; m++) { + for (int n = 0; n < params.N; n++) { + for (int k = 0; k < params.K; k += 4) { + Cbuf[m * params.LDC + n] += + Abuf[m * params.LDA + k] * Bbuf[k * params.LDB + 4 * n]; + if (k + 1 < params.K) { + // simulate vnni padding + Cbuf[m * params.LDC + n] += + Abuf[m * params.LDA + k + 1] * Bbuf[k * params.LDB + 4 * n + 1]; + } + if (k + 2 < params.K) { + Cbuf[m * params.LDC + n] += + Abuf[m * params.LDA + k + 2] * Bbuf[k * params.LDB + 4 * n + 2]; + } + if (k + 3 < params.K) { + Cbuf[m * params.LDC + n] += + Abuf[m * params.LDA + k + 3] * Bbuf[k * params.LDB + 4 * n + 3]; + } + } + } + } + Abuf += params.stride_a; + Bbuf += params.stride_b; + } +} + +static std::vector brgemm_list; + +extern "C" { + +int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, + int64_t LDB, int64_t LDC, int64_t stride_a, + int64_t stride_b, float beta, int64_t dtypeA, + int64_t dtypeB) { + // simply store the given parameters for naive BRGEMM + brgemm_list.emplace_back(brgemm_params_t(M, N, K, LDA, LDB, LDC, stride_a, + stride_b, beta, dtypeA, dtypeB)); + // std::cout << ">>>>> dnnl_brgemm_dispatch: " << brgemm_list.size() - 1 << + // std::endl; + return brgemm_list.size() - 1; +} + +void dnnl_brgemm_tileconfig(int64_t kernel) { return; } + +void dnnl_brgemm_tilerelease() { return; } + +void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void *B, + uint64_t B_offset, void *C, uint64_t C_offset, + int num) { + assert(kernel >= 0 && kernel < (int64_t)brgemm_list.size() && + "Invalid kernel handler"); + brgemm_params_t ¶ms = brgemm_list[kernel]; + if (params.dtypeA == static_cast(dnnl_f32) && + params.dtypeB == static_cast(dnnl_f32)) { + // std::cout << ">>>>> dnnl_brgemm_execute_f32: " << kernel << std::endl; + naive_brgemm_execute_fp32(params, A, A_offset, B, B_offset, C, C_offset, + num); + } else if (params.dtypeA == static_cast(dnnl_bf16) && + params.dtypeB == static_cast(dnnl_bf16)) { + // std::cout << ">>>>> dnnl_brgemm_execute_bf16: " << kernel << std::endl; + naive_brgemm_execute_bf16(params, A, A_offset, B, B_offset, C, C_offset, + num); + } else if (params.dtypeA == static_cast(dnnl_s8) && + params.dtypeB == static_cast(dnnl_s8)) { + // std::cout << ">>>>> dnnl_brgemm_execute_s8s8: " << kernel << std::endl; + naive_brgemm_execute_int8(params, A, A_offset, B, B_offset, + C, C_offset, num); + } else if (params.dtypeA == static_cast(dnnl_s8) && + params.dtypeB == static_cast(dnnl_u8)) { + // std::cout << ">>>>> dnnl_brgemm_execute_s8u8: " << kernel << std::endl; + naive_brgemm_execute_int8(params, A, A_offset, B, B_offset, + C, C_offset, num); + } else if (params.dtypeA == static_cast(dnnl_u8) && + params.dtypeB == static_cast(dnnl_u8)) { + // std::cout << ">>>>> dnnl_brgemm_execute_u8u8: " << kernel << std::endl; + naive_brgemm_execute_int8(params, A, A_offset, B, + B_offset, C, C_offset, num); + } else if (params.dtypeA == static_cast(dnnl_u8) && + params.dtypeB == static_cast(dnnl_s8)) { + // std::cout << ">>>>> dnnl_brgemm_execute_u8s8: " << kernel << std::endl; + naive_brgemm_execute_int8(params, A, A_offset, B, B_offset, + C, C_offset, num); + } else { + assert(false && "unsupported input dtypes"); + } +} +} diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp new file mode 100644 index 000000000..e5e84e24d --- /dev/null +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -0,0 +1,136 @@ +//===-- BrgemmNaive.cpp - BRGEMM Naive Implementation -----------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include +#ifdef _WIN32 +#include +#else +#include +#endif + +// manually include xbyak header here to avoid no-exception compile issue +#define XBYAK_NO_EXCEPTION +#include // NOLINT +#undef XBYAK_NO_EXCEPTION + +#include +#include +#include +#include + +using namespace dnnl::impl::cpu::x64; + +namespace dnnl { +namespace impl { +namespace graph { +namespace utils { +// dummy definition for DNNL lite linkage +__attribute__((weak)) void print_verbose_header() {} +} // namespace utils +} // namespace graph +} // namespace impl +} // namespace dnnl + +static constexpr int PALETTE_SIZE = 64; +static std::vector brgemm_desc_list; +static std::vector brgemm_kernel_list; + +extern "C" { + +int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, + int64_t LDB, int64_t LDC, int64_t stride_a, + int64_t stride_b, float beta, int64_t dtypeA, + int64_t dtypeB) { + std::cout << ">>> Brgemm dispatch: " << std::endl; + brgemm_desc_list.emplace_back(brgemm_desc_t()); + brgemm_kernel_list.emplace_back(nullptr); + + brgemm_desc_t &desc = brgemm_desc_list.back(); + auto &kernel = brgemm_kernel_list.back(); + brgemm_strides_t stride_info{stride_a, stride_b}; + + dnnl::impl::status_t status = brgemm_desc_init( + &desc, cpu_isa_t::isa_undef, brgemm_batch_kind_t::brgemm_strd, + static_cast(dtypeA), + static_cast(dtypeB), false, false, + brgemm_layout_t::brgemm_row_major, 1.0f, beta, LDA, LDB, LDC, M, N, K, + &stride_info); + assert(status == dnnl::impl::status::success && + "Failed to initialize BRGEMM descriptor"); + + status = brgemm_kernel_create(&kernel, desc); + assert(status == dnnl::impl::status::success && + "Failed to JIT BRGEMM kernel"); + + return brgemm_desc_list.size() - 1; +} + +void dnnl_brgemm_tileconfig(int64_t kernel_idx) { + assert(kernel_idx >= 0 && kernel_idx < (int64_t)brgemm_desc_list.size() && + "Invalid kernel handler"); + std::cout << ">>> Brgemm tileconfig: " << kernel_idx << std::endl; + + brgemm_desc_t &desc = brgemm_desc_list[kernel_idx]; + if (!desc.is_tmm) { + return; + } + + char palette_buffer[PALETTE_SIZE]; + dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer); + assert(status == dnnl::impl::status::success && + "Failed to initialize palette for BRGEMM"); + + amx_tile_configure(palette_buffer); +} + +void dnnl_brgemm_tilerelease() { + if (!mayiuse(avx512_core_amx)) { + return; + } + std::cout << ">>> Brgemm tilerelease" << std::endl; + + amx_tile_release(); +} + +void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset, + void *B, uint64_t B_offset, void *C, uint64_t C_offset, + int num) { + assert(kernel_idx >= 0 && kernel_idx < (int64_t)brgemm_desc_list.size() && + "Invalid kernel handler"); + + std::cout << ">>> Brgemm Execute: " << kernel_idx << std::endl; + brgemm_desc_t &desc = brgemm_desc_list[kernel_idx]; + brgemm_kernel_t *kernel = brgemm_kernel_list[kernel_idx]; + + size_t A_offset_in_bytes = + dnnl::impl::types::data_type_size(desc.dt_a) * A_offset; + size_t B_offset_in_bytes = + dnnl::impl::types::data_type_size(desc.dt_b) * A_offset; + size_t C_offset_in_bytes = + dnnl::impl::types::data_type_size(desc.dt_c) * A_offset; + +#ifdef _WIN32 + // fix-me: (win32) impl + static size_t scratch_size = 2 * 4096; +#else + static size_t scratch_size = 2 * getpagesize(); +#endif + // TODO(haixin): use thread local buffer for scratch + char *scratch = new char[scratch_size]; + brgemm_kernel_execute(kernel, num, A + A_offset_in_bytes, + B + B_offset_in_bytes, nullptr, C + C_offset_in_bytes, + (void *)scratch); + delete scratch; +} +} diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 1b4f2cb73..f806eb435 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -28,4 +28,6 @@ add_mlir_library(GCPasses ) set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS GCPasses) + add_subdirectory(GPU) +add_subdirectory(Microkernel) diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt new file mode 100644 index 000000000..c2d88e33c --- /dev/null +++ b/lib/gc/Transforms/Microkernel/CMakeLists.txt @@ -0,0 +1,26 @@ +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR) + +include(onednn) + +get_property(DNNL_INCLUDES GLOBAL PROPERTY DNNL_INCLUDES) + +include_directories(${DNNL_INCLUDES}) + +add_mlir_dialect_library(MLIRMicrokernelTransforms + ConvertLinalgToMicrokernel.cpp + ConvertMicrokernelToDnnlFunc.cpp + EarlyDispatchMicrokernel.cpp + MicrokernelInvariantCodeMotion.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/ + + DEPENDS + MLIRMicrokernelPassesIncGen + + LINK_LIBS PUBLIC + ${MLIR_LINK_COMPONENTS} + GCMLIRUtils + ) + +set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS MLIRMicrokernelTransforms) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp new file mode 100644 index 000000000..4e46f35ac --- /dev/null +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -0,0 +1,262 @@ +//===- ConvertLinalgToMicrokernel.cpp - Linalg To Microkernel -*- C++ -*--===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" +#include "gc/Utils/StructuredOpMatcher.h" +#include "gc/Utils/ValueUtils.h" + +namespace mlir::microkernel { +#define GEN_PASS_DEF_CONVERTLINALGTOMICROKERNEL +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define DEBUG_TYPE "convert-linalg-to-microkernel" + +struct BrgemmInfo { + enum BrgemmMode { STRIDE_MODE, LIST_MODE }; + int64_t m; + int64_t n; + int64_t k; + int64_t batchSize; + int64_t addrLen; + + int64_t lda; + int64_t ldb; + int64_t ldc; + int64_t strideA; + int64_t strideB; + + bool isInitOutput; + BrgemmMode mode; +}; + +// Return the position of `dim` in the codomain of `operand`. +static std::optional +getPosInCodomain(unsigned dim, OpOperand *operand, linalg::LinalgOp linalgOp) { + assert(operand->getOwner() == linalgOp); + return linalgOp.getMatchingIndexingMap(operand).getResultPosition( + getAffineDimExpr(dim, linalgOp.getContext())); +} + +static FailureOr +inferBrgemmInfo(linalg::LinalgOp linalgOp, + const linalg::ContractionDimensions &dims) { + unsigned mPos = dims.m[0]; + unsigned nPos = dims.n[0]; + unsigned kPos = dims.k.back(); + std::optional batchPos; + if (dims.k.size() == 2) + batchPos = dims.k.front(); + + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Candidate dims: " + << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] m: " << mPos << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] n: " << nPos << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] k: " << kPos << "\n"); + if (batchPos) + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] batch: " << batchPos << "\n"); + else + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] no batch dim\n"); + + auto checkStridesAndGetLda = [&](unsigned minorDim, unsigned majorDim, + OpOperand *operand) -> FailureOr { + auto minorDimPosInCodomain = getPosInCodomain(minorDim, operand, linalgOp); + auto majorDimPosInCodomain = getPosInCodomain(majorDim, operand, linalgOp); + if (!minorDimPosInCodomain || !majorDimPosInCodomain) + return failure(); + auto stridesOnOperand = gcext::utils::getStaticStrides(operand->get()); + if (failed(stridesOnOperand) || + (*stridesOnOperand)[*minorDimPosInCodomain] != 1) + return failure(); + return (*stridesOnOperand)[*majorDimPosInCodomain]; + }; + + OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; + OpOperand *operandB = linalgOp.getDpsInputOperands()[1]; + OpOperand *operandC = &linalgOp.getDpsInitsMutable()[0]; + + // A(m, k) + auto lda = checkStridesAndGetLda(kPos, mPos, operandA); + if (failed(lda)) + return failure(); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on A: OK\n"); + + // B(k, n) + auto ldb = checkStridesAndGetLda(nPos, kPos, operandB); + if (failed(ldb)) + return failure(); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on B: OK\n"); + + // C(m, n) + auto ldc = checkStridesAndGetLda(nPos, mPos, operandC); + if (failed(ldc)) + return failure(); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on C: OK\n"); + + int64_t strideA = 1; + int64_t strideB = 1; + if (batchPos) { + auto batchPosCodomainA = + getPosInCodomain(batchPos.value(), operandA, linalgOp); + auto stridesOnA = gcext::utils::getStaticStrides(operandA->get()); + strideA = (*stridesOnA)[*batchPosCodomainA]; + + auto batchPosCodomainB = + getPosInCodomain(batchPos.value(), operandB, linalgOp); + auto stridesOnB = gcext::utils::getStaticStrides(operandB->get()); + strideB = (*stridesOnB)[*batchPosCodomainB]; + } + + auto loops = linalgOp.computeStaticLoopSizes(); + int64_t batchVal = (batchPos) ? loops[batchPos.value()] : 0; + + BrgemmInfo info{loops[mPos], + loops[nPos], + loops[kPos], + batchVal, + 0 /* addrLen useless under stride mode */, + *lda, + *ldb, + *ldc, + strideA, + strideB}; + info.isInitOutput = false; + info.mode = BrgemmInfo::STRIDE_MODE; + + return info; +} + +static FailureOr getBrgemmInfo(linalg::LinalgOp linalgOp) { + using namespace mlir::gcext::utils::structured_match; + auto validBrgemmMatcher = StructuredOpMatcher::make() + .output(MatchAll(), HasStaticShape()) + .input(MatchAll(), HasStaticShape()) + .output(MatchAll(), HasStaticStrides()) + .input(MatchAll(), HasStaticStrides()) + .operation(NumOfLoops(GreaterThanOrEqualTo(3))); + // clang-format on + if (!validBrgemmMatcher.match(linalgOp)) + return failure(); + + auto contractionDims = linalg::inferContractionDims(linalgOp); + if (failed(contractionDims)) { + LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Not a valid contraction\n"); + return failure(); + } + if (contractionDims->m.size() != 1 || contractionDims->n.size() != 1 || + (contractionDims->k.size() != 2 && contractionDims->k.size() != 1) || + contractionDims->batch.size() != 0) { + LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions\n"); + return failure(); + } + unsigned classifiedLoops = + contractionDims->m.size() + contractionDims->n.size() + + contractionDims->k.size() + contractionDims->batch.size(); + if (linalgOp.getNumLoops() != classifiedLoops) { + LLVM_DEBUG(llvm::dbgs() + << "[checkStructure] Not all loops are classified\n"); + return failure(); + } + + return inferBrgemmInfo(linalgOp, *contractionDims); +} + +// Replace linalgOp with a set of microkernel ops +static void replaceOpWithMicrokernelOpSet(PatternRewriter &rewriter, + linalg::LinalgOp linalgOp, + const BrgemmInfo &info) { + assert(linalgOp.getDpsInputs().size() == 2); + OpBuilder::InsertionGuard guard(rewriter); + + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + Location loc = linalgOp.getLoc(); + SmallVector brgemmFlags; + if (info.isInitOutput) { + brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( + rewriter.getContext(), microkernel::BrgemmFlags::BETA_0)); + } + if (info.mode == BrgemmInfo::STRIDE_MODE) { + brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( + rewriter.getContext(), microkernel::BrgemmFlags::STRIDE)); + } else if (info.mode == BrgemmInfo::LIST_MODE) { + brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( + rewriter.getContext(), microkernel::BrgemmFlags::LIST)); + } + + SmallVector brgemmDtypes{ + TypeAttr::get(getElementTypeOrSelf(linalgOp.getDpsInputs()[0].getType())), + TypeAttr::get( + getElementTypeOrSelf(linalgOp.getDpsInputs()[1].getType()))}; + + // create dispatch op + auto flags = rewriter.getArrayAttr(brgemmFlags); + auto dtypes = rewriter.getArrayAttr(brgemmDtypes); + DenseI64ArrayAttr dims = DenseI64ArrayAttr::get( + rewriter.getContext(), + ArrayRef{info.m, info.n, info.k, info.lda, info.ldb, info.ldc, + info.strideA, info.strideB}); + Value dispatched = rewriter.create( + loc, integer64, dims, flags, dtypes); + + // create prologue op + rewriter.create(loc, dispatched); + + // create brgemm invoke op + Value batchDim = rewriter.create( + loc, integer64, rewriter.getIntegerAttr(integer64, info.batchSize)); + Value lenDim = rewriter.create( + loc, integer64, rewriter.getIntegerAttr(integer64, info.addrLen)); + SmallVector invokeOperands; + invokeOperands.push_back(dispatched); + invokeOperands.append(linalgOp->getOperands().begin(), + linalgOp->getOperands().end()); + invokeOperands.push_back(batchDim); + invokeOperands.push_back(lenDim); + rewriter.create(loc, invokeOperands); + + // create epilogue op & replace original op + rewriter.replaceOpWithNewOp(linalgOp, + dispatched); +} + +class ConvertBatchReduceMatmulToBrgemmRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(linalg::BatchReduceMatmulOp op, + PatternRewriter &rewriter) const final { + auto brgemmInfo = getBrgemmInfo(op); + if (failed(brgemmInfo)) + return failure(); + replaceOpWithMicrokernelOpSet(rewriter, op, *brgemmInfo); + return success(); + } +}; + +class ConvertLinalgToMicrokernel + : public impl::ConvertLinalgToMicrokernelBase { +public: + using impl::ConvertLinalgToMicrokernelBase< + ConvertLinalgToMicrokernel>::ConvertLinalgToMicrokernelBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) + signalPassFailure(); + } +}; + +} // namespace mlir::microkernel diff --git a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp new file mode 100644 index 000000000..0d50aee71 --- /dev/null +++ b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp @@ -0,0 +1,222 @@ +//===- ConvertMicrokernelToDnnlFunc.cpp ------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" +#include "gc/Utils/ValueUtils.h" + +namespace mlir::microkernel { +#define GEN_PASS_DEF_CONVERTMICROKERNELTODNNLFUNC +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define DEBUG_TYPE "convert-microkernel-to-dnnl-func" + +static func::CallOp createFuncCall(RewriterBase &rewriter, Location loc, + ModuleOp module, const std::string &funcName, + ArrayRef operands, + ArrayRef operandTypes, + ArrayRef resultTypes) { + FlatSymbolRefAttr fnName = SymbolRefAttr::get(module->getContext(), funcName); + auto fnType = rewriter.getFunctionType(operandTypes, resultTypes); + + if (!module.lookupSymbol(fnName.getAttr())) { + OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(module.getBody(), + std::prev(module.getBody()->end())); + func::FuncOp funcOp = + rewriter.create(loc, fnName.getValue(), fnType); + funcOp.setPrivate(); + } + + func::CallOp call = rewriter.create(loc, fnName.getValue(), + resultTypes, operands); + return call; +} + +class ConvertBrgemmDispatchOpRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + // runtime func for dnnl brgemm dispatch: + // int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, + // int64_t LDB, int64_t LDC, int64_t stride_a, int64_t stride_b, float beta, + // int64_t dtypeA, int64_t dtypeB); + LogicalResult matchAndRewrite(microkernel::BrgemmDispatchOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + ModuleOp module = op->template getParentOfType(); + + SmallVector operands; + SmallVector operandTypes; + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + FloatType float32 = FloatType::getF32(rewriter.getContext()); + + // M, N, K, LDA, LDB, LDC, stride_a, stride_b + // they are in the same order with BrgemmDispatchOp inputs + ArrayRef inputs = op.getInputsAttr().asArrayRef(); + for (auto input : inputs) { + auto attr = IntegerAttr::get(rewriter.getI64Type(), input); + operands.push_back( + rewriter.create(loc, integer64, attr)); + operandTypes.push_back(integer64); + } + + // beta + auto flags = op.getFlagsAttr(); + float beta = 1.0f; + for (auto flag : flags) { + auto brgemmFlag = dyn_cast_or_null(flag); + if (!brgemmFlag) + return rewriter.notifyMatchFailure(op, "unknown flag for BRGEMM"); + if (brgemmFlag.getValue() == BrgemmFlags::LIST) + return rewriter.notifyMatchFailure( + op, "addr mode BRGEMM not supported yet"); + if (brgemmFlag.getValue() == BrgemmFlags::BETA_0) + beta = 0.0f; + } + auto betaAttr = FloatAttr::get(rewriter.getF32Type(), beta); + operands.push_back( + rewriter.create(loc, float32, betaAttr)); + operandTypes.push_back(float32); + + // dtypeA, dtypeB + auto dtypes = op.getDataType(); + if (dtypes.size() != 2) + return rewriter.notifyMatchFailure( + op, "invalid number of DataType for BRGEMM"); + auto dtypeAAttr = IntegerAttr::get(rewriter.getI64Type(), + getDnnlDataTypeVal(rewriter, dtypes[0])); + auto dtypeBAttr = IntegerAttr::get(rewriter.getI64Type(), + getDnnlDataTypeVal(rewriter, dtypes[1])); + operands.push_back( + rewriter.create(loc, integer64, dtypeAAttr)); + operandTypes.push_back(integer64); + operands.push_back( + rewriter.create(loc, integer64, dtypeBAttr)); + operandTypes.push_back(integer64); + + func::CallOp call = + createFuncCall(rewriter, loc, module, DNNL_BRGEMM_DISPATCH_NAME, + operands, operandTypes, {integer64}); + rewriter.replaceOp(op, call.getResult(0)); + return success(); + } +}; + +class ConvertBrgemmPrologueOpRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + // dnnl runtime func for brgemm set hw context: + // void dnnl_brgemm_tileconfig(int64_t kernel_idx); + LogicalResult matchAndRewrite(microkernel::BrgemmPrologueOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + ModuleOp module = op->template getParentOfType(); + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + func::CallOp call = + createFuncCall(rewriter, loc, module, DNNL_BRGEMM_TILECFG_NAME, + op.getInputs(), {integer64}, {}); + rewriter.replaceOp(op, call); + return success(); + } +}; + +class ConvertBrgemmOpRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + // runtime func for stride mode dnnl brgemm execution: + // void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void + // *B, uint64_t B_offset, void *C, uint64_t C_offset, int num) + LogicalResult matchAndRewrite(microkernel::BrgemmOp op, + PatternRewriter &rewriter) const final { + // currently only support stride mode, directly call it + // TODO(haixin): support addr mode execution, through detecting dispatch + // target + + auto context = rewriter.getContext(); + Location loc = op.getLoc(); + ModuleOp module = op->template getParentOfType(); + + SmallVector operands; + SmallVector operandTypes; + + auto raw_operands = op->getOperands(); + size_t raw_op_cnt = 0; + for (Value operand : raw_operands) { + if (raw_op_cnt++ >= 5) { + // drop the last operand for `addr list length` + break; + } + Type operandType = operand.getType(); + if (auto memrefType = dyn_cast(operandType)) { + Type basePtrType = LLVM::LLVMPointerType::get(context); + auto [ptr, offset] = + gcext::utils::getPtrAndOffset(rewriter, operand, loc); + operands.push_back(ptr); + operands.push_back(offset); + operandTypes.push_back(basePtrType); + operandTypes.push_back(rewriter.getIndexType()); // offset + } else { + operands.push_back(operand); + operandTypes.push_back(operand.getType()); + } + } + + createFuncCall(rewriter, loc, module, DNNL_BRGEMM_EXECUTE_NAME, operands, + operandTypes, {}); + rewriter.eraseOp(op); + return success(); + } +}; + +class ConvertBrgemmEpilogueOpRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + // dnnl runtime func for brgemm release hw context: + // void dnnl_brgemm_tilerelease(); + LogicalResult matchAndRewrite(microkernel::BrgemmEpilogueOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + ModuleOp module = op->template getParentOfType(); + func::CallOp call = createFuncCall( + rewriter, loc, module, DNNL_BRGEMM_TILERELEASE_NAME, {}, {}, {}); + rewriter.replaceOp(op, call); + return success(); + } +}; + +class ConvertMicrokernelToDnnlFunc + : public impl::ConvertMicrokernelToDnnlFuncBase< + ConvertMicrokernelToDnnlFunc> { +public: + using impl::ConvertMicrokernelToDnnlFuncBase< + ConvertMicrokernelToDnnlFunc>::ConvertMicrokernelToDnnlFuncBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns + .add( + &getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) + signalPassFailure(); + } +}; + +} // namespace mlir::microkernel diff --git a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp new file mode 100644 index 000000000..a3843e01d --- /dev/null +++ b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp @@ -0,0 +1,194 @@ +//===- EarlyDispatchMicrokernel.cpp ----------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + +#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" +#include "gc/Utils/ValueUtils.h" +#include "oneapi/dnnl/dnnl_types.h" + +namespace mlir::microkernel { +#define GEN_PASS_DEF_EARLYDISPATCHMICROKERNEL +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define DEBUG_TYPE "early-dispatch-microkernel" + +static constexpr StringRef getGlobalCtorsVarName() { + return "llvm.global_ctors"; +} + +static FailureOr +createGlobalKernelHandleName(RewriterBase &rewriter, + microkernel::BrgemmDispatchOp op) { + // TODO(haixin): Add runtime backend type to global name + std::stringstream ss; + ss << "g_dispatched_microkernel_brgemm"; + + auto flags = op.getFlagsAttr(); + for (auto flag : flags) { + auto brgemmFlag = dyn_cast_or_null(flag); + if (!brgemmFlag) + return failure("unknown flag for BRGEMM"); + if (brgemmFlag.getValue() == BrgemmFlags::LIST) + return failure("addr mode BRGEMM not supported yet"); + if (brgemmFlag.getValue() == BrgemmFlags::BETA_0) + ss << "_init"; + } + + // M, N, K, LDA, LDB, LDC, stride_a, stride_b + // they are in the same order with BrgemmDispatchOp inputs + ArrayRef inputs = op.getInputsAttr().asArrayRef(); + for (auto input : inputs) { + ss << "_" << input; + } + + // dtypeA, dtypeB + auto dtypes = op.getDataType(); + if (dtypes.size() != 2) + return failure("invalid number of DataType for BRGEMM"); + ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[0]); + ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[1]); + + return ss.str(); +} + +// get or create global kernel handle with initializer, identified by +// `kernelName` +static FailureOr +getOrCreateGlobalKernelHandle(RewriterBase &rewriter, ModuleOp module, + const std::string &kernelName, + microkernel::BrgemmDispatchOp op) { + // Create the global at the entry of the module + LLVM::GlobalOp global; + if (!(global = module.lookupSymbol(kernelName))) { + auto global_type = op.getResults().getType(); + FlatSymbolRefAttr ctorName = + SymbolRefAttr::get(module->getContext(), kernelName + "_ctor"); + if (module.lookupSymbol(ctorName.getAttr())) { + return failure("Existing ctor for new global kernel handle"); + } + + OpBuilder::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + global = rewriter.create( + module.getLoc(), global_type, /*isConstant=*/false, + LLVM::Linkage::Internal, kernelName, Attribute(), + /*alignment=*/0); + + // create ctor for this global, which needs to be LLVMFuncOp + LLVM::LLVMFuncOp ctorFunc = rewriter.create( + module.getLoc(), ctorName.getValue(), + LLVM::LLVMFunctionType::get(global_type, {}, false)); + + Location loc = ctorFunc.getLoc(); + Block *entryBlock = ctorFunc.addEntryBlock(rewriter); + rewriter.setInsertionPointToEnd(entryBlock); + + auto dispatch = op.clone(); + rewriter.insert(dispatch); + Value globalPtr = rewriter.create(loc, global); + rewriter.create(loc, dispatch.getResults(), globalPtr); + rewriter.create(loc, dispatch.getResults()); + + // initialize the gloabl with global_ctors, as the initializer of global + // does not allow side effect + rewriter.setInsertionPointToStart(module.getBody()); + LLVM::GlobalCtorsOp global_ctors = + module.lookupSymbol(getGlobalCtorsVarName()); + SmallVector ctorRefs; + SmallVector priorities; + if (global_ctors) { + auto ctorRefsAttr = global_ctors.getCtors(); + auto prioritiesAttr = global_ctors.getPriorities(); + for (auto &&[ctor, prior] : llvm::zip(ctorRefsAttr, prioritiesAttr)) { + ctorRefs.push_back(ctor); + priorities.push_back(prior); + } + } + ctorRefs.push_back(ctorName); + // Set new ctor's priority to lowest + priorities.push_back(IntegerAttr::get(rewriter.getI32Type(), INT_MAX)); + if (global_ctors) { + // If there's existing ctors + rewriter.replaceOpWithNewOp( + global_ctors, rewriter.getArrayAttr(ctorRefs), + rewriter.getArrayAttr(priorities)); + } else { + rewriter.create(module.getLoc(), + rewriter.getArrayAttr(ctorRefs), + rewriter.getArrayAttr(priorities)); + } + } + return global; +} + +class EarlyDispatchBrgemmRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(microkernel::BrgemmDispatchOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + ModuleOp module = op->template getParentOfType(); + func::FuncOp func = op->template getParentOfType(); + + auto globalKernelName = createGlobalKernelHandleName(rewriter, op); + if (failed(globalKernelName)) { + return rewriter.notifyMatchFailure( + op, "Failed to create global kernel handle name"); + } + + // Generate kernel handle global name + auto globalKernel = + getOrCreateGlobalKernelHandle(rewriter, module, *globalKernelName, op); + if (failed(globalKernel)) { + return rewriter.notifyMatchFailure( + op, "Failed to create global kernel handle"); + } + + // Inject global val loading into start of function + auto funcBlock = &func.getBody().front(); + rewriter.setInsertionPointToStart(funcBlock); + Value globalPtr = rewriter.create(loc, *globalKernel); + Value globalVal = rewriter.create( + loc, op.getResults().getType(), globalPtr); + rewriter.moveOpAfter(op, funcBlock, funcBlock->begin()); + rewriter.replaceOp(op, globalVal); + return success(); + } +}; + +class EarlyDispatchMicrokernel + : public impl::EarlyDispatchMicrokernelBase { +public: + using impl::EarlyDispatchMicrokernelBase< + EarlyDispatchMicrokernel>::EarlyDispatchMicrokernelBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + + // Ignore newly created Ops + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + if (failed( + applyPatternsAndFoldGreedily(getOperation(), patternSet, config))) + signalPassFailure(); + } +}; + +} // namespace mlir::microkernel diff --git a/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp new file mode 100644 index 000000000..8604e9c59 --- /dev/null +++ b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp @@ -0,0 +1,437 @@ +//===- MicrokernelInvariantCodeMotion.cpp ----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include +#include + +#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" +#include "gc/Utils/ValueUtils.h" +#include "oneapi/dnnl/dnnl_types.h" + +namespace mlir::microkernel { +#define GEN_PASS_DEF_MICROKERNELINVARIANTCODEMOTION +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define DEBUG_TYPE "microkernel-invariant-code-motion" + +enum BrgemmCallType { INAPPLICABLE = -1, DISPATCH, TILECFG, TILERELEASE }; + +static bool isParallelLoop(Operation *op) { + return llvm::isa(op) || llvm::isa(op) || + llvm::isa(op) || llvm::isa(op); +} + +static bool isConcernedCF(Operation *op) { + return llvm::isa(op) || llvm::isa(op) || + llvm::isa(op) || llvm::isa(op); +} + +static BrgemmCallType getBrgemmCallType(Operation *op) { + if (!llvm::isa(op)) { + return BrgemmCallType::INAPPLICABLE; + } + auto callOp = dyn_cast(op); + StringAttr callee = callOp.getCalleeAttr().getAttr(); + + if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_DISPATCH_NAME)) { + return BrgemmCallType::DISPATCH; + } + if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_TILECFG_NAME)) { + return BrgemmCallType::TILECFG; + } + if (callee == + StringAttr::get(op->getContext(), DNNL_BRGEMM_TILERELEASE_NAME)) { + return BrgemmCallType::TILERELEASE; + } + return BrgemmCallType::INAPPLICABLE; +} + +// Tree node of structure info tree, each node represents an Op +// This tree contains only concerned Ops +struct BrgemmContextStructInfo { + // Basic structure info retrieved by first walk + Operation *contextRoot; // Could be parallel loop or func + Operation *self, *parent; + DenseSet child; + SmallVector containBrgemmCallType; + // Rewrite-time info retrieved by analysing basic structure info + union { + Operation *maxInvariantScope; // Used by BrgemmCallOps for hoisting + bool hasTilereleased; // Used by other Ops as hoisting scopes to + // dedup tilerelease injection + }; + BrgemmContextStructInfo() { + contextRoot = nullptr; + self = nullptr; + parent = nullptr; + containBrgemmCallType = {false, false, false}; + maxInvariantScope = nullptr; + } +}; + +typedef DenseMap OpStructInfoMap; + +class BrgemmTilecfgRewriter : public OpRewritePattern { +private: + OpStructInfoMap &structInfo; + +public: + using OpRewritePattern::OpRewritePattern; + + BrgemmTilecfgRewriter(MLIRContext *context, OpStructInfoMap &si) + : OpRewritePattern(context), structInfo{si} {} + + LogicalResult matchAndRewrite(func::CallOp op, + PatternRewriter &rewriter) const final { + ModuleOp module = op->template getParentOfType(); + + StringAttr callee = op.getCalleeAttr().getAttr(); + if (!module.lookupSymbol(callee)) + return rewriter.notifyMatchFailure(op, + "Invalid CallOp to unknown callee"); + if (callee != + StringAttr::get(rewriter.getContext(), DNNL_BRGEMM_TILECFG_NAME)) + return rewriter.notifyMatchFailure(op, "Not call to BRGEMM tilecfg"); + auto opInfoIter = structInfo.find(op); + if (opInfoIter == structInfo.end()) { + return rewriter.notifyMatchFailure(op, "Cannot find structInfo for Op"); + } + auto &opStructInfo = opInfoIter->second; + + // Don't hoist if max invariant scope is itself to reduce + // unnecessary movement + if (opStructInfo.maxInvariantScope == op) { + return rewriter.notifyMatchFailure(op, "No need to hoist"); + } + rewriter.moveOpBefore(op, opStructInfo.maxInvariantScope); + // Avoid being hoisted again + opStructInfo.maxInvariantScope = op; + return success(); + } +}; + +static void markScopeAsReleased(OpStructInfoMap &structInfo, Operation *op) { + auto iter = structInfo.find(op); + assert(iter != structInfo.end()); + // Don't mark BrgemmCallOps + if (getBrgemmCallType(op) != BrgemmCallType::INAPPLICABLE) + return; + iter->second.hasTilereleased = true; + + for (auto ch : iter->second.child) { + markScopeAsReleased(structInfo, ch); + } +} + +class BrgemmTilereleaseRewriter : public OpRewritePattern { +private: + OpStructInfoMap &structInfo; + +public: + using OpRewritePattern::OpRewritePattern; + + BrgemmTilereleaseRewriter(MLIRContext *context, OpStructInfoMap &si) + : OpRewritePattern(context), structInfo{si} {} + LogicalResult matchAndRewrite(func::CallOp op, + PatternRewriter &rewriter) const final { + ModuleOp module = op->template getParentOfType(); + + StringAttr callee = op.getCalleeAttr().getAttr(); + if (!module.lookupSymbol(callee)) + return rewriter.notifyMatchFailure(op, + "Invalid CallOp to unknown callee"); + if (callee != + StringAttr::get(rewriter.getContext(), DNNL_BRGEMM_TILERELEASE_NAME)) + return rewriter.notifyMatchFailure(op, "Not call to BRGEMM tilerelease"); + + auto opInfoIter = structInfo.find(op); + if (opInfoIter == structInfo.end()) { + return rewriter.notifyMatchFailure(op, "Cannot find structInfo for Op"); + } + auto &opStructInfo = opInfoIter->second; + auto targetInfoIter = structInfo.find(opStructInfo.maxInvariantScope); + assert(opStructInfo.maxInvariantScope); + // Don't hoist if max invariant scope is itself to reduce + // unnecessary movement + if (opStructInfo.maxInvariantScope == op) { + return rewriter.notifyMatchFailure(op, "No need to hoist"); + } + assert(targetInfoIter != structInfo.end()); + // move last tilerelease to end of contextRoot, and remove all + // others + if (targetInfoIter->second.hasTilereleased) { + rewriter.eraseOp(op); + } else { + // auto region = opStructInfo.maxInvariantScope->getRegion(0); + // auto block = ®ion.getBlocks().front(); + // auto enditer = block->end(); + // rewriter.moveOpBefore(op, block, enditer); + rewriter.moveOpAfter(op, opStructInfo.maxInvariantScope); + // Mark all sub scope as released to avoid duplicate Tilerelease + markScopeAsReleased(structInfo, opStructInfo.maxInvariantScope); + // Avoid being hoisted again + opStructInfo.maxInvariantScope = op; + } + return success(); + } +}; + +class MicrokernelInvariantCodeMotion + : public impl::MicrokernelInvariantCodeMotionBase< + MicrokernelInvariantCodeMotion> { +private: + // This helper create structInfo tree node along the path from input Op(as + // leaf) to contextRoot(FuncOp or any parallel Op) on demand; + // This tree only contains concerned Ops, including BrgemmCall Ops, parallel + // ops and related SCF Ops etc.; + // Input Op should be a BrgemmCall Op or the op + // depent by BrgemmTilecfg/BrgemmRelease + BrgemmContextStructInfo getOrCreateBrgemmContext(OpStructInfoMap &structInfo, + Operation *op) { + auto resIter = structInfo.find(op); + if (resIter != structInfo.end()) { + return resIter->second; + } + + SmallVector createdInfo; + Operation *contextRootForCreatedInfo = nullptr; + + auto doCreateStructInfo = [&](Operation *child, Operation *op) { + BrgemmContextStructInfo info; + info.self = op; + if (child) { + auto iter = structInfo.find(child); + assert(iter != structInfo.end()); + iter->second.parent = op; + info.child.insert(child); + } + structInfo.insert(std::make_pair(op, std::move(info))); + auto iter = structInfo.find(op); + createdInfo.push_back(&iter->second); + return &iter->second; + }; + // Create info for input Op as leaf + auto brgemmInfo = doCreateStructInfo(nullptr, op); + auto callType = getBrgemmCallType(op); + if (callType != BrgemmCallType::INAPPLICABLE) { + brgemmInfo->containBrgemmCallType[callType] = true; + } + + auto last = op; + auto current = op->getParentOp(); + // Traverse up the IR tree, creating structInfo for each concerned Op + while (current) { + bool isParaLoop = isParallelLoop(current); + bool isCCF = isConcernedCF(current); + if (!llvm::isa(current) && !isParaLoop && !isCCF) { + // Only care about selected Ops + current = current->getParentOp(); + continue; + } + + auto iter = structInfo.find(current); + if (iter != structInfo.end()) { + // StructInfo exists for current Op, then we don't need to create info + // anymore as all ancestors have been created + // But we still need to propagate containBrgemmCallType if we are + // dealing with BrgemmCall Ops + if (last) { + auto lastIter = structInfo.find(last); + assert(lastIter != structInfo.end()); + lastIter->second.parent = current; + iter->second.child.insert(last); + // Invalidate last as we don't create new info anymore + last = nullptr; + } + if (callType != BrgemmCallType::INAPPLICABLE) { + // Propagate containCallType if needed + iter->second.containBrgemmCallType[callType] = true; + } else + break; + } else { + // StructInfo not exist, then create one for current Op and keep + // Traversing up + auto created = doCreateStructInfo(last, current); + if (callType != BrgemmCallType::INAPPLICABLE) { + created->containBrgemmCallType[callType] = true; + } + last = current; + } + if (llvm::isa(current) || isParaLoop) { + // Encounter `contextRoot`, then record and terminate traversing + contextRootForCreatedInfo = current; + break; + } + current = current->getParentOp(); + } + + // Assign `contextRoot` for newly created structInfo + if (contextRootForCreatedInfo) { + for (auto info : createdInfo) + info->contextRoot = contextRootForCreatedInfo; + } + + resIter = structInfo.find(op); + assert(resIter != structInfo.end()); + return resIter->second; + } + + // This helper expand invariant scope according to two function: + // 1. controlFlowAllow: Whether we can hoist the BrgemmCallOp out of the scope + // of current Op; For example, we won't move TILECFG out of an IfOp as it + // contains underministic control flow. + // 2. peerAllow: Whether we can hoist the BrgemmCallOp out of the scope of + // current Op without violating other peer BrgemmCallOp in the same level; For + // example, one scf.ForOp contains two TILECFG in the same level, then we + // cannot hoist any of them. + void expandInvariantScopeWithCond( + OpStructInfoMap &structInfo, Operation *op, + std::function controlFlowAllow, + std::function &)> + peerAllow) { + auto opIter = structInfo.find(op); + assert(opIter != structInfo.end()); + auto contextRoot = opIter->second.contextRoot; + auto current = op; + auto currIter = opIter; + auto parent = opIter->second.parent; + while (parent != contextRoot) { + auto parentIter = structInfo.find(parent); + assert(parentIter != structInfo.end()); + // Verify whether we can expand the scope to direct parent + bool isControlFlowAllow = controlFlowAllow(parent); + bool isPeerAllow = + peerAllow(op, structInfo, current, parentIter->second.child); + if (!isControlFlowAllow || !isPeerAllow) { + break; + } + current = parent; + currIter = parentIter; + parent = parentIter->second.parent; + } + + opIter->second.maxInvariantScope = current; + } + + void expandInvariantScope(OpStructInfoMap &structInfo, Operation *op) { + BrgemmCallType brgemmCallType = getBrgemmCallType(op); + assert(brgemmCallType == BrgemmCallType::TILECFG || + brgemmCallType == BrgemmCallType::TILERELEASE); + + if (brgemmCallType == BrgemmCallType::TILECFG) { + expandInvariantScopeWithCond( + structInfo, op, + [](Operation *op) -> bool { + return !llvm::isa(op) && + !llvm::isa(op); + }, + [](Operation *self, const OpStructInfoMap &structInfo, + Operation *current, const DenseSet &peers) -> bool { + for (auto peer : peers) { + if (peer == current) + continue; + if (peer == self->getOperand(0).getDefiningOp()) { + // Don't break operand domination + return false; + } + const auto iter = structInfo.find(peer); + assert(iter != structInfo.end()); + const auto &containType = iter->second.containBrgemmCallType; + if (containType[BrgemmCallType::DISPATCH] || + containType[BrgemmCallType::TILECFG]) { + return false; + } + } + return true; + }); + } else { // brgemmCallType == BrgemmCallType::TILERELEASE + expandInvariantScopeWithCond( + structInfo, op, + [](Operation *op) -> bool { + return !llvm::isa(op) && + !llvm::isa(op); + }, + [](Operation *self, const OpStructInfoMap &structInfo, + Operation *current, + const DenseSet &peers) -> bool { return true; }); + } + } + + LogicalResult collectBrgemmContextStructInfo(OpStructInfoMap &structInfo) { + // First walk to collect basic structure + getOperation()->walk( + [this, &structInfo](Operation *op) { + BrgemmCallType brgemmCallType = getBrgemmCallType(op); + if (brgemmCallType == BrgemmCallType::INAPPLICABLE) { + return; + } + + // Construct the structInfo tree lazily upon encountering BrgemmCall + // Op + auto info = getOrCreateBrgemmContext(structInfo, op); + structInfo.insert(std::make_pair(op, std::move(info))); + if (brgemmCallType == BrgemmCallType::TILECFG) { + // Also contruct tree node for the input of BrgemmTilecfg for + // dependency check in `expandInvariantScope` + auto dependOp = op->getOperand(0).getDefiningOp(); + auto dependInfo = getOrCreateBrgemmContext(structInfo, dependOp); + structInfo.insert(std::make_pair(dependOp, std::move(dependInfo))); + } + }); + + // Second walk to analyse hoist related info + getOperation()->walk( + [this, &structInfo](Operation *op) { + BrgemmCallType brgemmCallType = getBrgemmCallType(op); + if (brgemmCallType != BrgemmCallType::TILECFG && + brgemmCallType != BrgemmCallType::TILERELEASE) { + return; + } + + // find the maximal invariant scope for hoisting + expandInvariantScope(structInfo, op); + }); + + return success(); + } + +public: + using impl::MicrokernelInvariantCodeMotionBase< + MicrokernelInvariantCodeMotion>::MicrokernelInvariantCodeMotionBase; + void runOnOperation() final { + OpStructInfoMap structInfo; + + if (failed(collectBrgemmContextStructInfo(structInfo))) { + signalPassFailure(); + } + + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), structInfo); + patterns.add(&getContext(), structInfo); + FrozenRewritePatternSet patternSet(std::move(patterns)); + + // Ignore newly created Ops + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + if (failed( + applyPatternsAndFoldGreedily(getOperation(), patternSet, config))) { + signalPassFailure(); + } + } +}; + +} // namespace mlir::microkernel diff --git a/src/gc-opt/gc-opt.cpp b/src/gc-opt/gc-opt.cpp index 5aa0bfc1d..827ba7202 100644 --- a/src/gc-opt/gc-opt.cpp +++ b/src/gc-opt/gc-opt.cpp @@ -21,6 +21,7 @@ #include "gc/Dialect/Linalgx/LinalgxDialect.h" #include "gc/Dialect/Microkernel/MicrokernelDialect.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" #include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" @@ -49,6 +50,8 @@ int main(int argc, char *argv[]) { mlir::gc::registerCPUPipeline(); mlir::gc::registerGraphCompilerPasses(); mlir::cpuruntime::registerCPURuntimePasses(); + mlir::microkernel::registerMicrokernelPasses(); + mlir::DialectRegistry registry; registry.insert(); registry.insert(); diff --git a/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir b/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir new file mode 100644 index 000000000..041c238dd --- /dev/null +++ b/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir @@ -0,0 +1,44 @@ +// RUN: gc-opt %s -convert-linalg-to-microkernel -split-input-file | FileCheck %s + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @simple_brgemm() { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg7, %arg8) in (4, 8) { + %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_10 : memref<32x32xf32>) + %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + linalg.batch_reduce_matmul ins(%subview, %subview_11 : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) + %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_14: f32, %out: f32): + %0 = arith.addf %in, %in_14 : f32 + linalg.yield %0 : f32 + } + %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %0 = arith.maximumf %in, %cst : f32 + linalg.yield %0 : f32 + } + memref.dealloc %alloc_10 : memref<32x32xf32> + } + return +} + +// CHECK-LABEL: simple_brgemm +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> +// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) +// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () +// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () + +// ----- diff --git a/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir b/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir new file mode 100644 index 000000000..bbd018b85 --- /dev/null +++ b/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir @@ -0,0 +1,70 @@ +// RUN: gc-opt %s -convert-microkernel-to-dnnl-func -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg0, %arg1) in (4, 8) { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %1 = arith.addf %in, %in_7 : f32 + linalg.yield %1 : f32 + } + %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.maximumf %in, %cst : f32 + linalg.yield %1 : f32 + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + return + } +} + +// CHECK-LABEL: dnnl_brgemm_execute +// CHECK-LABEL: dnnl_brgemm_dispatch +// CHECK-LABEL: simple_brgemm +// CHECK: %[[CST0:.+]] = arith.constant 0 : index +// CHECK: %[[CST3:.+]] = arith.constant 3 : i64 +// CHECK: %[[CST1F:.+]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[CST1024:.+]] = arith.constant 1024 : i64 +// CHECK: %[[CST32:.+]] = arith.constant 32 : i64 +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 + +// CHECK: %[[KERNEL:.+]] = func.call @dnnl_brgemm_dispatch(%[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST1024]], %[[CST1024]], %[[CST1F]], %[[CST3]], %[[CST3]]) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 +// CHECK-NOT: microkernel.brgemm.prologue(%[[TMP:.+]]) : (i64) -> () + +// CHECK: %[[bbA:.+]], %[[offA:.+]], %[[szA:.+]]:3, %[[strdA:.+]]:3 = memref.extract_strided_metadata %[[memrefA:.+]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index +// CHECK-NEXT: %[[ptrA:.+]] = memref.extract_aligned_pointer_as_index %[[memrefA]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> index +// CHECK-NEXT: %[[idxA:.+]] = arith.index_cast %[[ptrA]] : index to i64 +// CHECK-NEXT: %[[llvmptrA:.+]] = llvm.inttoptr %[[idxA]] : i64 to !llvm.ptr + +// CHECK: %[[bbB:.+]], %[[offB:.+]], %[[szB:.+]]:3, %[[strdB:.+]]:3 = memref.extract_strided_metadata %[[memrefB:.+]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index +// CHECK-NEXT: %[[ptrB:.+]] = memref.extract_aligned_pointer_as_index %[[memrefB]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> index +// CHECK-NEXT: %[[idxB:.+]] = arith.index_cast %[[ptrB]] : index to i64 +// CHECK-NEXT: %[[llvmptrB:.+]] = llvm.inttoptr %[[idxB]] : i64 to !llvm.ptr + +// CHECK: %[[ptrC:.+]] = memref.extract_aligned_pointer_as_index %[[memrefC:.+]] : memref<32x32xf32> -> index +// CHECK-NEXT: %[[idxC:.+]] = arith.index_cast %[[ptrC]] : index to i64 +// CHECK-NEXT: %[[llvmptrC:.+]] = llvm.inttoptr %[[idxC]] : i64 to !llvm.ptr + +// CHECK: func.call @dnnl_brgemm_execute(%[[KERNEL]], %[[llvmptrA]], %[[offA]], %[[llvmptrB]], %[[offB]], %[[llvmptrC]], %[[CST0]], %[[CST16]]) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () +// CHECK-NOT: microkernel.brgemm.epilogue(%[[KERNEL]]) : (i64) -> () + +// ----- diff --git a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir new file mode 100644 index 000000000..ad436da0c --- /dev/null +++ b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir @@ -0,0 +1,50 @@ +// RUN: gc-opt %s --convert-microkernel-to-dnnl-func --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc : memref<4x16x32x32xf32>) + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<8x16x32x32xf32>) + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) + scf.forall (%arg0, %arg1) in (4, 8) { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %1 = arith.addf %in, %in_7 : f32 + linalg.yield %1 : f32 + } + %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.maximumf %in, %cst : f32 + linalg.yield %1 : f32 + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + return + } + + func.func @main() { + call @simple_brgemm() : ()->() + // COM: parallelcpu.printf "BRGEMM DONE\n" + return + } + + // COM: CHECK: BRGEMM DONE +} From 6a1260a35009d190ab41db7c5ca7bbedf7268486 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 29 May 2024 23:35:49 -0700 Subject: [PATCH 10/93] remove irrelavant --- .../Microkernel/MicrokernelPasses.td | 22 - lib/gc/Transforms/Microkernel/CMakeLists.txt | 2 - .../Microkernel/EarlyDispatchMicrokernel.cpp | 194 -------- .../MicrokernelInvariantCodeMotion.cpp | 437 ------------------ 4 files changed, 655 deletions(-) delete mode 100644 lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp delete mode 100644 lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td index c015dddf3..6e483656f 100644 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.td +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.td @@ -50,26 +50,4 @@ def ConvertMicrokernelToDnnlFunc: Pass<"convert-microkernel-to-dnnl-func", "::ml "microkernel::MicrokernelDialect"]; } -def EarlyDispatchMicrokernel: Pass<"early-dispatch-microkernel", "::mlir::ModuleOp"> { - let summary = "Early dispatch microkernel during compile time"; - let description = [{ - Early dispatch microkernel during compile time. - }]; - let dependentDialects = ["func::FuncDialect", - "memref::MemRefDialect", - "LLVM::LLVMDialect", - "microkernel::MicrokernelDialect"]; -} - -def MicrokernelInvariantCodeMotion: Pass<"microkernel-invariant-code-motion", "::mlir::ModuleOp"> { - let summary = "Hoist invariant microkernel code to avoid redundant execution"; - let description = [{ - Hoist invariant microkernel code to avoid redundant execution. - }]; - let dependentDialects = ["func::FuncDialect", - "memref::MemRefDialect", - "LLVM::LLVMDialect", - "microkernel::MicrokernelDialect"]; -} - #endif // GC_DIALECT_MICROKERNELPASSES diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt index c2d88e33c..462c5d697 100644 --- a/lib/gc/Transforms/Microkernel/CMakeLists.txt +++ b/lib/gc/Transforms/Microkernel/CMakeLists.txt @@ -9,8 +9,6 @@ include_directories(${DNNL_INCLUDES}) add_mlir_dialect_library(MLIRMicrokernelTransforms ConvertLinalgToMicrokernel.cpp ConvertMicrokernelToDnnlFunc.cpp - EarlyDispatchMicrokernel.cpp - MicrokernelInvariantCodeMotion.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ diff --git a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp deleted file mode 100644 index a3843e01d..000000000 --- a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp +++ /dev/null @@ -1,194 +0,0 @@ -//===- EarlyDispatchMicrokernel.cpp ----------------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===---------------------------------------------------------------------===// - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include - -#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" -#include "gc/Transforms/Microkernel/MicrokernelPasses.h" -#include "gc/Utils/ValueUtils.h" -#include "oneapi/dnnl/dnnl_types.h" - -namespace mlir::microkernel { -#define GEN_PASS_DEF_EARLYDISPATCHMICROKERNEL -#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" - -#define DEBUG_TYPE "early-dispatch-microkernel" - -static constexpr StringRef getGlobalCtorsVarName() { - return "llvm.global_ctors"; -} - -static FailureOr -createGlobalKernelHandleName(RewriterBase &rewriter, - microkernel::BrgemmDispatchOp op) { - // TODO(haixin): Add runtime backend type to global name - std::stringstream ss; - ss << "g_dispatched_microkernel_brgemm"; - - auto flags = op.getFlagsAttr(); - for (auto flag : flags) { - auto brgemmFlag = dyn_cast_or_null(flag); - if (!brgemmFlag) - return failure("unknown flag for BRGEMM"); - if (brgemmFlag.getValue() == BrgemmFlags::LIST) - return failure("addr mode BRGEMM not supported yet"); - if (brgemmFlag.getValue() == BrgemmFlags::BETA_0) - ss << "_init"; - } - - // M, N, K, LDA, LDB, LDC, stride_a, stride_b - // they are in the same order with BrgemmDispatchOp inputs - ArrayRef inputs = op.getInputsAttr().asArrayRef(); - for (auto input : inputs) { - ss << "_" << input; - } - - // dtypeA, dtypeB - auto dtypes = op.getDataType(); - if (dtypes.size() != 2) - return failure("invalid number of DataType for BRGEMM"); - ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[0]); - ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[1]); - - return ss.str(); -} - -// get or create global kernel handle with initializer, identified by -// `kernelName` -static FailureOr -getOrCreateGlobalKernelHandle(RewriterBase &rewriter, ModuleOp module, - const std::string &kernelName, - microkernel::BrgemmDispatchOp op) { - // Create the global at the entry of the module - LLVM::GlobalOp global; - if (!(global = module.lookupSymbol(kernelName))) { - auto global_type = op.getResults().getType(); - FlatSymbolRefAttr ctorName = - SymbolRefAttr::get(module->getContext(), kernelName + "_ctor"); - if (module.lookupSymbol(ctorName.getAttr())) { - return failure("Existing ctor for new global kernel handle"); - } - - OpBuilder::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - global = rewriter.create( - module.getLoc(), global_type, /*isConstant=*/false, - LLVM::Linkage::Internal, kernelName, Attribute(), - /*alignment=*/0); - - // create ctor for this global, which needs to be LLVMFuncOp - LLVM::LLVMFuncOp ctorFunc = rewriter.create( - module.getLoc(), ctorName.getValue(), - LLVM::LLVMFunctionType::get(global_type, {}, false)); - - Location loc = ctorFunc.getLoc(); - Block *entryBlock = ctorFunc.addEntryBlock(rewriter); - rewriter.setInsertionPointToEnd(entryBlock); - - auto dispatch = op.clone(); - rewriter.insert(dispatch); - Value globalPtr = rewriter.create(loc, global); - rewriter.create(loc, dispatch.getResults(), globalPtr); - rewriter.create(loc, dispatch.getResults()); - - // initialize the gloabl with global_ctors, as the initializer of global - // does not allow side effect - rewriter.setInsertionPointToStart(module.getBody()); - LLVM::GlobalCtorsOp global_ctors = - module.lookupSymbol(getGlobalCtorsVarName()); - SmallVector ctorRefs; - SmallVector priorities; - if (global_ctors) { - auto ctorRefsAttr = global_ctors.getCtors(); - auto prioritiesAttr = global_ctors.getPriorities(); - for (auto &&[ctor, prior] : llvm::zip(ctorRefsAttr, prioritiesAttr)) { - ctorRefs.push_back(ctor); - priorities.push_back(prior); - } - } - ctorRefs.push_back(ctorName); - // Set new ctor's priority to lowest - priorities.push_back(IntegerAttr::get(rewriter.getI32Type(), INT_MAX)); - if (global_ctors) { - // If there's existing ctors - rewriter.replaceOpWithNewOp( - global_ctors, rewriter.getArrayAttr(ctorRefs), - rewriter.getArrayAttr(priorities)); - } else { - rewriter.create(module.getLoc(), - rewriter.getArrayAttr(ctorRefs), - rewriter.getArrayAttr(priorities)); - } - } - return global; -} - -class EarlyDispatchBrgemmRewriter - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(microkernel::BrgemmDispatchOp op, - PatternRewriter &rewriter) const final { - Location loc = op.getLoc(); - ModuleOp module = op->template getParentOfType(); - func::FuncOp func = op->template getParentOfType(); - - auto globalKernelName = createGlobalKernelHandleName(rewriter, op); - if (failed(globalKernelName)) { - return rewriter.notifyMatchFailure( - op, "Failed to create global kernel handle name"); - } - - // Generate kernel handle global name - auto globalKernel = - getOrCreateGlobalKernelHandle(rewriter, module, *globalKernelName, op); - if (failed(globalKernel)) { - return rewriter.notifyMatchFailure( - op, "Failed to create global kernel handle"); - } - - // Inject global val loading into start of function - auto funcBlock = &func.getBody().front(); - rewriter.setInsertionPointToStart(funcBlock); - Value globalPtr = rewriter.create(loc, *globalKernel); - Value globalVal = rewriter.create( - loc, op.getResults().getType(), globalPtr); - rewriter.moveOpAfter(op, funcBlock, funcBlock->begin()); - rewriter.replaceOp(op, globalVal); - return success(); - } -}; - -class EarlyDispatchMicrokernel - : public impl::EarlyDispatchMicrokernelBase { -public: - using impl::EarlyDispatchMicrokernelBase< - EarlyDispatchMicrokernel>::EarlyDispatchMicrokernelBase; - void runOnOperation() final { - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - FrozenRewritePatternSet patternSet(std::move(patterns)); - - // Ignore newly created Ops - GreedyRewriteConfig config; - config.strictMode = GreedyRewriteStrictness::ExistingOps; - if (failed( - applyPatternsAndFoldGreedily(getOperation(), patternSet, config))) - signalPassFailure(); - } -}; - -} // namespace mlir::microkernel diff --git a/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp deleted file mode 100644 index 8604e9c59..000000000 --- a/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp +++ /dev/null @@ -1,437 +0,0 @@ -//===- MicrokernelInvariantCodeMotion.cpp ----------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===---------------------------------------------------------------------===// - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include -#include - -#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" -#include "gc/Transforms/Microkernel/MicrokernelPasses.h" -#include "gc/Utils/ValueUtils.h" -#include "oneapi/dnnl/dnnl_types.h" - -namespace mlir::microkernel { -#define GEN_PASS_DEF_MICROKERNELINVARIANTCODEMOTION -#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" - -#define DEBUG_TYPE "microkernel-invariant-code-motion" - -enum BrgemmCallType { INAPPLICABLE = -1, DISPATCH, TILECFG, TILERELEASE }; - -static bool isParallelLoop(Operation *op) { - return llvm::isa(op) || llvm::isa(op) || - llvm::isa(op) || llvm::isa(op); -} - -static bool isConcernedCF(Operation *op) { - return llvm::isa(op) || llvm::isa(op) || - llvm::isa(op) || llvm::isa(op); -} - -static BrgemmCallType getBrgemmCallType(Operation *op) { - if (!llvm::isa(op)) { - return BrgemmCallType::INAPPLICABLE; - } - auto callOp = dyn_cast(op); - StringAttr callee = callOp.getCalleeAttr().getAttr(); - - if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_DISPATCH_NAME)) { - return BrgemmCallType::DISPATCH; - } - if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_TILECFG_NAME)) { - return BrgemmCallType::TILECFG; - } - if (callee == - StringAttr::get(op->getContext(), DNNL_BRGEMM_TILERELEASE_NAME)) { - return BrgemmCallType::TILERELEASE; - } - return BrgemmCallType::INAPPLICABLE; -} - -// Tree node of structure info tree, each node represents an Op -// This tree contains only concerned Ops -struct BrgemmContextStructInfo { - // Basic structure info retrieved by first walk - Operation *contextRoot; // Could be parallel loop or func - Operation *self, *parent; - DenseSet child; - SmallVector containBrgemmCallType; - // Rewrite-time info retrieved by analysing basic structure info - union { - Operation *maxInvariantScope; // Used by BrgemmCallOps for hoisting - bool hasTilereleased; // Used by other Ops as hoisting scopes to - // dedup tilerelease injection - }; - BrgemmContextStructInfo() { - contextRoot = nullptr; - self = nullptr; - parent = nullptr; - containBrgemmCallType = {false, false, false}; - maxInvariantScope = nullptr; - } -}; - -typedef DenseMap OpStructInfoMap; - -class BrgemmTilecfgRewriter : public OpRewritePattern { -private: - OpStructInfoMap &structInfo; - -public: - using OpRewritePattern::OpRewritePattern; - - BrgemmTilecfgRewriter(MLIRContext *context, OpStructInfoMap &si) - : OpRewritePattern(context), structInfo{si} {} - - LogicalResult matchAndRewrite(func::CallOp op, - PatternRewriter &rewriter) const final { - ModuleOp module = op->template getParentOfType(); - - StringAttr callee = op.getCalleeAttr().getAttr(); - if (!module.lookupSymbol(callee)) - return rewriter.notifyMatchFailure(op, - "Invalid CallOp to unknown callee"); - if (callee != - StringAttr::get(rewriter.getContext(), DNNL_BRGEMM_TILECFG_NAME)) - return rewriter.notifyMatchFailure(op, "Not call to BRGEMM tilecfg"); - auto opInfoIter = structInfo.find(op); - if (opInfoIter == structInfo.end()) { - return rewriter.notifyMatchFailure(op, "Cannot find structInfo for Op"); - } - auto &opStructInfo = opInfoIter->second; - - // Don't hoist if max invariant scope is itself to reduce - // unnecessary movement - if (opStructInfo.maxInvariantScope == op) { - return rewriter.notifyMatchFailure(op, "No need to hoist"); - } - rewriter.moveOpBefore(op, opStructInfo.maxInvariantScope); - // Avoid being hoisted again - opStructInfo.maxInvariantScope = op; - return success(); - } -}; - -static void markScopeAsReleased(OpStructInfoMap &structInfo, Operation *op) { - auto iter = structInfo.find(op); - assert(iter != structInfo.end()); - // Don't mark BrgemmCallOps - if (getBrgemmCallType(op) != BrgemmCallType::INAPPLICABLE) - return; - iter->second.hasTilereleased = true; - - for (auto ch : iter->second.child) { - markScopeAsReleased(structInfo, ch); - } -} - -class BrgemmTilereleaseRewriter : public OpRewritePattern { -private: - OpStructInfoMap &structInfo; - -public: - using OpRewritePattern::OpRewritePattern; - - BrgemmTilereleaseRewriter(MLIRContext *context, OpStructInfoMap &si) - : OpRewritePattern(context), structInfo{si} {} - LogicalResult matchAndRewrite(func::CallOp op, - PatternRewriter &rewriter) const final { - ModuleOp module = op->template getParentOfType(); - - StringAttr callee = op.getCalleeAttr().getAttr(); - if (!module.lookupSymbol(callee)) - return rewriter.notifyMatchFailure(op, - "Invalid CallOp to unknown callee"); - if (callee != - StringAttr::get(rewriter.getContext(), DNNL_BRGEMM_TILERELEASE_NAME)) - return rewriter.notifyMatchFailure(op, "Not call to BRGEMM tilerelease"); - - auto opInfoIter = structInfo.find(op); - if (opInfoIter == structInfo.end()) { - return rewriter.notifyMatchFailure(op, "Cannot find structInfo for Op"); - } - auto &opStructInfo = opInfoIter->second; - auto targetInfoIter = structInfo.find(opStructInfo.maxInvariantScope); - assert(opStructInfo.maxInvariantScope); - // Don't hoist if max invariant scope is itself to reduce - // unnecessary movement - if (opStructInfo.maxInvariantScope == op) { - return rewriter.notifyMatchFailure(op, "No need to hoist"); - } - assert(targetInfoIter != structInfo.end()); - // move last tilerelease to end of contextRoot, and remove all - // others - if (targetInfoIter->second.hasTilereleased) { - rewriter.eraseOp(op); - } else { - // auto region = opStructInfo.maxInvariantScope->getRegion(0); - // auto block = ®ion.getBlocks().front(); - // auto enditer = block->end(); - // rewriter.moveOpBefore(op, block, enditer); - rewriter.moveOpAfter(op, opStructInfo.maxInvariantScope); - // Mark all sub scope as released to avoid duplicate Tilerelease - markScopeAsReleased(structInfo, opStructInfo.maxInvariantScope); - // Avoid being hoisted again - opStructInfo.maxInvariantScope = op; - } - return success(); - } -}; - -class MicrokernelInvariantCodeMotion - : public impl::MicrokernelInvariantCodeMotionBase< - MicrokernelInvariantCodeMotion> { -private: - // This helper create structInfo tree node along the path from input Op(as - // leaf) to contextRoot(FuncOp or any parallel Op) on demand; - // This tree only contains concerned Ops, including BrgemmCall Ops, parallel - // ops and related SCF Ops etc.; - // Input Op should be a BrgemmCall Op or the op - // depent by BrgemmTilecfg/BrgemmRelease - BrgemmContextStructInfo getOrCreateBrgemmContext(OpStructInfoMap &structInfo, - Operation *op) { - auto resIter = structInfo.find(op); - if (resIter != structInfo.end()) { - return resIter->second; - } - - SmallVector createdInfo; - Operation *contextRootForCreatedInfo = nullptr; - - auto doCreateStructInfo = [&](Operation *child, Operation *op) { - BrgemmContextStructInfo info; - info.self = op; - if (child) { - auto iter = structInfo.find(child); - assert(iter != structInfo.end()); - iter->second.parent = op; - info.child.insert(child); - } - structInfo.insert(std::make_pair(op, std::move(info))); - auto iter = structInfo.find(op); - createdInfo.push_back(&iter->second); - return &iter->second; - }; - // Create info for input Op as leaf - auto brgemmInfo = doCreateStructInfo(nullptr, op); - auto callType = getBrgemmCallType(op); - if (callType != BrgemmCallType::INAPPLICABLE) { - brgemmInfo->containBrgemmCallType[callType] = true; - } - - auto last = op; - auto current = op->getParentOp(); - // Traverse up the IR tree, creating structInfo for each concerned Op - while (current) { - bool isParaLoop = isParallelLoop(current); - bool isCCF = isConcernedCF(current); - if (!llvm::isa(current) && !isParaLoop && !isCCF) { - // Only care about selected Ops - current = current->getParentOp(); - continue; - } - - auto iter = structInfo.find(current); - if (iter != structInfo.end()) { - // StructInfo exists for current Op, then we don't need to create info - // anymore as all ancestors have been created - // But we still need to propagate containBrgemmCallType if we are - // dealing with BrgemmCall Ops - if (last) { - auto lastIter = structInfo.find(last); - assert(lastIter != structInfo.end()); - lastIter->second.parent = current; - iter->second.child.insert(last); - // Invalidate last as we don't create new info anymore - last = nullptr; - } - if (callType != BrgemmCallType::INAPPLICABLE) { - // Propagate containCallType if needed - iter->second.containBrgemmCallType[callType] = true; - } else - break; - } else { - // StructInfo not exist, then create one for current Op and keep - // Traversing up - auto created = doCreateStructInfo(last, current); - if (callType != BrgemmCallType::INAPPLICABLE) { - created->containBrgemmCallType[callType] = true; - } - last = current; - } - if (llvm::isa(current) || isParaLoop) { - // Encounter `contextRoot`, then record and terminate traversing - contextRootForCreatedInfo = current; - break; - } - current = current->getParentOp(); - } - - // Assign `contextRoot` for newly created structInfo - if (contextRootForCreatedInfo) { - for (auto info : createdInfo) - info->contextRoot = contextRootForCreatedInfo; - } - - resIter = structInfo.find(op); - assert(resIter != structInfo.end()); - return resIter->second; - } - - // This helper expand invariant scope according to two function: - // 1. controlFlowAllow: Whether we can hoist the BrgemmCallOp out of the scope - // of current Op; For example, we won't move TILECFG out of an IfOp as it - // contains underministic control flow. - // 2. peerAllow: Whether we can hoist the BrgemmCallOp out of the scope of - // current Op without violating other peer BrgemmCallOp in the same level; For - // example, one scf.ForOp contains two TILECFG in the same level, then we - // cannot hoist any of them. - void expandInvariantScopeWithCond( - OpStructInfoMap &structInfo, Operation *op, - std::function controlFlowAllow, - std::function &)> - peerAllow) { - auto opIter = structInfo.find(op); - assert(opIter != structInfo.end()); - auto contextRoot = opIter->second.contextRoot; - auto current = op; - auto currIter = opIter; - auto parent = opIter->second.parent; - while (parent != contextRoot) { - auto parentIter = structInfo.find(parent); - assert(parentIter != structInfo.end()); - // Verify whether we can expand the scope to direct parent - bool isControlFlowAllow = controlFlowAllow(parent); - bool isPeerAllow = - peerAllow(op, structInfo, current, parentIter->second.child); - if (!isControlFlowAllow || !isPeerAllow) { - break; - } - current = parent; - currIter = parentIter; - parent = parentIter->second.parent; - } - - opIter->second.maxInvariantScope = current; - } - - void expandInvariantScope(OpStructInfoMap &structInfo, Operation *op) { - BrgemmCallType brgemmCallType = getBrgemmCallType(op); - assert(brgemmCallType == BrgemmCallType::TILECFG || - brgemmCallType == BrgemmCallType::TILERELEASE); - - if (brgemmCallType == BrgemmCallType::TILECFG) { - expandInvariantScopeWithCond( - structInfo, op, - [](Operation *op) -> bool { - return !llvm::isa(op) && - !llvm::isa(op); - }, - [](Operation *self, const OpStructInfoMap &structInfo, - Operation *current, const DenseSet &peers) -> bool { - for (auto peer : peers) { - if (peer == current) - continue; - if (peer == self->getOperand(0).getDefiningOp()) { - // Don't break operand domination - return false; - } - const auto iter = structInfo.find(peer); - assert(iter != structInfo.end()); - const auto &containType = iter->second.containBrgemmCallType; - if (containType[BrgemmCallType::DISPATCH] || - containType[BrgemmCallType::TILECFG]) { - return false; - } - } - return true; - }); - } else { // brgemmCallType == BrgemmCallType::TILERELEASE - expandInvariantScopeWithCond( - structInfo, op, - [](Operation *op) -> bool { - return !llvm::isa(op) && - !llvm::isa(op); - }, - [](Operation *self, const OpStructInfoMap &structInfo, - Operation *current, - const DenseSet &peers) -> bool { return true; }); - } - } - - LogicalResult collectBrgemmContextStructInfo(OpStructInfoMap &structInfo) { - // First walk to collect basic structure - getOperation()->walk( - [this, &structInfo](Operation *op) { - BrgemmCallType brgemmCallType = getBrgemmCallType(op); - if (brgemmCallType == BrgemmCallType::INAPPLICABLE) { - return; - } - - // Construct the structInfo tree lazily upon encountering BrgemmCall - // Op - auto info = getOrCreateBrgemmContext(structInfo, op); - structInfo.insert(std::make_pair(op, std::move(info))); - if (brgemmCallType == BrgemmCallType::TILECFG) { - // Also contruct tree node for the input of BrgemmTilecfg for - // dependency check in `expandInvariantScope` - auto dependOp = op->getOperand(0).getDefiningOp(); - auto dependInfo = getOrCreateBrgemmContext(structInfo, dependOp); - structInfo.insert(std::make_pair(dependOp, std::move(dependInfo))); - } - }); - - // Second walk to analyse hoist related info - getOperation()->walk( - [this, &structInfo](Operation *op) { - BrgemmCallType brgemmCallType = getBrgemmCallType(op); - if (brgemmCallType != BrgemmCallType::TILECFG && - brgemmCallType != BrgemmCallType::TILERELEASE) { - return; - } - - // find the maximal invariant scope for hoisting - expandInvariantScope(structInfo, op); - }); - - return success(); - } - -public: - using impl::MicrokernelInvariantCodeMotionBase< - MicrokernelInvariantCodeMotion>::MicrokernelInvariantCodeMotionBase; - void runOnOperation() final { - OpStructInfoMap structInfo; - - if (failed(collectBrgemmContextStructInfo(structInfo))) { - signalPassFailure(); - } - - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext(), structInfo); - patterns.add(&getContext(), structInfo); - FrozenRewritePatternSet patternSet(std::move(patterns)); - - // Ignore newly created Ops - GreedyRewriteConfig config; - config.strictMode = GreedyRewriteStrictness::ExistingOps; - if (failed( - applyPatternsAndFoldGreedily(getOperation(), patternSet, config))) { - signalPassFailure(); - } - } -}; - -} // namespace mlir::microkernel From 1850d60a56ae5c74813eb9a36f8e2cc1a3717126 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Thu, 30 May 2024 00:15:18 -0700 Subject: [PATCH 11/93] refine cmake --- include/gc/Transforms/CMakeLists.txt | 2 ++ lib/gc/Dialect/Microkernel/CMakeLists.txt | 1 + 2 files changed, 3 insertions(+) diff --git a/include/gc/Transforms/CMakeLists.txt b/include/gc/Transforms/CMakeLists.txt index fdc68e6a7..501283ab1 100644 --- a/include/gc/Transforms/CMakeLists.txt +++ b/include/gc/Transforms/CMakeLists.txt @@ -4,3 +4,5 @@ mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix GraphCompiler) mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix GraphCompiler) add_public_tablegen_target(GraphCompilerPassIncGen) add_mlir_doc(Passes GraphCompilerPasses ./ -gen-pass-doc) + +add_subdirectory(Microkernel) diff --git a/lib/gc/Dialect/Microkernel/CMakeLists.txt b/lib/gc/Dialect/Microkernel/CMakeLists.txt index 029f00cce..0a1aafa6e 100644 --- a/lib/gc/Dialect/Microkernel/CMakeLists.txt +++ b/lib/gc/Dialect/Microkernel/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRMicrokernel DEPENDS MLIRMicrokernelOpsIncGen + MLIRMicrokernelPassesIncGen LINK_LIBS PUBLIC ${MLIR_LINK_COMPONENTS} From 5bc44e40fbdd1630468603800eb3d10fd4c856e7 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Mon, 3 Jun 2024 23:49:03 -0700 Subject: [PATCH 12/93] fix brgemm runtime --- .../CPURuntime/Microkernel/BrgemmOnednn.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index e5e84e24d..ff4aee172 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -52,7 +52,6 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, int64_t stride_a, int64_t stride_b, float beta, int64_t dtypeA, int64_t dtypeB) { - std::cout << ">>> Brgemm dispatch: " << std::endl; brgemm_desc_list.emplace_back(brgemm_desc_t()); brgemm_kernel_list.emplace_back(nullptr); @@ -79,7 +78,6 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, void dnnl_brgemm_tileconfig(int64_t kernel_idx) { assert(kernel_idx >= 0 && kernel_idx < (int64_t)brgemm_desc_list.size() && "Invalid kernel handler"); - std::cout << ">>> Brgemm tileconfig: " << kernel_idx << std::endl; brgemm_desc_t &desc = brgemm_desc_list[kernel_idx]; if (!desc.is_tmm) { @@ -98,7 +96,6 @@ void dnnl_brgemm_tilerelease() { if (!mayiuse(avx512_core_amx)) { return; } - std::cout << ">>> Brgemm tilerelease" << std::endl; amx_tile_release(); } @@ -109,16 +106,15 @@ void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset, assert(kernel_idx >= 0 && kernel_idx < (int64_t)brgemm_desc_list.size() && "Invalid kernel handler"); - std::cout << ">>> Brgemm Execute: " << kernel_idx << std::endl; brgemm_desc_t &desc = brgemm_desc_list[kernel_idx]; brgemm_kernel_t *kernel = brgemm_kernel_list[kernel_idx]; size_t A_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_a) * A_offset; size_t B_offset_in_bytes = - dnnl::impl::types::data_type_size(desc.dt_b) * A_offset; + dnnl::impl::types::data_type_size(desc.dt_b) * B_offset; size_t C_offset_in_bytes = - dnnl::impl::types::data_type_size(desc.dt_c) * A_offset; + dnnl::impl::types::data_type_size(desc.dt_c) * C_offset; #ifdef _WIN32 // fix-me: (win32) impl @@ -128,9 +124,12 @@ void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset, #endif // TODO(haixin): use thread local buffer for scratch char *scratch = new char[scratch_size]; - brgemm_kernel_execute(kernel, num, A + A_offset_in_bytes, - B + B_offset_in_bytes, nullptr, C + C_offset_in_bytes, - (void *)scratch); + char *A_arith = (char *)A; + char *B_arith = (char *)B; + char *C_arith = (char *)C; + brgemm_kernel_execute(kernel, num, (void *)(A_arith + A_offset_in_bytes), + (void *)(B_arith + B_offset_in_bytes), nullptr, + (void *)(C_arith + C_offset_in_bytes), (void *)scratch); delete scratch; } } From 1c69ee65cc198221b4566061549870907cf95f50 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 4 Jun 2024 21:18:06 -0700 Subject: [PATCH 13/93] support linalgx::batch_reduce_matmul_vnni --- .../Microkernel/MicrokernelPasses.h | 1 + .../Microkernel/MicrokernelPasses.td | 1 + .../ConvertLinalgToMicrokernel.cpp | 197 +++++++++++++----- .../Microkernel/linalg-to-microkernel.mlir | 43 ++++ 4 files changed, 192 insertions(+), 50 deletions(-) diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.h b/include/gc/Transforms/Microkernel/MicrokernelPasses.h index a053253e6..ee9da8a4e 100644 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.h +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.h @@ -9,6 +9,7 @@ #ifndef GC_MICROKERNELPASSES_H #define GC_MICROKERNELPASSES_H +#include "gc/Dialect/Linalgx/LinalgxDialect.h" #include "gc/Dialect/Microkernel/MicrokernelDialect.h" #include "gc/Dialect/Microkernel/MicrokernelOps.h" #include "mlir/Pass/Pass.h" diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td index 6e483656f..59726ac0a 100644 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.td +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.td @@ -36,6 +36,7 @@ def ConvertLinalgToMicrokernel: Pass<"convert-linalg-to-microkernel", "::mlir::f let dependentDialects = ["func::FuncDialect", "memref::MemRefDialect", "linalg::LinalgDialect", + "linalgx::LinalgxDialect", "microkernel::MicrokernelDialect"]; } diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index 4e46f35ac..0f426aeeb 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -6,6 +6,12 @@ // //===---------------------------------------------------------------------===// +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineMap.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" + #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" @@ -14,6 +20,7 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "gc/Dialect/Linalgx/LinalgxOps.h" #include "gc/Transforms/Microkernel/MicrokernelPasses.h" #include "gc/Utils/StructuredOpMatcher.h" #include "gc/Utils/ValueUtils.h" @@ -42,12 +49,68 @@ struct BrgemmInfo { BrgemmMode mode; }; +FailureOr +customInferContractionDims(linalg::LinalgOp linalgOp) { + auto dims = linalg::inferContractionDims(linalgOp); + if (failed(dims)) + return dims; + if (llvm::isa(linalgOp)) { + // For VnniOp, the K reduction dims (dim index 3 & 4) cannot be infered by + // linalg utils because they form complex affine in operand A; Manually add + // them here + dims->k.push_back(3); + dims->k.push_back(4); + } + return dims; +} + +static bool isMatchingAffineResult(linalg::LinalgOp linalgOp, AffineExpr expr, + ArrayRef dimPos) { + if (dimPos.size() > 2) { + return false; + } + auto firstDim = getAffineDimExpr(dimPos[0], linalgOp.getContext()); + if (dimPos.size() == 1) { + if (firstDim == expr) + return true; + else + return false; + } + // If not regular dim affine, check for VNNI format K affine + auto secondKPosDim = getAffineDimExpr(dimPos[1], linalgOp.getContext()); + // An K affine result for VNNI should be this format: + // d{kPos[0]} * s{kPos[1]} + d{kPos[1]} (k0 * K_vnni + k1) + if (auto add = dyn_cast(expr)) { + if (add.getKind() == AffineExprKind::Add) { + auto lhs = add.getLHS(); + auto rhs = add.getRHS(); + if (rhs == secondKPosDim) { + auto mul = dyn_cast(lhs); + if (mul && mul.getKind() == AffineExprKind::Mul && + mul.getLHS() == firstDim) { + if (auto cst_affine = dyn_cast(mul.getRHS())) { + if (cst_affine.getValue() == 2 || cst_affine.getValue() == 4) { + return true; + } + } + } + } + } + } + return false; +} + // Return the position of `dim` in the codomain of `operand`. -static std::optional -getPosInCodomain(unsigned dim, OpOperand *operand, linalg::LinalgOp linalgOp) { +static std::optional getPosInCodomain(ArrayRef dimPos, + OpOperand *operand, + linalg::LinalgOp linalgOp) { assert(operand->getOwner() == linalgOp); - return linalgOp.getMatchingIndexingMap(operand).getResultPosition( - getAffineDimExpr(dim, linalgOp.getContext())); + auto map = linalgOp.getMatchingIndexingMap(operand); + for (unsigned i = 0, numResults = map.getNumResults(); i < numResults; i++) { + if (isMatchingAffineResult(linalgOp, map.getResult(i), dimPos)) + return i; + } + return std::nullopt; } static FailureOr @@ -55,32 +118,51 @@ inferBrgemmInfo(linalg::LinalgOp linalgOp, const linalg::ContractionDimensions &dims) { unsigned mPos = dims.m[0]; unsigned nPos = dims.n[0]; - unsigned kPos = dims.k.back(); - std::optional batchPos; - if (dims.k.size() == 2) - batchPos = dims.k.front(); + // dims.k could be of 2 cases: + // 1. dims.k.size() == 2: non-VNNI, K = dims.k[1] + // 2. dims.k.size() == 3: VNNI, K = dims.k[1] * dims.k[2] + unsigned batchPos = dims.k.front(); + SmallVector kPos; + if (dims.k.size() == 2) { + kPos = {dims.k[1]}; + } else if (dims.k.size() == 3) { + kPos = {dims.k[1], dims.k[2]}; + } else { + return failure(); + } LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Candidate dims: " << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] m: " << mPos << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] n: " << nPos << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] k: " << kPos << "\n"); - if (batchPos) - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] batch: " << batchPos << "\n"); - else - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] no batch dim\n"); - - auto checkStridesAndGetLda = [&](unsigned minorDim, unsigned majorDim, - OpOperand *operand) -> FailureOr { + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] m pos in affine: " << mPos + << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] n pos in affine: " << nPos + << "\n"); + for (auto kp : kPos) { + LLVM_DEBUG(llvm::dbgs() + << "[inferBrgemmInfo] k pos in affine: " << kp << "\n"); + } + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] batch pos in affine: " + << batchPos << "\n"); + + auto checkStridesAndGetLda = + [&](ArrayRef minorDim, ArrayRef majorDim, + OpOperand *operand, bool allowVnni) -> FailureOr { auto minorDimPosInCodomain = getPosInCodomain(minorDim, operand, linalgOp); auto majorDimPosInCodomain = getPosInCodomain(majorDim, operand, linalgOp); if (!minorDimPosInCodomain || !majorDimPosInCodomain) return failure(); auto stridesOnOperand = gcext::utils::getStaticStrides(operand->get()); - if (failed(stridesOnOperand) || - (*stridesOnOperand)[*minorDimPosInCodomain] != 1) + if (failed(stridesOnOperand)) return failure(); - return (*stridesOnOperand)[*majorDimPosInCodomain]; + auto minorDimLd = (*stridesOnOperand)[*minorDimPosInCodomain]; + auto majorDimLd = (*stridesOnOperand)[*majorDimPosInCodomain]; + if (minorDimLd != 1) { + // VNNI format exists, special treatment to align LD with non-VNNI format + if (!allowVnni || (minorDimLd != 2 && minorDimLd != 4)) + return failure(); + return majorDimLd / minorDimLd; + } + return majorDimLd; }; OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; @@ -88,53 +170,56 @@ inferBrgemmInfo(linalg::LinalgOp linalgOp, OpOperand *operandC = &linalgOp.getDpsInitsMutable()[0]; // A(m, k) - auto lda = checkStridesAndGetLda(kPos, mPos, operandA); + auto lda = checkStridesAndGetLda(kPos, {mPos}, operandA, false); if (failed(lda)) return failure(); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on A: OK\n"); // B(k, n) - auto ldb = checkStridesAndGetLda(nPos, kPos, operandB); + // note: B does not use VNNI format K affine + auto ldb = checkStridesAndGetLda({nPos}, {kPos[0]}, operandB, true); if (failed(ldb)) return failure(); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on B: OK\n"); // C(m, n) - auto ldc = checkStridesAndGetLda(nPos, mPos, operandC); + auto ldc = checkStridesAndGetLda({nPos}, {mPos}, operandC, false); if (failed(ldc)) return failure(); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on C: OK\n"); int64_t strideA = 1; int64_t strideB = 1; - if (batchPos) { - auto batchPosCodomainA = - getPosInCodomain(batchPos.value(), operandA, linalgOp); - auto stridesOnA = gcext::utils::getStaticStrides(operandA->get()); - strideA = (*stridesOnA)[*batchPosCodomainA]; - - auto batchPosCodomainB = - getPosInCodomain(batchPos.value(), operandB, linalgOp); - auto stridesOnB = gcext::utils::getStaticStrides(operandB->get()); - strideB = (*stridesOnB)[*batchPosCodomainB]; - } + auto batchPosCodomainA = getPosInCodomain(batchPos, operandA, linalgOp); + auto stridesOnA = gcext::utils::getStaticStrides(operandA->get()); + strideA = (*stridesOnA)[*batchPosCodomainA]; - auto loops = linalgOp.computeStaticLoopSizes(); - int64_t batchVal = (batchPos) ? loops[batchPos.value()] : 0; + auto batchPosCodomainB = getPosInCodomain(batchPos, operandB, linalgOp); + auto stridesOnB = gcext::utils::getStaticStrides(operandB->get()); + strideB = (*stridesOnB)[*batchPosCodomainB]; + auto loops = linalgOp.computeStaticLoopSizes(); + auto kSize = + kPos.size() == 1 ? loops[kPos[0]] : (loops[kPos[0]] * loops[kPos[1]]); + + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] final BrgemmInfo: m(" + << loops[mPos] << "), n(" << loops[nPos] << "), k(" + << kSize << "), batch(" << loops[batchPos] + << "), lda(" << *lda << "), ldb(" << *ldb << "), ldc(" + << *ldc << "), strideA(" << strideA << "), strideB(" + << strideB << ")\n"); BrgemmInfo info{loops[mPos], loops[nPos], - loops[kPos], - batchVal, + kSize, + loops[batchPos], 0 /* addrLen useless under stride mode */, *lda, *ldb, *ldc, strideA, - strideB}; - info.isInitOutput = false; - info.mode = BrgemmInfo::STRIDE_MODE; - + strideB, + false, + BrgemmInfo::STRIDE_MODE}; return info; } @@ -150,15 +235,21 @@ static FailureOr getBrgemmInfo(linalg::LinalgOp linalgOp) { if (!validBrgemmMatcher.match(linalgOp)) return failure(); - auto contractionDims = linalg::inferContractionDims(linalgOp); + auto contractionDims = customInferContractionDims(linalgOp); if (failed(contractionDims)) { LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Not a valid contraction\n"); return failure(); } if (contractionDims->m.size() != 1 || contractionDims->n.size() != 1 || - (contractionDims->k.size() != 2 && contractionDims->k.size() != 1) || + // batch-reduce dim for BRGEMM should be identified as one of k dim + // including VNNI & non-VNNI cases + (contractionDims->k.size() != 2 && contractionDims->k.size() != 3) || contractionDims->batch.size() != 0) { LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions\n"); + LLVM_DEBUG(llvm::dbgs() + << "[checkStructure] " << contractionDims->m.size() << " " + << contractionDims->n.size() << " " << contractionDims->k.size() + << " " << contractionDims->batch.size() << "\n"); return failure(); } unsigned classifiedLoops = @@ -231,11 +322,12 @@ static void replaceOpWithMicrokernelOpSet(PatternRewriter &rewriter, dispatched); } -class ConvertBatchReduceMatmulToBrgemmRewriter - : public OpRewritePattern { +template +class ConvertContractionOpToBrgemmRewriter + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(linalg::BatchReduceMatmulOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ContractionOp op, PatternRewriter &rewriter) const final { auto brgemmInfo = getBrgemmInfo(op); if (failed(brgemmInfo)) @@ -252,7 +344,12 @@ class ConvertLinalgToMicrokernel ConvertLinalgToMicrokernel>::ConvertLinalgToMicrokernelBase; void runOnOperation() final { RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns + .add>( + &getContext()); + patterns.add< + ConvertContractionOpToBrgemmRewriter>( + &getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) signalPassFailure(); diff --git a/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir b/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir index 041c238dd..fd57e5a00 100644 --- a/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir +++ b/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir @@ -42,3 +42,46 @@ func.func @simple_brgemm() { // CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () // ----- + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @simple_brgemm() { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg7, %arg8) in (4, 8) { + %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_10 : memref<32x32xf32>) + %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + linalgx.batch_reduce_matmul_vnni ins(%subview, %subview_11 : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) + %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_14: f32, %out: f32): + %0 = arith.addf %in, %in_14 : f32 + linalg.yield %0 : f32 + } + %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %0 = arith.maximumf %in, %cst : f32 + linalg.yield %0 : f32 + } + memref.dealloc %alloc_10 : memref<32x32xf32> + } + return +} + +// CHECK-LABEL: simple_brgemm +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> +// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) +// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () +// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () + +// ----- From 2ce6f4ce72b052399effbfbf7e7818b5d4df47e2 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 11 Jun 2024 00:02:21 -0700 Subject: [PATCH 14/93] fix runtime dnnl brgemm correctness --- CMakeLists.txt | 1 + .../ExecutionEngine/CPURuntime/CMakeLists.txt | 6 +++--- .../CPURuntime/Microkernel/BrgemmNaive.cpp | 9 +-------- .../CPURuntime/Microkernel/BrgemmOnednn.cpp | 17 +++++++++++++---- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 636b33ad2..ded5a86a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,6 +31,7 @@ option(GC_TEST_ENABLE "Build the tests" ON) option(GC_USE_GPU "Enable GPU backend" OFF) option(GC_ENABLE_BINDINGS_PYTHON "Enable Graph Complier Python Binding" ON) option(GC_DEV_LINK_LLVM_DYLIB "Link dynamic libraries of LLVM and MLIR. For developers only. Do not use it in packing the library." OFF) +option(GC_RUNTIME_NAIVE_BRGEMM "Use naive BRGEMM as runtime backend, mainly for debug purpose." OFF) if(GC_LEGACY_ENABLE) add_subdirectory(legacy/core) diff --git a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt index 180413719..32d7510f8 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt @@ -10,10 +10,10 @@ file(GLOB_RECURSE MICROKERNEL_RUNTIME_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/*.c ) -if (GC_MLIR_NAIVE_BRGEMM) - string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/BrgemmOnednn.cpp;" "" MICROKERNEL_RUNTIME_SOURCES "${MICROKERNEL_RUNTIME_SOURCES}") +if (GC_RUNTIME_NAIVE_BRGEMM) + string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/BrgemmOnednn.cpp" "" MICROKERNEL_RUNTIME_SOURCES "${MICROKERNEL_RUNTIME_SOURCES}") else() - string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/BrgemmNaive.cpp;" "" MICROKERNEL_RUNTIME_SOURCES "${MICROKERNEL_RUNTIME_SOURCES}") + string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/BrgemmNaive.cpp" "" MICROKERNEL_RUNTIME_SOURCES "${MICROKERNEL_RUNTIME_SOURCES}") endif() include(onednn) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp index 3cb585c31..01710e322 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp @@ -90,6 +90,7 @@ static int naive_brgemm_execute_fp32(brgemm_params_t params, void *A, Abuf += params.stride_a; Bbuf += params.stride_b; } + return 0; } @@ -175,8 +176,6 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, // simply store the given parameters for naive BRGEMM brgemm_list.emplace_back(brgemm_params_t(M, N, K, LDA, LDB, LDC, stride_a, stride_b, beta, dtypeA, dtypeB)); - // std::cout << ">>>>> dnnl_brgemm_dispatch: " << brgemm_list.size() - 1 << - // std::endl; return brgemm_list.size() - 1; } @@ -192,32 +191,26 @@ void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void *B, brgemm_params_t ¶ms = brgemm_list[kernel]; if (params.dtypeA == static_cast(dnnl_f32) && params.dtypeB == static_cast(dnnl_f32)) { - // std::cout << ">>>>> dnnl_brgemm_execute_f32: " << kernel << std::endl; naive_brgemm_execute_fp32(params, A, A_offset, B, B_offset, C, C_offset, num); } else if (params.dtypeA == static_cast(dnnl_bf16) && params.dtypeB == static_cast(dnnl_bf16)) { - // std::cout << ">>>>> dnnl_brgemm_execute_bf16: " << kernel << std::endl; naive_brgemm_execute_bf16(params, A, A_offset, B, B_offset, C, C_offset, num); } else if (params.dtypeA == static_cast(dnnl_s8) && params.dtypeB == static_cast(dnnl_s8)) { - // std::cout << ">>>>> dnnl_brgemm_execute_s8s8: " << kernel << std::endl; naive_brgemm_execute_int8(params, A, A_offset, B, B_offset, C, C_offset, num); } else if (params.dtypeA == static_cast(dnnl_s8) && params.dtypeB == static_cast(dnnl_u8)) { - // std::cout << ">>>>> dnnl_brgemm_execute_s8u8: " << kernel << std::endl; naive_brgemm_execute_int8(params, A, A_offset, B, B_offset, C, C_offset, num); } else if (params.dtypeA == static_cast(dnnl_u8) && params.dtypeB == static_cast(dnnl_u8)) { - // std::cout << ">>>>> dnnl_brgemm_execute_u8u8: " << kernel << std::endl; naive_brgemm_execute_int8(params, A, A_offset, B, B_offset, C, C_offset, num); } else if (params.dtypeA == static_cast(dnnl_u8) && params.dtypeB == static_cast(dnnl_s8)) { - // std::cout << ">>>>> dnnl_brgemm_execute_u8s8: " << kernel << std::endl; naive_brgemm_execute_int8(params, A, A_offset, B, B_offset, C, C_offset, num); } else { diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index ff4aee172..3e8d57257 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include -#include #include #include #include @@ -57,12 +56,18 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, brgemm_desc_t &desc = brgemm_desc_list.back(); auto &kernel = brgemm_kernel_list.back(); - brgemm_strides_t stride_info{stride_a, stride_b}; + + auto dnnl_dtypeA = static_cast(dtypeA); + auto dnnl_dtypeB = static_cast(dtypeB); + int64_t dtypeA_size = + dnnl::impl::types::data_type_size(dnnl_dtypeA); + int64_t dtypeB_size = + dnnl::impl::types::data_type_size(dnnl_dtypeB); + brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size}; dnnl::impl::status_t status = brgemm_desc_init( &desc, cpu_isa_t::isa_undef, brgemm_batch_kind_t::brgemm_strd, - static_cast(dtypeA), - static_cast(dtypeB), false, false, + dnnl_dtypeA, dnnl_dtypeB, false, false, brgemm_layout_t::brgemm_row_major, 1.0f, beta, LDA, LDB, LDC, M, N, K, &stride_info); assert(status == dnnl::impl::status::success && @@ -72,6 +77,9 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, assert(status == dnnl::impl::status::success && "Failed to JIT BRGEMM kernel"); + brgemm_attr_t dnnl_attrs; + brgemm_desc_set_attr(&desc, dnnl_attrs); + return brgemm_desc_list.size() - 1; } @@ -84,6 +92,7 @@ void dnnl_brgemm_tileconfig(int64_t kernel_idx) { return; } + // TODO(haixin): move to dispatch time char palette_buffer[PALETTE_SIZE]; dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer); assert(status == dnnl::impl::status::success && From e0e8b9449999196237feff57486035e0c87146e5 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 11 Jun 2024 00:05:21 -0700 Subject: [PATCH 15/93] fix format --- .../CPURuntime/Microkernel/BrgemmOnednn.cpp | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index 3e8d57257..4f60e7da1 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -59,17 +59,14 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, auto dnnl_dtypeA = static_cast(dtypeA); auto dnnl_dtypeB = static_cast(dtypeB); - int64_t dtypeA_size = - dnnl::impl::types::data_type_size(dnnl_dtypeA); - int64_t dtypeB_size = - dnnl::impl::types::data_type_size(dnnl_dtypeB); + int64_t dtypeA_size = dnnl::impl::types::data_type_size(dnnl_dtypeA); + int64_t dtypeB_size = dnnl::impl::types::data_type_size(dnnl_dtypeB); brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size}; dnnl::impl::status_t status = brgemm_desc_init( &desc, cpu_isa_t::isa_undef, brgemm_batch_kind_t::brgemm_strd, - dnnl_dtypeA, dnnl_dtypeB, false, false, - brgemm_layout_t::brgemm_row_major, 1.0f, beta, LDA, LDB, LDC, M, N, K, - &stride_info); + dnnl_dtypeA, dnnl_dtypeB, false, false, brgemm_layout_t::brgemm_row_major, + 1.0f, beta, LDA, LDB, LDC, M, N, K, &stride_info); assert(status == dnnl::impl::status::success && "Failed to initialize BRGEMM descriptor"); From 921b0dc90fbb904406fcf78130ca9f398fb37b90 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Fri, 14 Jun 2024 01:52:42 -0700 Subject: [PATCH 16/93] support pattern with linalg.fill --- .../ConvertLinalgToMicrokernel.cpp | 33 +++++++ .../Microkernel/linalg-to-microkernel.mlir | 90 ++++++++++++++++++- 2 files changed, 121 insertions(+), 2 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index 0f426aeeb..c5b1d6355 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -322,6 +322,22 @@ static void replaceOpWithMicrokernelOpSet(PatternRewriter &rewriter, dispatched); } +bool isZeroArithConstant(arith::ConstantOp op) { + if (!op) + return false; + + if (auto intAttr = llvm::dyn_cast(op.getValue())) { + if (intAttr.getInt() != 0) + return false; + } else if (auto floatAttr = llvm::dyn_cast(op.getValue())) { + if (!floatAttr.getValue().isZero()) + return false; + } else + return false; + + return true; +} + template class ConvertContractionOpToBrgemmRewriter : public OpRewritePattern { @@ -332,6 +348,23 @@ class ConvertContractionOpToBrgemmRewriter auto brgemmInfo = getBrgemmInfo(op); if (failed(brgemmInfo)) return failure(); + // Check for immediately preceding linalg::FillOp + auto block = op.getBlock(); + auto opIter = Block::iterator(op); + if (block->begin() != opIter) { + auto prevOp = &(*(--opIter)); + if (auto fillOp = dyn_cast(prevOp)) { + auto inputCst = dyn_cast_or_null( + fillOp.getInputs()[0].getDefiningOp()); + auto fillOperand = fillOp.getOutputs()[0]; + auto contractionOperand = op.getOutputs()[0]; + if (isZeroArithConstant(inputCst) && + contractionOperand == fillOperand) { + brgemmInfo->isInitOutput = true; + rewriter.eraseOp(prevOp); + } + } + } replaceOpWithMicrokernelOpSet(rewriter, op, *brgemmInfo); return success(); } diff --git a/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir b/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir index fd57e5a00..c721ed2ad 100644 --- a/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir +++ b/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir @@ -9,7 +9,6 @@ func.func @simple_brgemm() { %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> scf.forall (%arg7, %arg8) in (4, 8) { %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_10 : memref<32x32xf32>) %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> linalg.batch_reduce_matmul ins(%subview, %subview_11 : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) @@ -52,7 +51,6 @@ func.func @simple_brgemm() { %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> scf.forall (%arg7, %arg8) in (4, 8) { %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_10 : memref<32x32xf32>) %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> linalgx.batch_reduce_matmul_vnni ins(%subview, %subview_11 : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) @@ -85,3 +83,91 @@ func.func @simple_brgemm() { // CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () // ----- + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @simple_brgemm() { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg7, %arg8) in (4, 8) { + %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + linalg.fill ins(%cst : f32) outs(%alloc_10 : memref<32x32xf32>) + linalg.batch_reduce_matmul ins(%subview, %subview_11 : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) + %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_14: f32, %out: f32): + %0 = arith.addf %in, %in_14 : f32 + linalg.yield %0 : f32 + } + %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %0 = arith.maximumf %in, %cst : f32 + linalg.yield %0 : f32 + } + memref.dealloc %alloc_10 : memref<32x32xf32> + } + return +} + +// CHECK-LABEL: simple_brgemm +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> +// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK-NOT: linalg.fill +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (f32, f32) +// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () +// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () + +// ----- + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @simple_brgemm() { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg7, %arg8) in (4, 8) { + %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + linalg.fill ins(%cst : f32) outs(%alloc_10 : memref<32x32xf32>) + linalgx.batch_reduce_matmul_vnni ins(%subview, %subview_11 : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) + %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_14: f32, %out: f32): + %0 = arith.addf %in, %in_14 : f32 + linalg.yield %0 : f32 + } + %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %0 = arith.maximumf %in, %cst : f32 + linalg.yield %0 : f32 + } + memref.dealloc %alloc_10 : memref<32x32xf32> + } + return +} + +// CHECK-LABEL: simple_brgemm +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> +// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> +// CHECK-NOT: linalg.fill +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) +// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () +// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () + +// ----- From 6ec1053b8f71db640239612b48e4a0a1ea5cc329 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Mon, 17 Jun 2024 00:13:34 -0700 Subject: [PATCH 17/93] move brgemm init_tiles to dispatch time --- .../CPURuntime/Microkernel/BrgemmOnednn.cpp | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index 4f60e7da1..f22adf7ee 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -44,6 +44,7 @@ __attribute__((weak)) void print_verbose_header() {} static constexpr int PALETTE_SIZE = 64; static std::vector brgemm_desc_list; static std::vector brgemm_kernel_list; +static std::vector brgemm_palette; extern "C" { @@ -77,6 +78,17 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, brgemm_attr_t dnnl_attrs; brgemm_desc_set_attr(&desc, dnnl_attrs); + // TODO(haixin): Reuse identical palettes across kernels + if (desc.is_tmm) { + brgemm_palette.push_back(new char[PALETTE_SIZE]); + dnnl::impl::status_t status = + brgemm_init_tiles(desc, brgemm_palette.back()); + assert(status == dnnl::impl::status::success && + "Failed to initialize palette for BRGEMM"); + } else { + brgemm_palette.push_back(nullptr); + } + return brgemm_desc_list.size() - 1; } @@ -89,13 +101,9 @@ void dnnl_brgemm_tileconfig(int64_t kernel_idx) { return; } - // TODO(haixin): move to dispatch time - char palette_buffer[PALETTE_SIZE]; - dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer); - assert(status == dnnl::impl::status::success && - "Failed to initialize palette for BRGEMM"); - - amx_tile_configure(palette_buffer); + assert(brgemm_palette[kernel_idx] != nullptr && + "Invalid palette for BRGEMM kernel"); + amx_tile_configure(brgemm_palette[kernel_idx]); } void dnnl_brgemm_tilerelease() { From f014e734f26203d9f92be21cd3a12156394422b2 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Mon, 17 Jun 2024 00:50:12 -0700 Subject: [PATCH 18/93] move mlir tests to right place --- .../test}/gc/Dialect/Microkernel/linalg-to-microkernel.mlir | 0 .../test}/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename test/{ => mlir/test}/gc/Dialect/Microkernel/linalg-to-microkernel.mlir (100%) rename test/{ => mlir/test}/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir (100%) diff --git a/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir similarity index 100% rename from test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir rename to test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir diff --git a/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir similarity index 100% rename from test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir rename to test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir From f586efb6ba04176943bdb6641258361d8d370552 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 25 Jun 2024 01:53:39 -0700 Subject: [PATCH 19/93] use thread_local for scratch buffer --- .../CPURuntime/Microkernel/BrgemmOnednn.cpp | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index f22adf7ee..6ccb715a6 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -46,6 +46,11 @@ static std::vector brgemm_desc_list; static std::vector brgemm_kernel_list; static std::vector brgemm_palette; +// TODO(haixin): use syscall to determine page size? +static constexpr size_t SCRATCH_SIZE = 2 * 4096; +// TODO(haixin): need to use custom thread management for scratch in the future? +static thread_local char scratch[SCRATCH_SIZE] = {0}; + extern "C" { int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, @@ -130,20 +135,11 @@ void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset, size_t C_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_c) * C_offset; -#ifdef _WIN32 - // fix-me: (win32) impl - static size_t scratch_size = 2 * 4096; -#else - static size_t scratch_size = 2 * getpagesize(); -#endif - // TODO(haixin): use thread local buffer for scratch - char *scratch = new char[scratch_size]; char *A_arith = (char *)A; char *B_arith = (char *)B; char *C_arith = (char *)C; brgemm_kernel_execute(kernel, num, (void *)(A_arith + A_offset_in_bytes), (void *)(B_arith + B_offset_in_bytes), nullptr, (void *)(C_arith + C_offset_in_bytes), (void *)scratch); - delete scratch; } } From c4e4bcf00ce8bf57a36ba758e5c80ed50e8f10ca Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 26 Jun 2024 01:50:03 -0700 Subject: [PATCH 20/93] refine memref ptr/offset extraction --- include/gc/Transforms/Utils/ValueUtils.h | 3 +- .../ConvertMicrokernelToDnnlFunc.cpp | 30 ++++++++++++++++--- lib/gc/Transforms/Utils/ValueUtils.cpp | 29 ++++++++++++++++++ .../Microkernel/microkernel-to-dnnl-func.mlir | 12 ++++---- 4 files changed, 62 insertions(+), 12 deletions(-) diff --git a/include/gc/Transforms/Utils/ValueUtils.h b/include/gc/Transforms/Utils/ValueUtils.h index acffd5642..07013bde4 100644 --- a/include/gc/Transforms/Utils/ValueUtils.h +++ b/include/gc/Transforms/Utils/ValueUtils.h @@ -27,8 +27,7 @@ FailureOr> getStaticStrides(Value val); // Return the offset and ptr for `val`. Assert if `val` // is not a memref. -std::pair getPtrAndOffset(OpBuilder &builder, Value val, - Location loc); +std::pair getPtrAndOffset(OpBuilder &builder, Value operand); } // namespace utils } // namespace mlir diff --git a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp index 0d50aee71..ead0ab0d8 100644 --- a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp @@ -137,8 +137,15 @@ class ConvertBrgemmPrologueOpRewriter }; class ConvertBrgemmOpRewriter : public OpRewritePattern { +private: + DenseMap> &memrefExtractCache; + public: using OpRewritePattern::OpRewritePattern; + ConvertBrgemmOpRewriter(MLIRContext *context, + DenseMap> &cache) + : OpRewritePattern(context), memrefExtractCache{cache} {} + // runtime func for stride mode dnnl brgemm execution: // void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void // *B, uint64_t B_offset, void *C, uint64_t C_offset, int num) @@ -165,8 +172,20 @@ class ConvertBrgemmOpRewriter : public OpRewritePattern { Type operandType = operand.getType(); if (auto memrefType = dyn_cast(operandType)) { Type basePtrType = LLVM::LLVMPointerType::get(context); - auto [ptr, offset] = - gcext::utils::getPtrAndOffset(rewriter, operand, loc); + + Value ptr, offset; + // Use cache to avoid injecting duplicated extraction Ops + auto memrefExtractIter = memrefExtractCache.find(operand); + if (memrefExtractIter == memrefExtractCache.end()) { + auto res = gcext::utils::getPtrAndOffset(rewriter, operand); + ptr = res.first; + offset = res.second; + memrefExtractCache[operand] = res; + } else { + ptr = memrefExtractIter->second.first; + offset = memrefExtractIter->second.second; + } + operands.push_back(ptr); operands.push_back(offset); operandTypes.push_back(basePtrType); @@ -211,8 +230,11 @@ class ConvertMicrokernelToDnnlFunc RewritePatternSet patterns(&getContext()); patterns .add( - &getContext()); + ConvertBrgemmEpilogueOpRewriter>(&getContext()); + + DenseMap> memrefExtractCache; + patterns.add(&getContext(), memrefExtractCache); + FrozenRewritePatternSet patternSet(std::move(patterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) signalPassFailure(); diff --git a/lib/gc/Transforms/Utils/ValueUtils.cpp b/lib/gc/Transforms/Utils/ValueUtils.cpp index d565d0cf8..8750042ee 100644 --- a/lib/gc/Transforms/Utils/ValueUtils.cpp +++ b/lib/gc/Transforms/Utils/ValueUtils.cpp @@ -116,5 +116,34 @@ FailureOr> getStaticStrides(Value value) { return strides; } +std::pair getPtrAndOffset(OpBuilder &builder, Value operand) { + auto memrefType = dyn_cast(operand.getType()); + assert(memrefType && "Expect a memref value"); + + Location loc = operand.getDefiningOp()->getLoc(); + OpBuilder::InsertionGuard guard(builder); + // Insert right after operand producer for better opt chances. + builder.setInsertionPointAfterValue(operand); + + MemRefType baseMemrefType = MemRefType::get({}, memrefType.getElementType()); + Type basePtrType = builder.getIndexType(); + Type offsetType = builder.getIndexType(); + SmallVector sizesTypes(memrefType.getRank(), offsetType); + SmallVector stridesTypes(memrefType.getRank(), offsetType); + auto meta = builder.create( + loc, baseMemrefType, offsetType, sizesTypes, stridesTypes, operand); + Value alignedPointerAsIndex = + builder.create(loc, basePtrType, + operand); + Value alignedPointerAsI64 = builder.create( + loc, builder.getIntegerType(64), alignedPointerAsIndex); + // TODO: non-POD will require an LLVMTypeConverter. + Value alignedPointer = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), + alignedPointerAsI64); + Value offset = meta.getOffset(); + return std::make_pair(alignedPointer, offset); +} + } // namespace utils } // namespace mlir diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir index bbd018b85..8dee28b66 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir @@ -40,15 +40,16 @@ module { // CHECK-LABEL: dnnl_brgemm_execute // CHECK-LABEL: dnnl_brgemm_dispatch // CHECK-LABEL: simple_brgemm -// CHECK: %[[CST0:.+]] = arith.constant 0 : index // CHECK: %[[CST3:.+]] = arith.constant 3 : i64 // CHECK: %[[CST1F:.+]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[CST1024:.+]] = arith.constant 1024 : i64 // CHECK: %[[CST32:.+]] = arith.constant 32 : i64 +// CHECK: %[[CST0:.+]] = arith.constant 0 : index // CHECK: %[[CST16:.+]] = arith.constant 16 : i64 -// CHECK: %[[KERNEL:.+]] = func.call @dnnl_brgemm_dispatch(%[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST1024]], %[[CST1024]], %[[CST1F]], %[[CST3]], %[[CST3]]) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 -// CHECK-NOT: microkernel.brgemm.prologue(%[[TMP:.+]]) : (i64) -> () +// CHECK: %[[ptrC:.+]] = memref.extract_aligned_pointer_as_index %[[memrefC:.+]] : memref<32x32xf32> -> index +// CHECK-NEXT: %[[idxC:.+]] = arith.index_cast %[[ptrC]] : index to i64 +// CHECK-NEXT: %[[llvmptrC:.+]] = llvm.inttoptr %[[idxC]] : i64 to !llvm.ptr // CHECK: %[[bbA:.+]], %[[offA:.+]], %[[szA:.+]]:3, %[[strdA:.+]]:3 = memref.extract_strided_metadata %[[memrefA:.+]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index // CHECK-NEXT: %[[ptrA:.+]] = memref.extract_aligned_pointer_as_index %[[memrefA]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> index @@ -60,9 +61,8 @@ module { // CHECK-NEXT: %[[idxB:.+]] = arith.index_cast %[[ptrB]] : index to i64 // CHECK-NEXT: %[[llvmptrB:.+]] = llvm.inttoptr %[[idxB]] : i64 to !llvm.ptr -// CHECK: %[[ptrC:.+]] = memref.extract_aligned_pointer_as_index %[[memrefC:.+]] : memref<32x32xf32> -> index -// CHECK-NEXT: %[[idxC:.+]] = arith.index_cast %[[ptrC]] : index to i64 -// CHECK-NEXT: %[[llvmptrC:.+]] = llvm.inttoptr %[[idxC]] : i64 to !llvm.ptr +// CHECK: %[[KERNEL:.+]] = func.call @dnnl_brgemm_dispatch(%[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST1024]], %[[CST1024]], %[[CST1F]], %[[CST3]], %[[CST3]]) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 +// CHECK-NOT: microkernel.brgemm.prologue(%[[TMP:.+]]) : (i64) -> () // CHECK: func.call @dnnl_brgemm_execute(%[[KERNEL]], %[[llvmptrA]], %[[offA]], %[[llvmptrB]], %[[offB]], %[[llvmptrC]], %[[CST0]], %[[CST16]]) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () // CHECK-NOT: microkernel.brgemm.epilogue(%[[KERNEL]]) : (i64) -> () From f51ea4c4de7005efee63b963981bcfd959ed32d7 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 26 Jun 2024 02:17:19 -0700 Subject: [PATCH 21/93] revert pass change --- .../ConvertMicrokernelToDnnlFunc.cpp | 28 ++----------------- .../Microkernel/microkernel-to-dnnl-func.mlir | 2 +- 2 files changed, 4 insertions(+), 26 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp index ead0ab0d8..14aba586c 100644 --- a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp @@ -137,15 +137,8 @@ class ConvertBrgemmPrologueOpRewriter }; class ConvertBrgemmOpRewriter : public OpRewritePattern { -private: - DenseMap> &memrefExtractCache; - public: using OpRewritePattern::OpRewritePattern; - ConvertBrgemmOpRewriter(MLIRContext *context, - DenseMap> &cache) - : OpRewritePattern(context), memrefExtractCache{cache} {} - // runtime func for stride mode dnnl brgemm execution: // void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void // *B, uint64_t B_offset, void *C, uint64_t C_offset, int num) @@ -172,20 +165,7 @@ class ConvertBrgemmOpRewriter : public OpRewritePattern { Type operandType = operand.getType(); if (auto memrefType = dyn_cast(operandType)) { Type basePtrType = LLVM::LLVMPointerType::get(context); - - Value ptr, offset; - // Use cache to avoid injecting duplicated extraction Ops - auto memrefExtractIter = memrefExtractCache.find(operand); - if (memrefExtractIter == memrefExtractCache.end()) { - auto res = gcext::utils::getPtrAndOffset(rewriter, operand); - ptr = res.first; - offset = res.second; - memrefExtractCache[operand] = res; - } else { - ptr = memrefExtractIter->second.first; - offset = memrefExtractIter->second.second; - } - + auto [ptr, offset] = gcext::utils::getPtrAndOffset(rewriter, operand); operands.push_back(ptr); operands.push_back(offset); operandTypes.push_back(basePtrType); @@ -230,10 +210,8 @@ class ConvertMicrokernelToDnnlFunc RewritePatternSet patterns(&getContext()); patterns .add(&getContext()); - - DenseMap> memrefExtractCache; - patterns.add(&getContext(), memrefExtractCache); + ConvertBrgemmOpRewriter, ConvertBrgemmEpilogueOpRewriter>( + &getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir index 8dee28b66..6d705920c 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s -convert-microkernel-to-dnnl-func -split-input-file | FileCheck %s +// RUN: gc-opt %s -convert-microkernel-to-dnnl-func -cse -split-input-file | FileCheck %s #map = affine_map<(d0, d1) -> (d0, d1)> module { From 6ad33cfcd58008eb834d6edbda4968053259f26a Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 2 Jul 2024 22:51:59 -0700 Subject: [PATCH 22/93] fix op preceding check --- lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index c5b1d6355..d0d813499 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -349,8 +349,9 @@ class ConvertContractionOpToBrgemmRewriter if (failed(brgemmInfo)) return failure(); // Check for immediately preceding linalg::FillOp - auto block = op.getBlock(); - auto opIter = Block::iterator(op); + Operation *rawOp = op; + auto block = rawOp->getBlock(); + auto opIter = Block::iterator(rawOp); if (block->begin() != opIter) { auto prevOp = &(*(--opIter)); if (auto fillOp = dyn_cast(prevOp)) { From a9a683a923af3266f55e392755daf68b78794b1f Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 00:32:22 -0700 Subject: [PATCH 23/93] fix utils header --- lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp | 4 ++-- .../Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index d0d813499..2fc1a8866 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -22,8 +22,8 @@ #include "gc/Dialect/Linalgx/LinalgxOps.h" #include "gc/Transforms/Microkernel/MicrokernelPasses.h" -#include "gc/Utils/StructuredOpMatcher.h" -#include "gc/Utils/ValueUtils.h" +#include "gc/Transforms/Utils/StructuredOpMatcher.h" +#include "gc/Transforms/Utils/ValueUtils.h" namespace mlir::microkernel { #define GEN_PASS_DEF_CONVERTLINALGTOMICROKERNEL diff --git a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp index 14aba586c..ed3901336 100644 --- a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp @@ -16,7 +16,7 @@ #include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" #include "gc/Transforms/Microkernel/MicrokernelPasses.h" -#include "gc/Utils/ValueUtils.h" +#include "gc/Transforms/Utils/ValueUtils.h" namespace mlir::microkernel { #define GEN_PASS_DEF_CONVERTMICROKERNELTODNNLFUNC From 619f6705f3d215f7426bfe5a9796f3d413ca239c Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 00:45:55 -0700 Subject: [PATCH 24/93] accommodate to new utils --- .../gc/Transforms/Utils/StructuredOpMatcher.h | 20 +++++++++++++++++++ lib/gc/Transforms/Microkernel/CMakeLists.txt | 2 +- .../ConvertLinalgToMicrokernel.cpp | 8 ++++---- .../ConvertMicrokernelToDnnlFunc.cpp | 2 +- 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/include/gc/Transforms/Utils/StructuredOpMatcher.h b/include/gc/Transforms/Utils/StructuredOpMatcher.h index 66bd22b7a..c931472a1 100644 --- a/include/gc/Transforms/Utils/StructuredOpMatcher.h +++ b/include/gc/Transforms/Utils/StructuredOpMatcher.h @@ -217,6 +217,26 @@ template struct EqualsTo { }; template EqualsTo(T) -> EqualsTo; +// Callable object to check if the input is less than or equal to specified +// `value`. +struct LessThanOrEqualTo { + LessThanOrEqualTo() = delete; + explicit LessThanOrEqualTo(size_t value) : value(value){}; + const size_t value; + + bool operator()(size_t value) const { return value <= this->value; } +}; + +// Callable object to check if the input is greater than or equal to specified +// `value`. +struct GreaterThanOrEqualTo { + GreaterThanOrEqualTo() = delete; + explicit GreaterThanOrEqualTo(size_t value) : value(value){}; + const size_t value; + + bool operator()(size_t value) const { return value >= this->value; } +}; + // Callable object to validate number of init operands for `op`. struct NumDpsInits { NumDpsInits() = delete; diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt index 462c5d697..af6fa30e0 100644 --- a/lib/gc/Transforms/Microkernel/CMakeLists.txt +++ b/lib/gc/Transforms/Microkernel/CMakeLists.txt @@ -18,7 +18,7 @@ add_mlir_dialect_library(MLIRMicrokernelTransforms LINK_LIBS PUBLIC ${MLIR_LINK_COMPONENTS} - GCMLIRUtils + GCUtilsIR ) set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS MLIRMicrokernelTransforms) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index 2fc1a8866..8eacdd921 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -151,7 +151,7 @@ inferBrgemmInfo(linalg::LinalgOp linalgOp, auto majorDimPosInCodomain = getPosInCodomain(majorDim, operand, linalgOp); if (!minorDimPosInCodomain || !majorDimPosInCodomain) return failure(); - auto stridesOnOperand = gcext::utils::getStaticStrides(operand->get()); + auto stridesOnOperand = utils::getStaticStrides(operand->get()); if (failed(stridesOnOperand)) return failure(); auto minorDimLd = (*stridesOnOperand)[*minorDimPosInCodomain]; @@ -191,11 +191,11 @@ inferBrgemmInfo(linalg::LinalgOp linalgOp, int64_t strideA = 1; int64_t strideB = 1; auto batchPosCodomainA = getPosInCodomain(batchPos, operandA, linalgOp); - auto stridesOnA = gcext::utils::getStaticStrides(operandA->get()); + auto stridesOnA = utils::getStaticStrides(operandA->get()); strideA = (*stridesOnA)[*batchPosCodomainA]; auto batchPosCodomainB = getPosInCodomain(batchPos, operandB, linalgOp); - auto stridesOnB = gcext::utils::getStaticStrides(operandB->get()); + auto stridesOnB = utils::getStaticStrides(operandB->get()); strideB = (*stridesOnB)[*batchPosCodomainB]; auto loops = linalgOp.computeStaticLoopSizes(); @@ -224,7 +224,7 @@ inferBrgemmInfo(linalg::LinalgOp linalgOp, } static FailureOr getBrgemmInfo(linalg::LinalgOp linalgOp) { - using namespace mlir::gcext::utils::structured_match; + using namespace mlir::structured_match; auto validBrgemmMatcher = StructuredOpMatcher::make() .output(MatchAll(), HasStaticShape()) .input(MatchAll(), HasStaticShape()) diff --git a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp index ed3901336..998500030 100644 --- a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp @@ -165,7 +165,7 @@ class ConvertBrgemmOpRewriter : public OpRewritePattern { Type operandType = operand.getType(); if (auto memrefType = dyn_cast(operandType)) { Type basePtrType = LLVM::LLVMPointerType::get(context); - auto [ptr, offset] = gcext::utils::getPtrAndOffset(rewriter, operand); + auto [ptr, offset] = utils::getPtrAndOffset(rewriter, operand); operands.push_back(ptr); operands.push_back(offset); operandTypes.push_back(basePtrType); From e31a6d38588b6932ed8ab690e4154f0b430b5aa9 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 00:55:15 -0700 Subject: [PATCH 25/93] fix licenses --- include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h | 2 +- include/gc/Transforms/Microkernel/MicrokernelPasses.td | 4 ++-- .../ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp | 2 +- lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp | 4 ++-- .../Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp | 4 ++-- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h b/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h index afe2da9b5..ebf177b21 100644 --- a/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h +++ b/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h @@ -1,4 +1,4 @@ -//===- BrgemmRuntimeUtils.h - Utils for Brgemm Runtime -----------*-C++ -*-===// +//===-- BrgemmRuntimeUtils.h - Utils for Brgemm Runtime ---------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td index 59726ac0a..bf9e3c61d 100644 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.td +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.td @@ -1,10 +1,10 @@ -//===- MicrokernelPasses.td - Graph Compiler microkernel passes *- tablegen -*-===// +//===-- MicrokernelPasses.td - microkernel passes ----------*- tablegen -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===--------------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// #ifndef GC_DIALECT_MICROKERNELPASSES #define GC_DIALECT_MICROKERNELPASSES diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index 6ccb715a6..61cfee6c1 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -1,4 +1,4 @@ -//===-- BrgemmNaive.cpp - BRGEMM Naive Implementation -----------*- C++ -*-===// +//===-- BrgemmOnednn.cpp - BRGEMM Onednn Implementation ---------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index 8eacdd921..8b96f59e5 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -1,10 +1,10 @@ -//===- ConvertLinalgToMicrokernel.cpp - Linalg To Microkernel -*- C++ -*--===// +//===-- ConvertLinalgToMicrokernel.cpp - Linalg To Microkernel --*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===---------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" diff --git a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp index 998500030..966f87c06 100644 --- a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp @@ -1,10 +1,10 @@ -//===- ConvertMicrokernelToDnnlFunc.cpp ------------------------*- C++ -*-===// +//===-- ConvertMicrokernelToDnnlFunc.cpp - Lower to dnnl funcs --*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===---------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" From f8100e18606f07665d6a0bd20502641c7b70ed4c Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 01:16:20 -0700 Subject: [PATCH 26/93] update clang-tidy workflow --- .github/workflows/clang-tidy.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/clang-tidy.yml b/.github/workflows/clang-tidy.yml index ae8479fbc..f65296c92 100644 --- a/.github/workflows/clang-tidy.yml +++ b/.github/workflows/clang-tidy.yml @@ -81,7 +81,8 @@ jobs: -DCMAKE_EXPORT_COMPILE_COMMANDS=True \ -DCMAKE_C_COMPILER=$(which clang) \ -DCMAKE_CXX_COMPILER=$(which clang++) \ - -DLLVM_EXTERNAL_LIT=$(which lit) + -DLLVM_EXTERNAL_LIT=$(which lit) \ + -DDNNL_USE_CLANG_SANITIZER="Undefined" - name: Prepare inc file run: | From ae3e9f8555f62f1067bdb5c1c2ff45bec789124d Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 01:51:11 -0700 Subject: [PATCH 27/93] fix tidy --- lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index 8b96f59e5..d2074086c 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -73,8 +73,7 @@ static bool isMatchingAffineResult(linalg::LinalgOp linalgOp, AffineExpr expr, if (dimPos.size() == 1) { if (firstDim == expr) return true; - else - return false; + return false; } // If not regular dim affine, check for VNNI format K affine auto secondKPosDim = getAffineDimExpr(dimPos[1], linalgOp.getContext()); @@ -244,7 +243,7 @@ static FailureOr getBrgemmInfo(linalg::LinalgOp linalgOp) { // batch-reduce dim for BRGEMM should be identified as one of k dim // including VNNI & non-VNNI cases (contractionDims->k.size() != 2 && contractionDims->k.size() != 3) || - contractionDims->batch.size() != 0) { + !contractionDims->batch.empty()) { LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions\n"); LLVM_DEBUG(llvm::dbgs() << "[checkStructure] " << contractionDims->m.size() << " " From c5cbbd3c6115e9280b02076f14f209d79d26a592 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 01:57:06 -0700 Subject: [PATCH 28/93] fix tidy --- .../Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index d2074086c..dec7c4577 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -70,11 +70,9 @@ static bool isMatchingAffineResult(linalg::LinalgOp linalgOp, AffineExpr expr, return false; } auto firstDim = getAffineDimExpr(dimPos[0], linalgOp.getContext()); - if (dimPos.size() == 1) { - if (firstDim == expr) - return true; - return false; - } + if (dimPos.size() == 1) + return firstDim == expr; + // If not regular dim affine, check for VNNI format K affine auto secondKPosDim = getAffineDimExpr(dimPos[1], linalgOp.getContext()); // An K affine result for VNNI should be this format: From 43b0c2858e3993d329d8a6946e394241ba59cfc6 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 20:59:14 -0700 Subject: [PATCH 29/93] fix tidy --- lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index 61cfee6c1..983ed46fb 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -57,7 +57,7 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, int64_t stride_a, int64_t stride_b, float beta, int64_t dtypeA, int64_t dtypeB) { - brgemm_desc_list.emplace_back(brgemm_desc_t()); + brgemm_desc_list.emplace_back(); brgemm_kernel_list.emplace_back(nullptr); brgemm_desc_t &desc = brgemm_desc_list.back(); From 334be08e63ddb67e699ae9180c109103d4ba7df6 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 30 Jul 2024 00:28:08 -0700 Subject: [PATCH 30/93] give teste better names --- .../Microkernel/linalg-to-microkernel.mlir | 16 ++++++++-------- .../Microkernel/microkernel-to-dnnl-func.mlir | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir index c721ed2ad..224329f62 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir @@ -1,7 +1,7 @@ // RUN: gc-opt %s -convert-linalg-to-microkernel -split-input-file | FileCheck %s #map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @simple_brgemm() { +func.func @basic_linalg_to_microkernel() { %cst = arith.constant 0.000000e+00 : f32 %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> @@ -29,7 +29,7 @@ func.func @simple_brgemm() { return } -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: basic_linalg_to_microkernel // CHECK: %[[CST0:.+]] = arith.constant 0 : i64 // CHECK: %[[CST16:.+]] = arith.constant 16 : i64 // CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> @@ -43,7 +43,7 @@ func.func @simple_brgemm() { // ----- #map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @simple_brgemm() { +func.func @vnni_linalg_to_microkernel() { %cst = arith.constant 0.000000e+00 : f32 %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> @@ -71,7 +71,7 @@ func.func @simple_brgemm() { return } -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: vnni_linalg_to_microkernel // CHECK: %[[CST0:.+]] = arith.constant 0 : i64 // CHECK: %[[CST16:.+]] = arith.constant 16 : i64 // CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> @@ -85,7 +85,7 @@ func.func @simple_brgemm() { // ----- #map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @simple_brgemm() { +func.func @basic_linalg_to_microkernel_fusing_fill() { %cst = arith.constant 0.000000e+00 : f32 %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> @@ -114,7 +114,7 @@ func.func @simple_brgemm() { return } -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: basic_linalg_to_microkernel_fusing_fill // CHECK: %[[CST0:.+]] = arith.constant 0 : i64 // CHECK: %[[CST16:.+]] = arith.constant 16 : i64 // CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> @@ -129,7 +129,7 @@ func.func @simple_brgemm() { // ----- #map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @simple_brgemm() { +func.func @vnni_linalg_to_microkernel_fusing_fill() { %cst = arith.constant 0.000000e+00 : f32 %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> @@ -158,7 +158,7 @@ func.func @simple_brgemm() { return } -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: vnni_linalg_to_microkernel_fusing_fill // CHECK: %[[CST0:.+]] = arith.constant 0 : i64 // CHECK: %[[CST16:.+]] = arith.constant 16 : i64 // CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir index 6d705920c..1520ae069 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir @@ -2,7 +2,7 @@ #map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @simple_brgemm() { + func.func @basic_convert() { %c0_i64 = arith.constant 0 : i64 %c16_i64 = arith.constant 16 : i64 %cst = arith.constant 0.000000e+00 : f32 @@ -39,7 +39,7 @@ module { // CHECK-LABEL: dnnl_brgemm_execute // CHECK-LABEL: dnnl_brgemm_dispatch -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: basic_convert // CHECK: %[[CST3:.+]] = arith.constant 3 : i64 // CHECK: %[[CST1F:.+]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[CST1024:.+]] = arith.constant 1024 : i64 From bb21eef3c7692cfc9633b2aa5581ac196480488c Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 31 Jul 2024 19:43:36 -0700 Subject: [PATCH 31/93] fix clang-format --- lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp | 2 +- .../ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp index 01710e322..01243c02e 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp @@ -90,7 +90,7 @@ static int naive_brgemm_execute_fp32(brgemm_params_t params, void *A, Abuf += params.stride_a; Bbuf += params.stride_b; } - + return 0; } diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index 983ed46fb..96c95e960 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -62,8 +62,8 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, brgemm_desc_t &desc = brgemm_desc_list.back(); auto &kernel = brgemm_kernel_list.back(); - - auto dnnl_dtypeA = static_cast(dtypeA); + + auto dnnl_dtypeA = static_cast(dtypeA); auto dnnl_dtypeB = static_cast(dtypeB); int64_t dtypeA_size = dnnl::impl::types::data_type_size(dnnl_dtypeA); int64_t dtypeB_size = dnnl::impl::types::data_type_size(dnnl_dtypeB); From f96bcaf6e2e79ebbc3855b4934000efbdc02dc4a Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Mon, 5 Aug 2024 22:32:32 -0700 Subject: [PATCH 32/93] minor fixes as per reviews --- .../Microkernel/BrgemmRuntimeUtils.h | 2 + .../CPURuntime/Microkernel/BrgemmNaive.cpp | 53 +++++++++--- .../CPURuntime/Microkernel/BrgemmOnednn.cpp | 85 +++++++++++-------- 3 files changed, 91 insertions(+), 49 deletions(-) diff --git a/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h b/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h index ebf177b21..adb214e10 100644 --- a/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h +++ b/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h @@ -31,6 +31,8 @@ static inline int64_t getDnnlDataTypeVal(RewriterBase &rewriter, return static_cast(dnnl_f32); } else if (tattr == TypeAttr::get(FloatType::getBF16(context))) { return static_cast(dnnl_bf16); + } else if (tattr == TypeAttr::get(FloatType::getF16(context))) { + return static_cast(dnnl_f16); } else if (tattr == TypeAttr::get( IntegerType::get(context, 32, IntegerType::Signed))) { return static_cast(dnnl_s32); diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp index 01243c02e..60e2c7a29 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -67,10 +68,20 @@ struct brgemm_params_t { }; // namespace -static int naive_brgemm_execute_fp32(brgemm_params_t params, void *A, - uint64_t A_offset, void *B, - uint64_t B_offset, void *C, - uint64_t C_offset, int num) { +template +static void naive_brgemm_init(const brgemm_params_t ¶ms, void *C, + uint64_t C_offset) { + C_type *Cbuf = (C_type *)C; + Cbuf += C_offset; + for (int i = 0; i < params.M * params.N; i++) { + Cbuf[i] = C_type(0); + } +} + +static void naive_brgemm_execute_fp32(const brgemm_params_t ¶ms, void *A, + uint64_t A_offset, void *B, + uint64_t B_offset, void *C, + uint64_t C_offset, int num) { float *Abuf = (float *)A; float *Bbuf = (float *)B; float *Cbuf = (float *)C; @@ -90,11 +101,9 @@ static int naive_brgemm_execute_fp32(brgemm_params_t params, void *A, Abuf += params.stride_a; Bbuf += params.stride_b; } - - return 0; } -static void naive_brgemm_execute_bf16(brgemm_params_t params, void *A, +static void naive_brgemm_execute_bf16(const brgemm_params_t ¶ms, void *A, uint64_t A_offset, void *B, uint64_t B_offset, void *C, uint64_t C_offset, int num) { @@ -126,7 +135,7 @@ static void naive_brgemm_execute_bf16(brgemm_params_t params, void *A, } template -static void naive_brgemm_execute_int8(brgemm_params_t params, void *A, +static void naive_brgemm_execute_int8(const brgemm_params_t ¶ms, void *A, uint64_t A_offset, void *B, uint64_t B_offset, void *C, uint64_t C_offset, int num) { @@ -165,7 +174,8 @@ static void naive_brgemm_execute_int8(brgemm_params_t params, void *A, } } -static std::vector brgemm_list; +static std::mutex g_brgemm_mutex; +static std::vector g_brgemm_list; extern "C" { @@ -173,6 +183,7 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, int64_t stride_a, int64_t stride_b, float beta, int64_t dtypeA, int64_t dtypeB) { + std::lock_guard g(g_brgemm_list); // simply store the given parameters for naive BRGEMM brgemm_list.emplace_back(brgemm_params_t(M, N, K, LDA, LDB, LDC, stride_a, stride_b, beta, dtypeA, dtypeB)); @@ -186,9 +197,27 @@ void dnnl_brgemm_tilerelease() { return; } void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void *B, uint64_t B_offset, void *C, uint64_t C_offset, int num) { - assert(kernel >= 0 && kernel < (int64_t)brgemm_list.size() && - "Invalid kernel handler"); - brgemm_params_t ¶ms = brgemm_list[kernel]; + brgemm_params_t params; + { + std::lock_guard g(g_brgemm_list); + assert(kernel >= 0 && kernel < (int64_t)g_brgemm_list.size() && + "Invalid kernel handler"); + params = brgemm_list[kernel]; + } + + if (params.beta == 0.0f) { + if ((params.dtypeA == static_cast(dnnl_f32) || + params.dtypeA == static_cast(dnnl_bf16)) && + (params.dtypeB == static_cast(dnnl_f32) || + params.dtypeB == static_cast(dnnl_bf16))) + naive_brgemm_init(params, C, C_offset); + else if ((params.dtypeA == static_cast(dnnl_s8) || + params.dtypeA == static_cast(dnnl_u8)) && + (params.dtypeB == static_cast(dnnl_s8) || + params.dtypeB == static_cast(dnnl_u8))) + naive_brgemm_init(params, C, C_offset); + } + if (params.dtypeA == static_cast(dnnl_f32) && params.dtypeB == static_cast(dnnl_f32)) { naive_brgemm_execute_fp32(params, A, A_offset, B, B_offset, C, C_offset, diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index 96c95e960..56f6b099e 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -42,9 +43,10 @@ __attribute__((weak)) void print_verbose_header() {} } // namespace dnnl static constexpr int PALETTE_SIZE = 64; -static std::vector brgemm_desc_list; -static std::vector brgemm_kernel_list; -static std::vector brgemm_palette; +static std::mutex g_brgemm_mutex; +static std::vector g_brgemm_desc_list; +static std::vector g_brgemm_kernel_list; +static std::vector g_brgemm_palette; // TODO(haixin): use syscall to determine page size? static constexpr size_t SCRATCH_SIZE = 2 * 4096; @@ -57,11 +59,8 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, int64_t stride_a, int64_t stride_b, float beta, int64_t dtypeA, int64_t dtypeB) { - brgemm_desc_list.emplace_back(); - brgemm_kernel_list.emplace_back(nullptr); - - brgemm_desc_t &desc = brgemm_desc_list.back(); - auto &kernel = brgemm_kernel_list.back(); + brgemm_desc_t desc; + brgemm_kernel_t *kernel; auto dnnl_dtypeA = static_cast(dtypeA); auto dnnl_dtypeB = static_cast(dtypeB); @@ -71,8 +70,9 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, dnnl::impl::status_t status = brgemm_desc_init( &desc, cpu_isa_t::isa_undef, brgemm_batch_kind_t::brgemm_strd, - dnnl_dtypeA, dnnl_dtypeB, false, false, brgemm_layout_t::brgemm_row_major, - 1.0f, beta, LDA, LDB, LDC, M, N, K, &stride_info); + dnnl_dtypeA, dnnl_dtypeB, /*transA=*/false, /*transB=*/false, + brgemm_layout_t::brgemm_row_major, 1.0f, beta, LDA, LDB, LDC, M, N, K, + &stride_info); assert(status == dnnl::impl::status::success && "Failed to initialize BRGEMM descriptor"); @@ -84,31 +84,37 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, brgemm_desc_set_attr(&desc, dnnl_attrs); // TODO(haixin): Reuse identical palettes across kernels + char *palette_buffer = nullptr; if (desc.is_tmm) { - brgemm_palette.push_back(new char[PALETTE_SIZE]); - dnnl::impl::status_t status = - brgemm_init_tiles(desc, brgemm_palette.back()); + palette_buffer = new char[PALETTE_SIZE]; + dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer); assert(status == dnnl::impl::status::success && "Failed to initialize palette for BRGEMM"); - } else { - brgemm_palette.push_back(nullptr); } - return brgemm_desc_list.size() - 1; + std::lock_guard g(g_brgemm_mutex); + g_brgemm_desc_list.push_back(desc); + g_brgemm_kernel_list.push_back(kernel); + g_brgemm_palette.push_back(palette_buffer); + + return g_brgemm_desc_list.size() - 1; } void dnnl_brgemm_tileconfig(int64_t kernel_idx) { - assert(kernel_idx >= 0 && kernel_idx < (int64_t)brgemm_desc_list.size() && - "Invalid kernel handler"); - - brgemm_desc_t &desc = brgemm_desc_list[kernel_idx]; - if (!desc.is_tmm) { - return; + char *palette_buffer = nullptr; + { + std::lock_guard g(g_brgemm_mutex); + assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() && + "Invalid kernel handler"); + brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx]; + if (!desc.is_tmm) { + return; + } + palette_buffer = g_brgemm_palette[kernel_idx]; } - assert(brgemm_palette[kernel_idx] != nullptr && - "Invalid palette for BRGEMM kernel"); - amx_tile_configure(brgemm_palette[kernel_idx]); + assert(palette_buffer != nullptr && "Invalid palette for BRGEMM kernel"); + amx_tile_configure(palette_buffer); } void dnnl_brgemm_tilerelease() { @@ -122,19 +128,24 @@ void dnnl_brgemm_tilerelease() { void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset, void *B, uint64_t B_offset, void *C, uint64_t C_offset, int num) { - assert(kernel_idx >= 0 && kernel_idx < (int64_t)brgemm_desc_list.size() && - "Invalid kernel handler"); - - brgemm_desc_t &desc = brgemm_desc_list[kernel_idx]; - brgemm_kernel_t *kernel = brgemm_kernel_list[kernel_idx]; - - size_t A_offset_in_bytes = - dnnl::impl::types::data_type_size(desc.dt_a) * A_offset; - size_t B_offset_in_bytes = - dnnl::impl::types::data_type_size(desc.dt_b) * B_offset; - size_t C_offset_in_bytes = - dnnl::impl::types::data_type_size(desc.dt_c) * C_offset; + brgemm_kernel_t *kernel = nullptr; + size_t A_offset_in_bytes; + size_t B_offset_in_bytes; + size_t C_offset_in_bytes; + { + std::lock_guard g(g_brgemm_mutex); + assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() && + "Invalid kernel handler"); + + brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx]; + kernel = g_brgemm_kernel_list[kernel_idx]; + + A_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_a) * A_offset; + B_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_b) * B_offset; + C_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_c) * C_offset; + } + assert(kernel && "Invalid brgemm kernel pointer"); char *A_arith = (char *)A; char *B_arith = (char *)B; char *C_arith = (char *)C; From 10759184b4441d3a109ba3b89055f253358877b5 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 6 Aug 2024 02:04:22 -0700 Subject: [PATCH 33/93] merge main CMake change --- cmake/onednn_lite_config.cmake | 4 ++-- lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt | 10 +++++----- .../CPURuntime/Microkernel/BrgemmNaive.cpp | 13 +++++++------ lib/gc/Transforms/Microkernel/CMakeLists.txt | 11 +++++------ 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/cmake/onednn_lite_config.cmake b/cmake/onednn_lite_config.cmake index 848c3c292..c4f68036a 100644 --- a/cmake/onednn_lite_config.cmake +++ b/cmake/onednn_lite_config.cmake @@ -47,7 +47,7 @@ endif() if(UNIX OR MINGW) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") endif() ########## from cmake/options.cmake @@ -349,5 +349,5 @@ set_property(TARGET dnnl_brgemm PROPERTY POSITION_INDEPENDENT_CODE ON # set_property(GLOBAL APPEND PROPERTY DNNL_SUBDIR_EXTRA_STATIC_LIBS $) # set_property(GLOBAL APPEND PROPERTY DNNL_SUBDIR_EXTRA_SHARED_LIBS dnnl_brgemm) # Currently build objs only -set_property(GLOBAL APPEND PROPERTY DNNL_LIB_DEPS +set_property(GLOBAL APPEND PROPERTY GC_DNNL_LIB_DEPS $) diff --git a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt index 6027b0e06..7d47bb825 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt @@ -5,7 +5,7 @@ file(GLOB_RECURSE MICROKERNEL_RUNTIME_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/*.c ) -if (GC_RUNTIME_NAIVE_BRGEMM) +if (GC_ENABLE_RUNTIME_NAIVE_BRGEMM) string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/BrgemmOnednn.cpp" "" MICROKERNEL_RUNTIME_SOURCES "${MICROKERNEL_RUNTIME_SOURCES}") else() string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/BrgemmNaive.cpp" "" MICROKERNEL_RUNTIME_SOURCES "${MICROKERNEL_RUNTIME_SOURCES}") @@ -13,10 +13,10 @@ endif() include(onednn) -get_property(DNNL_INCLUDES GLOBAL PROPERTY GC_DNNL_INCLUDES) -get_property(DNNL_LIB_DEPS GLOBAL PROPERTY GC_DNNL_LIB_DEPS) +get_property(GC_DNNL_INCLUDES GLOBAL PROPERTY GC_DNNL_INCLUDES) +get_property(GC_DNNL_LIB_DEPS GLOBAL PROPERTY GC_DNNL_LIB_DEPS) -include_directories(${DNNL_INCLUDES}) +include_directories(${GC_DNNL_INCLUDES}) gc_add_mlir_library(GcCpuRuntime SHARED @@ -27,7 +27,7 @@ gc_add_mlir_library(GcCpuRuntime dnnl_brgemm LINK_LIBS PRIVATE - ${DNNL_LIB_DEPS} + ${GC_DNNL_LIB_DEPS} LINK_LIBS PUBLIC GcInterface diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp index 60e2c7a29..07c7be0c0 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp @@ -59,6 +59,7 @@ struct brgemm_params_t { int64_t stride_a, stride_b; float beta; int64_t dtypeA, dtypeB; + brgemm_params_t() {} brgemm_params_t(int64_t m, int64_t n, int64_t k, int64_t lda, int64_t ldb, int64_t ldc, int64_t sa, int64_t sb, float b, int64_t da, int64_t db) @@ -183,11 +184,11 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, int64_t stride_a, int64_t stride_b, float beta, int64_t dtypeA, int64_t dtypeB) { - std::lock_guard g(g_brgemm_list); + std::lock_guard g(g_brgemm_mutex); // simply store the given parameters for naive BRGEMM - brgemm_list.emplace_back(brgemm_params_t(M, N, K, LDA, LDB, LDC, stride_a, - stride_b, beta, dtypeA, dtypeB)); - return brgemm_list.size() - 1; + g_brgemm_list.emplace_back(brgemm_params_t(M, N, K, LDA, LDB, LDC, stride_a, + stride_b, beta, dtypeA, dtypeB)); + return g_brgemm_list.size() - 1; } void dnnl_brgemm_tileconfig(int64_t kernel) { return; } @@ -199,10 +200,10 @@ void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void *B, int num) { brgemm_params_t params; { - std::lock_guard g(g_brgemm_list); + std::lock_guard g(g_brgemm_mutex); assert(kernel >= 0 && kernel < (int64_t)g_brgemm_list.size() && "Invalid kernel handler"); - params = brgemm_list[kernel]; + params = g_brgemm_list[kernel]; } if (params.beta == 0.0f) { diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt index af6fa30e0..8c571bfce 100644 --- a/lib/gc/Transforms/Microkernel/CMakeLists.txt +++ b/lib/gc/Transforms/Microkernel/CMakeLists.txt @@ -2,11 +2,11 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR) include(onednn) -get_property(DNNL_INCLUDES GLOBAL PROPERTY DNNL_INCLUDES) +get_property(GC_DNNL_INCLUDES GLOBAL PROPERTY GC_DNNL_INCLUDES) -include_directories(${DNNL_INCLUDES}) +include_directories(${GC_DNNL_INCLUDES}) -add_mlir_dialect_library(MLIRMicrokernelTransforms +gc_add_mlir_dialect_library(MLIRMicrokernelTransforms ConvertLinalgToMicrokernel.cpp ConvertMicrokernelToDnnlFunc.cpp @@ -18,7 +18,6 @@ add_mlir_dialect_library(MLIRMicrokernelTransforms LINK_LIBS PUBLIC ${MLIR_LINK_COMPONENTS} - GCUtilsIR + GcInterface + GcUtilsIR ) - -set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS MLIRMicrokernelTransforms) From 181cbf09373b5058067aaab1331c1b23e5331cf8 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 7 Aug 2024 02:00:52 -0700 Subject: [PATCH 34/93] minor fixes & change GC_ENABLE_DNNL to GC_ENABLE_DNNL_API --- CMakeLists.txt | 6 ++-- include/gc/Dialect/OneDNNGraph/CMakeLists.txt | 2 +- include/gc/Transforms/CMakeLists.txt | 2 +- lib/gc/Dialect/OneDNNGraph/CMakeLists.txt | 2 +- lib/gc/Transforms/Microkernel/CMakeLists.txt | 8 ++--- .../ConvertLinalgToMicrokernel.cpp | 36 +++++++++---------- python/CMakeLists.txt | 4 +-- src/dnnl/CMakeLists.txt | 18 +++++----- test/dnnl/CMakeLists.txt | 6 ++-- test/mlir/test/lit.site.cfg.py.in | 2 +- 10 files changed, 42 insertions(+), 44 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ce02b7975..4afef5d7a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,9 +36,9 @@ endif() ############################ Build options ##################################### option(GC_ENABLE_LEGACY ON) -option(GC_ENABLE_DNNL "Enable the oneDNN library integration" ON) +option(GC_ENABLE_DNNL_API "Enable the oneDNN library API integration" ON) option(GC_ENABLE_TEST "Build the tests" ON) -option(GC_ENABLE_TEST_DNNL "Build the dnnl tests" ${GC_ENABLE_DNNL}) +option(GC_ENABLE_TEST_DNNL "Build the dnnl tests" ${GC_ENABLE_DNNL_API}) option(GC_ENABLE_TEST_MLIR "Build the mlir tests" ON) option(GC_ENABLE_TOOLS "Build the tools" ON) option(GC_ENABLE_OPT "Build gc-opt" ${GC_ENABLE_TOOLS}) @@ -51,7 +51,7 @@ if(GC_ENABLE_LEGACY) add_subdirectory(legacy/core) endif() -if(GC_ENABLE_DNNL) +if(GC_ENABLE_DNNL_API) set(GC_ONEDNN_DIALECT_LIB_NAME MLIROneDNNGraph) endif() ################################################################################ diff --git a/include/gc/Dialect/OneDNNGraph/CMakeLists.txt b/include/gc/Dialect/OneDNNGraph/CMakeLists.txt index 1b1f00222..08bad8334 100644 --- a/include/gc/Dialect/OneDNNGraph/CMakeLists.txt +++ b/include/gc/Dialect/OneDNNGraph/CMakeLists.txt @@ -1,4 +1,4 @@ -if (NOT GC_ENABLE_DNNL) +if (NOT GC_ENABLE_DNNL_API) message(STATUS "OneDNNGraphDialect is not enabled.") return() endif () diff --git a/include/gc/Transforms/CMakeLists.txt b/include/gc/Transforms/CMakeLists.txt index 6d9ec1ac6..08443020b 100644 --- a/include/gc/Transforms/CMakeLists.txt +++ b/include/gc/Transforms/CMakeLists.txt @@ -1,4 +1,4 @@ -if(GC_ENABLE_DNNL) +if(GC_ENABLE_DNNL_API) list(APPEND TABLEGEN_MACROS -DGC_HAS_ONEDNN_DIALECT) endif() if(GC_ENABLE_IMEX) diff --git a/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt b/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt index 56bb58f95..2f6434453 100644 --- a/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt +++ b/lib/gc/Dialect/OneDNNGraph/CMakeLists.txt @@ -1,4 +1,4 @@ -if (NOT GC_ENABLE_DNNL) +if (NOT GC_ENABLE_DNNL_API) return() endif () diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt index 8c571bfce..e33db7185 100644 --- a/lib/gc/Transforms/Microkernel/CMakeLists.txt +++ b/lib/gc/Transforms/Microkernel/CMakeLists.txt @@ -2,10 +2,6 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR) include(onednn) -get_property(GC_DNNL_INCLUDES GLOBAL PROPERTY GC_DNNL_INCLUDES) - -include_directories(${GC_DNNL_INCLUDES}) - gc_add_mlir_dialect_library(MLIRMicrokernelTransforms ConvertLinalgToMicrokernel.cpp ConvertMicrokernelToDnnlFunc.cpp @@ -21,3 +17,7 @@ gc_add_mlir_dialect_library(MLIRMicrokernelTransforms GcInterface GcUtilsIR ) + +get_property(GC_DNNL_INCLUDES GLOBAL PROPERTY GC_DNNL_INCLUDES) + +target_include_directories(MLIRMicrokernelTransforms PUBLIC ${GC_DNNL_INCLUDES}) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index dec7c4577..ba66bf895 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -53,7 +53,7 @@ FailureOr customInferContractionDims(linalg::LinalgOp linalgOp) { auto dims = linalg::inferContractionDims(linalgOp); if (failed(dims)) - return dims; + return failure(); if (llvm::isa(linalgOp)) { // For VnniOp, the K reduction dims (dim index 3 & 4) cannot be infered by // linalg utils because they form complex affine in operand A; Manually add @@ -77,24 +77,22 @@ static bool isMatchingAffineResult(linalg::LinalgOp linalgOp, AffineExpr expr, auto secondKPosDim = getAffineDimExpr(dimPos[1], linalgOp.getContext()); // An K affine result for VNNI should be this format: // d{kPos[0]} * s{kPos[1]} + d{kPos[1]} (k0 * K_vnni + k1) - if (auto add = dyn_cast(expr)) { - if (add.getKind() == AffineExprKind::Add) { - auto lhs = add.getLHS(); - auto rhs = add.getRHS(); - if (rhs == secondKPosDim) { - auto mul = dyn_cast(lhs); - if (mul && mul.getKind() == AffineExprKind::Mul && - mul.getLHS() == firstDim) { - if (auto cst_affine = dyn_cast(mul.getRHS())) { - if (cst_affine.getValue() == 2 || cst_affine.getValue() == 4) { - return true; - } - } - } - } - } - } - return false; + auto add = dyn_cast(expr); + if (!add) + return false; + if (add.getKind() != AffineExprKind::Add) + return false; + auto lhs = add.getLHS(); + auto rhs = add.getRHS(); + if (rhs != secondKPosDim) + return false; + auto mul = dyn_cast(lhs); + if (!mul || mul.getKind() != AffineExprKind::Mul || mul.getLHS() != firstDim) + return false; + auto cst_affine = dyn_cast(mul.getRHS()); + if (!cst_affine || (cst_affine.getValue() != 2 && cst_affine.getValue() != 4)) + return false; + return true; } // Return the position of `dim` in the codomain of `operand`. diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 68d522d26..02e3626a7 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -56,7 +56,7 @@ declare_mlir_python_sources(GcPythonSources.Common # Dialect bindings ################################################################################ -if(GC_ENABLE_DNNL) +if(GC_ENABLE_DNNL_API) declare_mlir_dialect_python_bindings( ADD_TO_PARENT GcPythonSources ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/gc_mlir" @@ -114,4 +114,4 @@ add_mlir_python_modules(GcPythonModules MLIRPythonSources COMMON_CAPI_LINK_LIBS GcPythonCAPI - ) \ No newline at end of file + ) diff --git a/src/dnnl/CMakeLists.txt b/src/dnnl/CMakeLists.txt index 69e8cb29f..375077006 100644 --- a/src/dnnl/CMakeLists.txt +++ b/src/dnnl/CMakeLists.txt @@ -13,7 +13,7 @@ # and limitations under the License. # SPDX-License-Identifier: Apache-2.0 -if(NOT GC_ENABLE_DNNL) +if(NOT GC_ENABLE_DNNL_API) message(STATUS "oneDNN library integration is not enabled.") return() endif() @@ -28,21 +28,21 @@ set(GC_DNNL_LINKED_LIBS GcJitWrapper GcCpuRuntime ) -gc_add_mlir_library(GcDnnl SHARED +gc_add_mlir_library(GcDnnlApi SHARED ${GC_DNNL_SOURCES} LINK_LIBS PRIVATE ${GC_DNNL_LINKED_LIBS} ) -target_link_libraries(GcDnnl PUBLIC GcInterface) -target_include_directories(GcDnnl PUBLIC ${GC_DNNL_INCLUDES}) -target_compile_options(GcDnnl PRIVATE -fvisibility=hidden -fexceptions) -target_link_options(GcDnnl PRIVATE -Wl,--gc-sections) +target_link_libraries(GcDnnlApi PUBLIC GcInterface) +target_include_directories(GcDnnlApi PUBLIC ${GC_DNNL_INCLUDES}) +target_compile_options(GcDnnlApi PRIVATE -fvisibility=hidden -fexceptions) +target_link_options(GcDnnlApi PRIVATE -Wl,--gc-sections) if(GC_ENABLE_TEST_DNNL) # Static graph compiler library to be used in tests - gc_add_mlir_library(GcDnnlStatic STATIC + gc_add_mlir_library(GcDnnlApiStatic STATIC ${GC_DNNL_SOURCES} LINK_LIBS PUBLIC GcInterface ${GC_DNNL_LINKED_LIBS} ) - target_compile_options(obj.GcDnnlStatic PUBLIC -fexceptions) - target_include_directories(GcDnnlStatic PUBLIC ${GC_DNNL_INCLUDES}) + target_compile_options(obj.GcDnnlApiStatic PUBLIC -fexceptions) + target_include_directories(GcDnnlApiStatic PUBLIC ${GC_DNNL_INCLUDES}) endif() diff --git a/test/dnnl/CMakeLists.txt b/test/dnnl/CMakeLists.txt index 18d7d421a..6166fda21 100644 --- a/test/dnnl/CMakeLists.txt +++ b/test/dnnl/CMakeLists.txt @@ -1,4 +1,4 @@ -if (NOT GC_ENABLE_TEST OR NOT GC_ENABLE_TEST_DNNL OR NOT GC_ENABLE_DNNL) +if (NOT GC_ENABLE_TEST OR NOT GC_ENABLE_TEST_DNNL OR NOT GC_ENABLE_DNNL_API) message(STATUS "The dnnl tests are not enabled.") return() endif () @@ -15,10 +15,10 @@ foreach (TEST_SOURCE ${TEST_SOURCES}) add_executable(${TEST_NAME} ${TEST_SOURCE}) if (${TEST_NAME} MATCHES "^TestApi.*") # The API tests are linked with the shared lib - target_link_libraries(${TEST_NAME} PRIVATE LLVMSupport GcDnnl) + target_link_libraries(${TEST_NAME} PRIVATE LLVMSupport GcDnnlApi) else () # The other tests are linked with the static lib and have non-public includes - target_link_libraries(${TEST_NAME} PRIVATE GcDnnlStatic) + target_link_libraries(${TEST_NAME} PRIVATE GcDnnlApiStatic) target_include_directories(${TEST_NAME} PRIVATE ${PROJECT_SOURCE_DIR}/src/dnnl) endif () target_link_libraries(${TEST_NAME} PRIVATE gtest gtest_main) diff --git a/test/mlir/test/lit.site.cfg.py.in b/test/mlir/test/lit.site.cfg.py.in index 37c0d01d2..0c189d743 100644 --- a/test/mlir/test/lit.site.cfg.py.in +++ b/test/mlir/test/lit.site.cfg.py.in @@ -43,7 +43,7 @@ config.mlir_runner_utils = os.path.normpath(os.path.join(config.mlir_runner_util config.mlir_c_runner_utils = os.path.normpath(os.path.join(config.mlir_runner_utils_dir, config.shlib_prefix + "mlir_c_runner_utils" + config.llvm_shlib_ext)) config.opencl_runtime = os.path.normpath(os.path.join(config.gc_lib_dir, config.shlib_prefix + "GcOpenclRuntime" + config.llvm_shlib_ext)) -config.gc_use_dnnl = "@GC_ENABLE_DNNL@" in ["ON", "1"] +config.gc_use_dnnl = "@GC_ENABLE_DNNL_API@" in ["ON", "1"] import lit.llvm lit.llvm.initialize(lit_config, config) From 9e69362581bafaa96530253b374d9ca2861493fb Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Thu, 8 Aug 2024 00:53:35 -0700 Subject: [PATCH 35/93] remove comments in cmake --- cmake/onednn_lite_config.cmake | 8 -------- 1 file changed, 8 deletions(-) diff --git a/cmake/onednn_lite_config.cmake b/cmake/onednn_lite_config.cmake index c4f68036a..74e8548a2 100644 --- a/cmake/onednn_lite_config.cmake +++ b/cmake/onednn_lite_config.cmake @@ -340,14 +340,6 @@ set_property(TARGET dnnl_brgemm PROPERTY POSITION_INDEPENDENT_CODE ON CXX_VISIBILITY_PRESET "hidden" VISIBILITY_INLINES_HIDDEN 1) -# install(TARGETS dnnl_brgemm -# EXPORT dnnl_brgemm_export -# RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} -# LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} -# ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) - -# set_property(GLOBAL APPEND PROPERTY DNNL_SUBDIR_EXTRA_STATIC_LIBS $) -# set_property(GLOBAL APPEND PROPERTY DNNL_SUBDIR_EXTRA_SHARED_LIBS dnnl_brgemm) # Currently build objs only set_property(GLOBAL APPEND PROPERTY GC_DNNL_LIB_DEPS $) From 9163da339864d8415e1b037eb81e69b5478061c0 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Thu, 8 Aug 2024 02:05:08 -0700 Subject: [PATCH 36/93] add runtime brgemm entry point --- .../CPURuntime/Microkernel/BrgemmInterface.h | 59 +++++++++++++++++++ .../CPURuntime/Microkernel/BrgemmNaive.cpp | 2 + .../CPURuntime/Microkernel/BrgemmOnednn.cpp | 2 + 3 files changed, 63 insertions(+) create mode 100644 include/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmInterface.h diff --git a/include/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmInterface.h b/include/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmInterface.h new file mode 100644 index 000000000..ce2f19085 --- /dev/null +++ b/include/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmInterface.h @@ -0,0 +1,59 @@ +//===-- BrgemmInterface.h - The interfaces of runtime Brgemm ----*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_EXECUTIONENGINE_CPURUNTIME_MICROKERNEL_BRGEMMINTERFACE_H +#define GC_EXECUTIONENGINE_CPURUNTIME_MICROKERNEL_BRGEMMINTERFACE_H + +extern "C" { + +/** + * Dispatch (JIT) the Brgemm kernel based on given parameters using DNNL + * Inputs: + * M, N, K: The size of Brgemm dims, given in element size; + * LDA, LDB, LDC: The stride of leading dim of + * each Brgemm matrix, given in element size; + * stride_a, stride_b: The stride of batch of Brgemm + * input A & B, given in element size; + * dtypeA, dtypeB: The dtype of Brgemm input A and B, + * given in dnnl type value. + * Output: A handle of dispatched kernel. + */ +int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, + int64_t LDB, int64_t LDC, int64_t stride_a, + int64_t stride_b, float beta, int64_t dtypeA, + int64_t dtypeB); + +/** + * Config the AMX tile context for given kernel. + * Inputs: A handle of dispatched kernel. + * Output: None. + */ +void dnnl_brgemm_tileconfig(int64_t kernel); + +/** + * Release the current AMX tile context. + * Inputs: None. + * Output: None. + */ +void dnnl_brgemm_tilerelease(); + +/** + * Execute the given kernel with given parameters. + * Inputs: + * kernel: A handle of dispatched kernel; + * A, A_offset, B, B_offset, C, C_offset: + * Pointers and starting offset of each Brgemm matrix; + * num: Batch size of Brgemm. + * Output: None. + */ +void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void *B, + uint64_t B_offset, void *C, uint64_t C_offset, + int num); +} + +#endif diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp index 07c7be0c0..75fddea27 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp @@ -16,6 +16,8 @@ #include "oneapi/dnnl/dnnl_types.h" +#include "gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmInterface.h" + namespace { struct bf16_t { diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index 56f6b099e..f37be394e 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -29,6 +29,8 @@ #include #include +#include "gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmInterface.h" + using namespace dnnl::impl::cpu::x64; namespace dnnl { From 29525c910729e58bdf09340eb0ba140c072047ee Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Thu, 8 Aug 2024 02:05:43 -0700 Subject: [PATCH 37/93] change cmake option name GC_ENABLE_TEST_DNNL to GC_ENABLE_TEST_DNNL_API --- CMakeLists.txt | 2 +- src/dnnl/CMakeLists.txt | 2 +- test/dnnl/CMakeLists.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4afef5d7a..4b3da8f7b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,7 @@ endif() option(GC_ENABLE_LEGACY ON) option(GC_ENABLE_DNNL_API "Enable the oneDNN library API integration" ON) option(GC_ENABLE_TEST "Build the tests" ON) -option(GC_ENABLE_TEST_DNNL "Build the dnnl tests" ${GC_ENABLE_DNNL_API}) +option(GC_ENABLE_TEST_DNNL_API "Build the dnnl tests" ${GC_ENABLE_DNNL_API}) option(GC_ENABLE_TEST_MLIR "Build the mlir tests" ON) option(GC_ENABLE_TOOLS "Build the tools" ON) option(GC_ENABLE_OPT "Build gc-opt" ${GC_ENABLE_TOOLS}) diff --git a/src/dnnl/CMakeLists.txt b/src/dnnl/CMakeLists.txt index 375077006..1b4af9fca 100644 --- a/src/dnnl/CMakeLists.txt +++ b/src/dnnl/CMakeLists.txt @@ -37,7 +37,7 @@ target_include_directories(GcDnnlApi PUBLIC ${GC_DNNL_INCLUDES}) target_compile_options(GcDnnlApi PRIVATE -fvisibility=hidden -fexceptions) target_link_options(GcDnnlApi PRIVATE -Wl,--gc-sections) -if(GC_ENABLE_TEST_DNNL) +if(GC_ENABLE_TEST_DNNL_API) # Static graph compiler library to be used in tests gc_add_mlir_library(GcDnnlApiStatic STATIC ${GC_DNNL_SOURCES} diff --git a/test/dnnl/CMakeLists.txt b/test/dnnl/CMakeLists.txt index 6166fda21..8a6b537e2 100644 --- a/test/dnnl/CMakeLists.txt +++ b/test/dnnl/CMakeLists.txt @@ -1,4 +1,4 @@ -if (NOT GC_ENABLE_TEST OR NOT GC_ENABLE_TEST_DNNL OR NOT GC_ENABLE_DNNL_API) +if (NOT GC_ENABLE_TEST OR NOT GC_ENABLE_TEST_DNNL_API OR NOT GC_ENABLE_DNNL_API) message(STATUS "The dnnl tests are not enabled.") return() endif () From cb9ac0924a232ae1e98fc5f2e7bce92944fa5db1 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Thu, 8 Aug 2024 02:16:17 -0700 Subject: [PATCH 38/93] use smart ptr to manage palette buffer --- .../CPURuntime/Microkernel/BrgemmOnednn.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index f37be394e..cddec46a8 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -48,7 +49,7 @@ static constexpr int PALETTE_SIZE = 64; static std::mutex g_brgemm_mutex; static std::vector g_brgemm_desc_list; static std::vector g_brgemm_kernel_list; -static std::vector g_brgemm_palette; +static std::vector> g_brgemm_palette; // TODO(haixin): use syscall to determine page size? static constexpr size_t SCRATCH_SIZE = 2 * 4096; @@ -97,7 +98,7 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, std::lock_guard g(g_brgemm_mutex); g_brgemm_desc_list.push_back(desc); g_brgemm_kernel_list.push_back(kernel); - g_brgemm_palette.push_back(palette_buffer); + g_brgemm_palette.emplace_back(palette_buffer); return g_brgemm_desc_list.size() - 1; } @@ -112,7 +113,7 @@ void dnnl_brgemm_tileconfig(int64_t kernel_idx) { if (!desc.is_tmm) { return; } - palette_buffer = g_brgemm_palette[kernel_idx]; + palette_buffer = g_brgemm_palette[kernel_idx].get(); } assert(palette_buffer != nullptr && "Invalid palette for BRGEMM kernel"); From 51caf8dcbdba26e0013b7aa53bf166d566263c3a Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Thu, 8 Aug 2024 22:51:54 -0700 Subject: [PATCH 39/93] fix clang format --- .../ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp | 6 +++--- .../Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index cddec46a8..eb9bc6ba6 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -95,7 +95,7 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, "Failed to initialize palette for BRGEMM"); } - std::lock_guard g(g_brgemm_mutex); + std::lock_guard g(g_brgemm_mutex); g_brgemm_desc_list.push_back(desc); g_brgemm_kernel_list.push_back(kernel); g_brgemm_palette.emplace_back(palette_buffer); @@ -106,7 +106,7 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, void dnnl_brgemm_tileconfig(int64_t kernel_idx) { char *palette_buffer = nullptr; { - std::lock_guard g(g_brgemm_mutex); + std::lock_guard g(g_brgemm_mutex); assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() && "Invalid kernel handler"); brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx]; @@ -136,7 +136,7 @@ void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset, size_t B_offset_in_bytes; size_t C_offset_in_bytes; { - std::lock_guard g(g_brgemm_mutex); + std::lock_guard g(g_brgemm_mutex); assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() && "Invalid kernel handler"); diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index ba66bf895..45380a19e 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -89,10 +89,10 @@ static bool isMatchingAffineResult(linalg::LinalgOp linalgOp, AffineExpr expr, auto mul = dyn_cast(lhs); if (!mul || mul.getKind() != AffineExprKind::Mul || mul.getLHS() != firstDim) return false; + auto cst_affine = dyn_cast(mul.getRHS()); - if (!cst_affine || (cst_affine.getValue() != 2 && cst_affine.getValue() != 4)) - return false; - return true; + return cst_affine && + (cst_affine.getValue() == 2 || cst_affine.getValue() == 4); } // Return the position of `dim` in the codomain of `operand`. From 44937e4e76b85977bf3f0753e20894fb4e597365 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Thu, 8 Aug 2024 23:24:47 -0700 Subject: [PATCH 40/93] Remove pass ConvertMicrokernelToDnnlFunc --- .../Microkernel/MicrokernelPasses.td | 11 - include/gc/Transforms/Utils/ValueUtils.h | 3 +- lib/gc/Transforms/Microkernel/CMakeLists.txt | 1 - .../ConvertMicrokernelToDnnlFunc.cpp | 222 ------------------ lib/gc/Transforms/Utils/ValueUtils.cpp | 29 --- .../Microkernel/microkernel-to-dnnl-func.mlir | 70 ------ .../test/gc/cpu-runner/brgemm-parallel.mlir | 50 ---- 7 files changed, 2 insertions(+), 384 deletions(-) delete mode 100644 lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp delete mode 100644 test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir delete mode 100644 test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td index bf9e3c61d..16e11532b 100644 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.td +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.td @@ -40,15 +40,4 @@ def ConvertLinalgToMicrokernel: Pass<"convert-linalg-to-microkernel", "::mlir::f "microkernel::MicrokernelDialect"]; } -def ConvertMicrokernelToDnnlFunc: Pass<"convert-microkernel-to-dnnl-func", "::mlir::ModuleOp"> { - let summary = "Lower microkernel dialects to dnnl func call"; - let description = [{ - Convert microkernel dialects to runtime function call to oneDNN library. - }]; - let dependentDialects = ["func::FuncDialect", - "memref::MemRefDialect", - "LLVM::LLVMDialect", - "microkernel::MicrokernelDialect"]; -} - #endif // GC_DIALECT_MICROKERNELPASSES diff --git a/include/gc/Transforms/Utils/ValueUtils.h b/include/gc/Transforms/Utils/ValueUtils.h index 07013bde4..acffd5642 100644 --- a/include/gc/Transforms/Utils/ValueUtils.h +++ b/include/gc/Transforms/Utils/ValueUtils.h @@ -27,7 +27,8 @@ FailureOr> getStaticStrides(Value val); // Return the offset and ptr for `val`. Assert if `val` // is not a memref. -std::pair getPtrAndOffset(OpBuilder &builder, Value operand); +std::pair getPtrAndOffset(OpBuilder &builder, Value val, + Location loc); } // namespace utils } // namespace mlir diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt index e33db7185..642eaa6ca 100644 --- a/lib/gc/Transforms/Microkernel/CMakeLists.txt +++ b/lib/gc/Transforms/Microkernel/CMakeLists.txt @@ -4,7 +4,6 @@ include(onednn) gc_add_mlir_dialect_library(MLIRMicrokernelTransforms ConvertLinalgToMicrokernel.cpp - ConvertMicrokernelToDnnlFunc.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ diff --git a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp deleted file mode 100644 index 966f87c06..000000000 --- a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp +++ /dev/null @@ -1,222 +0,0 @@ -//===-- ConvertMicrokernelToDnnlFunc.cpp - Lower to dnnl funcs --*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" -#include "gc/Transforms/Microkernel/MicrokernelPasses.h" -#include "gc/Transforms/Utils/ValueUtils.h" - -namespace mlir::microkernel { -#define GEN_PASS_DEF_CONVERTMICROKERNELTODNNLFUNC -#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" - -#define DEBUG_TYPE "convert-microkernel-to-dnnl-func" - -static func::CallOp createFuncCall(RewriterBase &rewriter, Location loc, - ModuleOp module, const std::string &funcName, - ArrayRef operands, - ArrayRef operandTypes, - ArrayRef resultTypes) { - FlatSymbolRefAttr fnName = SymbolRefAttr::get(module->getContext(), funcName); - auto fnType = rewriter.getFunctionType(operandTypes, resultTypes); - - if (!module.lookupSymbol(fnName.getAttr())) { - OpBuilder::InsertionGuard guard(rewriter); - // Insert before module terminator. - rewriter.setInsertionPoint(module.getBody(), - std::prev(module.getBody()->end())); - func::FuncOp funcOp = - rewriter.create(loc, fnName.getValue(), fnType); - funcOp.setPrivate(); - } - - func::CallOp call = rewriter.create(loc, fnName.getValue(), - resultTypes, operands); - return call; -} - -class ConvertBrgemmDispatchOpRewriter - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - // runtime func for dnnl brgemm dispatch: - // int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, - // int64_t LDB, int64_t LDC, int64_t stride_a, int64_t stride_b, float beta, - // int64_t dtypeA, int64_t dtypeB); - LogicalResult matchAndRewrite(microkernel::BrgemmDispatchOp op, - PatternRewriter &rewriter) const final { - Location loc = op.getLoc(); - ModuleOp module = op->template getParentOfType(); - - SmallVector operands; - SmallVector operandTypes; - IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); - FloatType float32 = FloatType::getF32(rewriter.getContext()); - - // M, N, K, LDA, LDB, LDC, stride_a, stride_b - // they are in the same order with BrgemmDispatchOp inputs - ArrayRef inputs = op.getInputsAttr().asArrayRef(); - for (auto input : inputs) { - auto attr = IntegerAttr::get(rewriter.getI64Type(), input); - operands.push_back( - rewriter.create(loc, integer64, attr)); - operandTypes.push_back(integer64); - } - - // beta - auto flags = op.getFlagsAttr(); - float beta = 1.0f; - for (auto flag : flags) { - auto brgemmFlag = dyn_cast_or_null(flag); - if (!brgemmFlag) - return rewriter.notifyMatchFailure(op, "unknown flag for BRGEMM"); - if (brgemmFlag.getValue() == BrgemmFlags::LIST) - return rewriter.notifyMatchFailure( - op, "addr mode BRGEMM not supported yet"); - if (brgemmFlag.getValue() == BrgemmFlags::BETA_0) - beta = 0.0f; - } - auto betaAttr = FloatAttr::get(rewriter.getF32Type(), beta); - operands.push_back( - rewriter.create(loc, float32, betaAttr)); - operandTypes.push_back(float32); - - // dtypeA, dtypeB - auto dtypes = op.getDataType(); - if (dtypes.size() != 2) - return rewriter.notifyMatchFailure( - op, "invalid number of DataType for BRGEMM"); - auto dtypeAAttr = IntegerAttr::get(rewriter.getI64Type(), - getDnnlDataTypeVal(rewriter, dtypes[0])); - auto dtypeBAttr = IntegerAttr::get(rewriter.getI64Type(), - getDnnlDataTypeVal(rewriter, dtypes[1])); - operands.push_back( - rewriter.create(loc, integer64, dtypeAAttr)); - operandTypes.push_back(integer64); - operands.push_back( - rewriter.create(loc, integer64, dtypeBAttr)); - operandTypes.push_back(integer64); - - func::CallOp call = - createFuncCall(rewriter, loc, module, DNNL_BRGEMM_DISPATCH_NAME, - operands, operandTypes, {integer64}); - rewriter.replaceOp(op, call.getResult(0)); - return success(); - } -}; - -class ConvertBrgemmPrologueOpRewriter - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - // dnnl runtime func for brgemm set hw context: - // void dnnl_brgemm_tileconfig(int64_t kernel_idx); - LogicalResult matchAndRewrite(microkernel::BrgemmPrologueOp op, - PatternRewriter &rewriter) const final { - Location loc = op.getLoc(); - ModuleOp module = op->template getParentOfType(); - IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); - func::CallOp call = - createFuncCall(rewriter, loc, module, DNNL_BRGEMM_TILECFG_NAME, - op.getInputs(), {integer64}, {}); - rewriter.replaceOp(op, call); - return success(); - } -}; - -class ConvertBrgemmOpRewriter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - // runtime func for stride mode dnnl brgemm execution: - // void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void - // *B, uint64_t B_offset, void *C, uint64_t C_offset, int num) - LogicalResult matchAndRewrite(microkernel::BrgemmOp op, - PatternRewriter &rewriter) const final { - // currently only support stride mode, directly call it - // TODO(haixin): support addr mode execution, through detecting dispatch - // target - - auto context = rewriter.getContext(); - Location loc = op.getLoc(); - ModuleOp module = op->template getParentOfType(); - - SmallVector operands; - SmallVector operandTypes; - - auto raw_operands = op->getOperands(); - size_t raw_op_cnt = 0; - for (Value operand : raw_operands) { - if (raw_op_cnt++ >= 5) { - // drop the last operand for `addr list length` - break; - } - Type operandType = operand.getType(); - if (auto memrefType = dyn_cast(operandType)) { - Type basePtrType = LLVM::LLVMPointerType::get(context); - auto [ptr, offset] = utils::getPtrAndOffset(rewriter, operand); - operands.push_back(ptr); - operands.push_back(offset); - operandTypes.push_back(basePtrType); - operandTypes.push_back(rewriter.getIndexType()); // offset - } else { - operands.push_back(operand); - operandTypes.push_back(operand.getType()); - } - } - - createFuncCall(rewriter, loc, module, DNNL_BRGEMM_EXECUTE_NAME, operands, - operandTypes, {}); - rewriter.eraseOp(op); - return success(); - } -}; - -class ConvertBrgemmEpilogueOpRewriter - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - // dnnl runtime func for brgemm release hw context: - // void dnnl_brgemm_tilerelease(); - LogicalResult matchAndRewrite(microkernel::BrgemmEpilogueOp op, - PatternRewriter &rewriter) const final { - Location loc = op.getLoc(); - ModuleOp module = op->template getParentOfType(); - func::CallOp call = createFuncCall( - rewriter, loc, module, DNNL_BRGEMM_TILERELEASE_NAME, {}, {}, {}); - rewriter.replaceOp(op, call); - return success(); - } -}; - -class ConvertMicrokernelToDnnlFunc - : public impl::ConvertMicrokernelToDnnlFuncBase< - ConvertMicrokernelToDnnlFunc> { -public: - using impl::ConvertMicrokernelToDnnlFuncBase< - ConvertMicrokernelToDnnlFunc>::ConvertMicrokernelToDnnlFuncBase; - void runOnOperation() final { - RewritePatternSet patterns(&getContext()); - patterns - .add( - &getContext()); - - FrozenRewritePatternSet patternSet(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) - signalPassFailure(); - } -}; - -} // namespace mlir::microkernel diff --git a/lib/gc/Transforms/Utils/ValueUtils.cpp b/lib/gc/Transforms/Utils/ValueUtils.cpp index 8750042ee..d565d0cf8 100644 --- a/lib/gc/Transforms/Utils/ValueUtils.cpp +++ b/lib/gc/Transforms/Utils/ValueUtils.cpp @@ -116,34 +116,5 @@ FailureOr> getStaticStrides(Value value) { return strides; } -std::pair getPtrAndOffset(OpBuilder &builder, Value operand) { - auto memrefType = dyn_cast(operand.getType()); - assert(memrefType && "Expect a memref value"); - - Location loc = operand.getDefiningOp()->getLoc(); - OpBuilder::InsertionGuard guard(builder); - // Insert right after operand producer for better opt chances. - builder.setInsertionPointAfterValue(operand); - - MemRefType baseMemrefType = MemRefType::get({}, memrefType.getElementType()); - Type basePtrType = builder.getIndexType(); - Type offsetType = builder.getIndexType(); - SmallVector sizesTypes(memrefType.getRank(), offsetType); - SmallVector stridesTypes(memrefType.getRank(), offsetType); - auto meta = builder.create( - loc, baseMemrefType, offsetType, sizesTypes, stridesTypes, operand); - Value alignedPointerAsIndex = - builder.create(loc, basePtrType, - operand); - Value alignedPointerAsI64 = builder.create( - loc, builder.getIntegerType(64), alignedPointerAsIndex); - // TODO: non-POD will require an LLVMTypeConverter. - Value alignedPointer = builder.create( - loc, LLVM::LLVMPointerType::get(builder.getContext()), - alignedPointerAsI64); - Value offset = meta.getOffset(); - return std::make_pair(alignedPointer, offset); -} - } // namespace utils } // namespace mlir diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir deleted file mode 100644 index 1520ae069..000000000 --- a/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir +++ /dev/null @@ -1,70 +0,0 @@ -// RUN: gc-opt %s -convert-microkernel-to-dnnl-func -cse -split-input-file | FileCheck %s - -#map = affine_map<(d0, d1) -> (d0, d1)> -module { - func.func @basic_convert() { - %c0_i64 = arith.constant 0 : i64 - %c16_i64 = arith.constant 16 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.forall (%arg0, %arg1) in (4, 8) { - %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) - %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) - microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%0) : (i64) -> () - %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { - ^bb0(%in: f32, %in_7: f32, %out: f32): - %1 = arith.addf %in, %in_7 : f32 - linalg.yield %1 : f32 - } - %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { - ^bb0(%in: f32, %out: f32): - %1 = arith.maximumf %in, %cst : f32 - linalg.yield %1 : f32 - } - memref.dealloc %alloc_3 : memref<32x32xf32> - } - return - } -} - -// CHECK-LABEL: dnnl_brgemm_execute -// CHECK-LABEL: dnnl_brgemm_dispatch -// CHECK-LABEL: basic_convert -// CHECK: %[[CST3:.+]] = arith.constant 3 : i64 -// CHECK: %[[CST1F:.+]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[CST1024:.+]] = arith.constant 1024 : i64 -// CHECK: %[[CST32:.+]] = arith.constant 32 : i64 -// CHECK: %[[CST0:.+]] = arith.constant 0 : index -// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 - -// CHECK: %[[ptrC:.+]] = memref.extract_aligned_pointer_as_index %[[memrefC:.+]] : memref<32x32xf32> -> index -// CHECK-NEXT: %[[idxC:.+]] = arith.index_cast %[[ptrC]] : index to i64 -// CHECK-NEXT: %[[llvmptrC:.+]] = llvm.inttoptr %[[idxC]] : i64 to !llvm.ptr - -// CHECK: %[[bbA:.+]], %[[offA:.+]], %[[szA:.+]]:3, %[[strdA:.+]]:3 = memref.extract_strided_metadata %[[memrefA:.+]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index -// CHECK-NEXT: %[[ptrA:.+]] = memref.extract_aligned_pointer_as_index %[[memrefA]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> index -// CHECK-NEXT: %[[idxA:.+]] = arith.index_cast %[[ptrA]] : index to i64 -// CHECK-NEXT: %[[llvmptrA:.+]] = llvm.inttoptr %[[idxA]] : i64 to !llvm.ptr - -// CHECK: %[[bbB:.+]], %[[offB:.+]], %[[szB:.+]]:3, %[[strdB:.+]]:3 = memref.extract_strided_metadata %[[memrefB:.+]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index -// CHECK-NEXT: %[[ptrB:.+]] = memref.extract_aligned_pointer_as_index %[[memrefB]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> index -// CHECK-NEXT: %[[idxB:.+]] = arith.index_cast %[[ptrB]] : index to i64 -// CHECK-NEXT: %[[llvmptrB:.+]] = llvm.inttoptr %[[idxB]] : i64 to !llvm.ptr - -// CHECK: %[[KERNEL:.+]] = func.call @dnnl_brgemm_dispatch(%[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST1024]], %[[CST1024]], %[[CST1F]], %[[CST3]], %[[CST3]]) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 -// CHECK-NOT: microkernel.brgemm.prologue(%[[TMP:.+]]) : (i64) -> () - -// CHECK: func.call @dnnl_brgemm_execute(%[[KERNEL]], %[[llvmptrA]], %[[offA]], %[[llvmptrB]], %[[offB]], %[[llvmptrC]], %[[CST0]], %[[CST16]]) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () -// CHECK-NOT: microkernel.brgemm.epilogue(%[[KERNEL]]) : (i64) -> () - -// ----- diff --git a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir deleted file mode 100644 index ad436da0c..000000000 --- a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir +++ /dev/null @@ -1,50 +0,0 @@ -// RUN: gc-opt %s --convert-microkernel-to-dnnl-func --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void - -#map = affine_map<(d0, d1) -> (d0, d1)> -module { - func.func @simple_brgemm() { - %c0_i64 = arith.constant 0 : i64 - %c16_i64 = arith.constant 16 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc : memref<4x16x32x32xf32>) - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<8x16x32x32xf32>) - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) - %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) - scf.forall (%arg0, %arg1) in (4, 8) { - %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) - %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) - microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%0) : (i64) -> () - %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { - ^bb0(%in: f32, %in_7: f32, %out: f32): - %1 = arith.addf %in, %in_7 : f32 - linalg.yield %1 : f32 - } - %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { - ^bb0(%in: f32, %out: f32): - %1 = arith.maximumf %in, %cst : f32 - linalg.yield %1 : f32 - } - memref.dealloc %alloc_3 : memref<32x32xf32> - } - return - } - - func.func @main() { - call @simple_brgemm() : ()->() - // COM: parallelcpu.printf "BRGEMM DONE\n" - return - } - - // COM: CHECK: BRGEMM DONE -} From c01ad17bf2c03e80fdbc984ce5b0223a64f87b10 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Thu, 8 Aug 2024 23:42:10 -0700 Subject: [PATCH 41/93] remove pass ConvertLinalgToMicrokernel --- include/gc/Transforms/CMakeLists.txt | 2 - .../Microkernel/BrgemmRuntimeUtils.h | 51 --- .../gc/Transforms/Microkernel/CMakeLists.txt | 6 - .../Microkernel/MicrokernelPasses.h | 28 -- .../Microkernel/MicrokernelPasses.td | 43 -- .../gc/Transforms/Utils/StructuredOpMatcher.h | 20 - lib/gc/Dialect/Microkernel/CMakeLists.txt | 1 - lib/gc/Transforms/CMakeLists.txt | 2 - lib/gc/Transforms/Microkernel/CMakeLists.txt | 22 - .../ConvertLinalgToMicrokernel.cpp | 388 ------------------ src/gc-opt/gc-opt.cpp | 2 - .../Microkernel/linalg-to-microkernel.mlir | 173 -------- 12 files changed, 738 deletions(-) delete mode 100644 include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h delete mode 100644 include/gc/Transforms/Microkernel/CMakeLists.txt delete mode 100644 include/gc/Transforms/Microkernel/MicrokernelPasses.h delete mode 100644 include/gc/Transforms/Microkernel/MicrokernelPasses.td delete mode 100644 lib/gc/Transforms/Microkernel/CMakeLists.txt delete mode 100644 lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp delete mode 100644 test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir diff --git a/include/gc/Transforms/CMakeLists.txt b/include/gc/Transforms/CMakeLists.txt index 08443020b..8014cba72 100644 --- a/include/gc/Transforms/CMakeLists.txt +++ b/include/gc/Transforms/CMakeLists.txt @@ -11,5 +11,3 @@ mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix GraphCompiler) mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix GraphCompiler) add_public_tablegen_target(GraphCompilerPassIncGen) add_mlir_doc(Passes GraphCompilerPasses ./ -gen-pass-doc) - -add_subdirectory(Microkernel) diff --git a/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h b/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h deleted file mode 100644 index adb214e10..000000000 --- a/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h +++ /dev/null @@ -1,51 +0,0 @@ -//===-- BrgemmRuntimeUtils.h - Utils for Brgemm Runtime ---------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef GC_MICROKERNEL_BRGEMMRUNTIMEUTILS_H -#define GC_MICROKERNEL_BRGEMMRUNTIMEUTILS_H - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/PatternMatch.h" -#include "oneapi/dnnl/dnnl_types.h" - -namespace mlir::microkernel { - -// these strings contain symbols for BRGEMM interfaces used in mlir pass -static const std::string DNNL_BRGEMM_DISPATCH_NAME = "dnnl_brgemm_dispatch"; -static const std::string DNNL_BRGEMM_TILECFG_NAME = "dnnl_brgemm_tileconfig"; -static const std::string DNNL_BRGEMM_TILERELEASE_NAME = - "dnnl_brgemm_tilerelease"; -static const std::string DNNL_BRGEMM_EXECUTE_NAME = "dnnl_brgemm_execute"; - -static inline int64_t getDnnlDataTypeVal(RewriterBase &rewriter, - Attribute attr) { - auto context = rewriter.getContext(); - auto tattr = dyn_cast_or_null(attr); - assert(tattr); - if (tattr == TypeAttr::get(FloatType::getF32(context))) { - return static_cast(dnnl_f32); - } else if (tattr == TypeAttr::get(FloatType::getBF16(context))) { - return static_cast(dnnl_bf16); - } else if (tattr == TypeAttr::get(FloatType::getF16(context))) { - return static_cast(dnnl_f16); - } else if (tattr == TypeAttr::get( - IntegerType::get(context, 32, IntegerType::Signed))) { - return static_cast(dnnl_s32); - } else if (tattr == - TypeAttr::get(IntegerType::get(context, 8, IntegerType::Signed))) { - return static_cast(dnnl_s8); - } else if (tattr == TypeAttr::get(IntegerType::get(context, 8, - IntegerType::Unsigned))) { - return static_cast(dnnl_u8); - } - return static_cast(dnnl_data_type_undef); -} - -}; // namespace mlir::microkernel - -#endif // GC_MICROKERNEL_BRGEMMRUNTIMEUTILS_H diff --git a/include/gc/Transforms/Microkernel/CMakeLists.txt b/include/gc/Transforms/Microkernel/CMakeLists.txt deleted file mode 100644 index 2e345775c..000000000 --- a/include/gc/Transforms/Microkernel/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS MicrokernelPasses.td) -mlir_tablegen(MicrokernelPasses.h.inc --gen-pass-decls -name Microkernel) -mlir_tablegen(MicrokernelPasses.capi.h.inc -gen-pass-capi-header --prefix Microkernel) -mlir_tablegen(MicrokernelPasses.capi.cpp.inc -gen-pass-capi-impl --prefix Microkernel) -add_public_tablegen_target(MLIRMicrokernelPassesIncGen) -add_mlir_doc(MicrokernelPasses GraphCompilerMicrokernelPasses ./ -gen-pass-doc) diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.h b/include/gc/Transforms/Microkernel/MicrokernelPasses.h deleted file mode 100644 index ee9da8a4e..000000000 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.h +++ /dev/null @@ -1,28 +0,0 @@ -//===- MicrokernelPasses.h - Graph Compiler microkerenl passes --*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef GC_MICROKERNELPASSES_H -#define GC_MICROKERNELPASSES_H - -#include "gc/Dialect/Linalgx/LinalgxDialect.h" -#include "gc/Dialect/Microkernel/MicrokernelDialect.h" -#include "gc/Dialect/Microkernel/MicrokernelOps.h" -#include "mlir/Pass/Pass.h" -#include - -namespace mlir { -namespace microkernel { -#define GEN_PASS_DECL -#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" - -#define GEN_PASS_REGISTRATION -#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" -} // namespace microkernel -} // namespace mlir - -#endif // GC_MICROKERNELPASSES_H diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td deleted file mode 100644 index 16e11532b..000000000 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.td +++ /dev/null @@ -1,43 +0,0 @@ -//===-- MicrokernelPasses.td - microkernel passes ----------*- tablegen -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef GC_DIALECT_MICROKERNELPASSES -#define GC_DIALECT_MICROKERNELPASSES - -include "mlir/Pass/PassBase.td" - -def ConvertLinalgToMicrokernel: Pass<"convert-linalg-to-microkernel", "::mlir::func::FuncOp"> { - let summary = "Lower eligible linalg ops to microkernels"; - let description = [{ - Convert eligible linalg ops to microkernel dialects based on pattern matching. - For example: - ``` - scf.forall { - linalg.fill ins(...) outs(...) -> tensor<...> - linalg.batch_reduce_matmul ins(...) outs(...) -> tensor<...> - } - ``` - Will be changed into - ``` - scf.forall { - linalg.fill ins(...) outs(...) -> tensor<...> - %0 = microkernel.brgemm.dispatch(...) - microkernel.brgemm.prologue(%0) - microkernel.brgemm(%0, ...) - microkernel.brgemm.epilogue(%0) - } - ``` - }]; - let dependentDialects = ["func::FuncDialect", - "memref::MemRefDialect", - "linalg::LinalgDialect", - "linalgx::LinalgxDialect", - "microkernel::MicrokernelDialect"]; -} - -#endif // GC_DIALECT_MICROKERNELPASSES diff --git a/include/gc/Transforms/Utils/StructuredOpMatcher.h b/include/gc/Transforms/Utils/StructuredOpMatcher.h index c931472a1..66bd22b7a 100644 --- a/include/gc/Transforms/Utils/StructuredOpMatcher.h +++ b/include/gc/Transforms/Utils/StructuredOpMatcher.h @@ -217,26 +217,6 @@ template struct EqualsTo { }; template EqualsTo(T) -> EqualsTo; -// Callable object to check if the input is less than or equal to specified -// `value`. -struct LessThanOrEqualTo { - LessThanOrEqualTo() = delete; - explicit LessThanOrEqualTo(size_t value) : value(value){}; - const size_t value; - - bool operator()(size_t value) const { return value <= this->value; } -}; - -// Callable object to check if the input is greater than or equal to specified -// `value`. -struct GreaterThanOrEqualTo { - GreaterThanOrEqualTo() = delete; - explicit GreaterThanOrEqualTo(size_t value) : value(value){}; - const size_t value; - - bool operator()(size_t value) const { return value >= this->value; } -}; - // Callable object to validate number of init operands for `op`. struct NumDpsInits { NumDpsInits() = delete; diff --git a/lib/gc/Dialect/Microkernel/CMakeLists.txt b/lib/gc/Dialect/Microkernel/CMakeLists.txt index 3a9099d4f..33e420f2e 100644 --- a/lib/gc/Dialect/Microkernel/CMakeLists.txt +++ b/lib/gc/Dialect/Microkernel/CMakeLists.txt @@ -10,7 +10,6 @@ gc_add_mlir_dialect_library(MLIRMicrokernel DEPENDS MLIRMicrokernelOpsIncGen - MLIRMicrokernelPassesIncGen LINK_LIBS PUBLIC ${MLIR_LINK_COMPONENTS} diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index a5388381b..d240f28c1 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -31,5 +31,3 @@ gc_add_mlir_library(GcPasses if(GC_ENABLE_IMEX) add_subdirectory(GPU) endif() - -add_subdirectory(Microkernel) diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt deleted file mode 100644 index 642eaa6ca..000000000 --- a/lib/gc/Transforms/Microkernel/CMakeLists.txt +++ /dev/null @@ -1,22 +0,0 @@ -gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR) - -include(onednn) - -gc_add_mlir_dialect_library(MLIRMicrokernelTransforms - ConvertLinalgToMicrokernel.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/ - - DEPENDS - MLIRMicrokernelPassesIncGen - - LINK_LIBS PUBLIC - ${MLIR_LINK_COMPONENTS} - GcInterface - GcUtilsIR - ) - -get_property(GC_DNNL_INCLUDES GLOBAL PROPERTY GC_DNNL_INCLUDES) - -target_include_directories(MLIRMicrokernelTransforms PUBLIC ${GC_DNNL_INCLUDES}) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp deleted file mode 100644 index 45380a19e..000000000 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ /dev/null @@ -1,388 +0,0 @@ -//===-- ConvertLinalgToMicrokernel.cpp - Linalg To Microkernel --*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/AffineExprVisitor.h" -#include "mlir/IR/AffineMap.h" -#include "llvm/ADT/SetOperations.h" -#include "llvm/ADT/SmallBitVector.h" -#include "llvm/ADT/SmallVector.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "gc/Dialect/Linalgx/LinalgxOps.h" -#include "gc/Transforms/Microkernel/MicrokernelPasses.h" -#include "gc/Transforms/Utils/StructuredOpMatcher.h" -#include "gc/Transforms/Utils/ValueUtils.h" - -namespace mlir::microkernel { -#define GEN_PASS_DEF_CONVERTLINALGTOMICROKERNEL -#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" - -#define DEBUG_TYPE "convert-linalg-to-microkernel" - -struct BrgemmInfo { - enum BrgemmMode { STRIDE_MODE, LIST_MODE }; - int64_t m; - int64_t n; - int64_t k; - int64_t batchSize; - int64_t addrLen; - - int64_t lda; - int64_t ldb; - int64_t ldc; - int64_t strideA; - int64_t strideB; - - bool isInitOutput; - BrgemmMode mode; -}; - -FailureOr -customInferContractionDims(linalg::LinalgOp linalgOp) { - auto dims = linalg::inferContractionDims(linalgOp); - if (failed(dims)) - return failure(); - if (llvm::isa(linalgOp)) { - // For VnniOp, the K reduction dims (dim index 3 & 4) cannot be infered by - // linalg utils because they form complex affine in operand A; Manually add - // them here - dims->k.push_back(3); - dims->k.push_back(4); - } - return dims; -} - -static bool isMatchingAffineResult(linalg::LinalgOp linalgOp, AffineExpr expr, - ArrayRef dimPos) { - if (dimPos.size() > 2) { - return false; - } - auto firstDim = getAffineDimExpr(dimPos[0], linalgOp.getContext()); - if (dimPos.size() == 1) - return firstDim == expr; - - // If not regular dim affine, check for VNNI format K affine - auto secondKPosDim = getAffineDimExpr(dimPos[1], linalgOp.getContext()); - // An K affine result for VNNI should be this format: - // d{kPos[0]} * s{kPos[1]} + d{kPos[1]} (k0 * K_vnni + k1) - auto add = dyn_cast(expr); - if (!add) - return false; - if (add.getKind() != AffineExprKind::Add) - return false; - auto lhs = add.getLHS(); - auto rhs = add.getRHS(); - if (rhs != secondKPosDim) - return false; - auto mul = dyn_cast(lhs); - if (!mul || mul.getKind() != AffineExprKind::Mul || mul.getLHS() != firstDim) - return false; - - auto cst_affine = dyn_cast(mul.getRHS()); - return cst_affine && - (cst_affine.getValue() == 2 || cst_affine.getValue() == 4); -} - -// Return the position of `dim` in the codomain of `operand`. -static std::optional getPosInCodomain(ArrayRef dimPos, - OpOperand *operand, - linalg::LinalgOp linalgOp) { - assert(operand->getOwner() == linalgOp); - auto map = linalgOp.getMatchingIndexingMap(operand); - for (unsigned i = 0, numResults = map.getNumResults(); i < numResults; i++) { - if (isMatchingAffineResult(linalgOp, map.getResult(i), dimPos)) - return i; - } - return std::nullopt; -} - -static FailureOr -inferBrgemmInfo(linalg::LinalgOp linalgOp, - const linalg::ContractionDimensions &dims) { - unsigned mPos = dims.m[0]; - unsigned nPos = dims.n[0]; - // dims.k could be of 2 cases: - // 1. dims.k.size() == 2: non-VNNI, K = dims.k[1] - // 2. dims.k.size() == 3: VNNI, K = dims.k[1] * dims.k[2] - unsigned batchPos = dims.k.front(); - SmallVector kPos; - if (dims.k.size() == 2) { - kPos = {dims.k[1]}; - } else if (dims.k.size() == 3) { - kPos = {dims.k[1], dims.k[2]}; - } else { - return failure(); - } - - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Candidate dims: " - << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] m pos in affine: " << mPos - << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] n pos in affine: " << nPos - << "\n"); - for (auto kp : kPos) { - LLVM_DEBUG(llvm::dbgs() - << "[inferBrgemmInfo] k pos in affine: " << kp << "\n"); - } - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] batch pos in affine: " - << batchPos << "\n"); - - auto checkStridesAndGetLda = - [&](ArrayRef minorDim, ArrayRef majorDim, - OpOperand *operand, bool allowVnni) -> FailureOr { - auto minorDimPosInCodomain = getPosInCodomain(minorDim, operand, linalgOp); - auto majorDimPosInCodomain = getPosInCodomain(majorDim, operand, linalgOp); - if (!minorDimPosInCodomain || !majorDimPosInCodomain) - return failure(); - auto stridesOnOperand = utils::getStaticStrides(operand->get()); - if (failed(stridesOnOperand)) - return failure(); - auto minorDimLd = (*stridesOnOperand)[*minorDimPosInCodomain]; - auto majorDimLd = (*stridesOnOperand)[*majorDimPosInCodomain]; - if (minorDimLd != 1) { - // VNNI format exists, special treatment to align LD with non-VNNI format - if (!allowVnni || (minorDimLd != 2 && minorDimLd != 4)) - return failure(); - return majorDimLd / minorDimLd; - } - return majorDimLd; - }; - - OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; - OpOperand *operandB = linalgOp.getDpsInputOperands()[1]; - OpOperand *operandC = &linalgOp.getDpsInitsMutable()[0]; - - // A(m, k) - auto lda = checkStridesAndGetLda(kPos, {mPos}, operandA, false); - if (failed(lda)) - return failure(); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on A: OK\n"); - - // B(k, n) - // note: B does not use VNNI format K affine - auto ldb = checkStridesAndGetLda({nPos}, {kPos[0]}, operandB, true); - if (failed(ldb)) - return failure(); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on B: OK\n"); - - // C(m, n) - auto ldc = checkStridesAndGetLda({nPos}, {mPos}, operandC, false); - if (failed(ldc)) - return failure(); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on C: OK\n"); - - int64_t strideA = 1; - int64_t strideB = 1; - auto batchPosCodomainA = getPosInCodomain(batchPos, operandA, linalgOp); - auto stridesOnA = utils::getStaticStrides(operandA->get()); - strideA = (*stridesOnA)[*batchPosCodomainA]; - - auto batchPosCodomainB = getPosInCodomain(batchPos, operandB, linalgOp); - auto stridesOnB = utils::getStaticStrides(operandB->get()); - strideB = (*stridesOnB)[*batchPosCodomainB]; - - auto loops = linalgOp.computeStaticLoopSizes(); - auto kSize = - kPos.size() == 1 ? loops[kPos[0]] : (loops[kPos[0]] * loops[kPos[1]]); - - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] final BrgemmInfo: m(" - << loops[mPos] << "), n(" << loops[nPos] << "), k(" - << kSize << "), batch(" << loops[batchPos] - << "), lda(" << *lda << "), ldb(" << *ldb << "), ldc(" - << *ldc << "), strideA(" << strideA << "), strideB(" - << strideB << ")\n"); - BrgemmInfo info{loops[mPos], - loops[nPos], - kSize, - loops[batchPos], - 0 /* addrLen useless under stride mode */, - *lda, - *ldb, - *ldc, - strideA, - strideB, - false, - BrgemmInfo::STRIDE_MODE}; - return info; -} - -static FailureOr getBrgemmInfo(linalg::LinalgOp linalgOp) { - using namespace mlir::structured_match; - auto validBrgemmMatcher = StructuredOpMatcher::make() - .output(MatchAll(), HasStaticShape()) - .input(MatchAll(), HasStaticShape()) - .output(MatchAll(), HasStaticStrides()) - .input(MatchAll(), HasStaticStrides()) - .operation(NumOfLoops(GreaterThanOrEqualTo(3))); - // clang-format on - if (!validBrgemmMatcher.match(linalgOp)) - return failure(); - - auto contractionDims = customInferContractionDims(linalgOp); - if (failed(contractionDims)) { - LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Not a valid contraction\n"); - return failure(); - } - if (contractionDims->m.size() != 1 || contractionDims->n.size() != 1 || - // batch-reduce dim for BRGEMM should be identified as one of k dim - // including VNNI & non-VNNI cases - (contractionDims->k.size() != 2 && contractionDims->k.size() != 3) || - !contractionDims->batch.empty()) { - LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions\n"); - LLVM_DEBUG(llvm::dbgs() - << "[checkStructure] " << contractionDims->m.size() << " " - << contractionDims->n.size() << " " << contractionDims->k.size() - << " " << contractionDims->batch.size() << "\n"); - return failure(); - } - unsigned classifiedLoops = - contractionDims->m.size() + contractionDims->n.size() + - contractionDims->k.size() + contractionDims->batch.size(); - if (linalgOp.getNumLoops() != classifiedLoops) { - LLVM_DEBUG(llvm::dbgs() - << "[checkStructure] Not all loops are classified\n"); - return failure(); - } - - return inferBrgemmInfo(linalgOp, *contractionDims); -} - -// Replace linalgOp with a set of microkernel ops -static void replaceOpWithMicrokernelOpSet(PatternRewriter &rewriter, - linalg::LinalgOp linalgOp, - const BrgemmInfo &info) { - assert(linalgOp.getDpsInputs().size() == 2); - OpBuilder::InsertionGuard guard(rewriter); - - IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); - Location loc = linalgOp.getLoc(); - SmallVector brgemmFlags; - if (info.isInitOutput) { - brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( - rewriter.getContext(), microkernel::BrgemmFlags::BETA_0)); - } - if (info.mode == BrgemmInfo::STRIDE_MODE) { - brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( - rewriter.getContext(), microkernel::BrgemmFlags::STRIDE)); - } else if (info.mode == BrgemmInfo::LIST_MODE) { - brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( - rewriter.getContext(), microkernel::BrgemmFlags::LIST)); - } - - SmallVector brgemmDtypes{ - TypeAttr::get(getElementTypeOrSelf(linalgOp.getDpsInputs()[0].getType())), - TypeAttr::get( - getElementTypeOrSelf(linalgOp.getDpsInputs()[1].getType()))}; - - // create dispatch op - auto flags = rewriter.getArrayAttr(brgemmFlags); - auto dtypes = rewriter.getArrayAttr(brgemmDtypes); - DenseI64ArrayAttr dims = DenseI64ArrayAttr::get( - rewriter.getContext(), - ArrayRef{info.m, info.n, info.k, info.lda, info.ldb, info.ldc, - info.strideA, info.strideB}); - Value dispatched = rewriter.create( - loc, integer64, dims, flags, dtypes); - - // create prologue op - rewriter.create(loc, dispatched); - - // create brgemm invoke op - Value batchDim = rewriter.create( - loc, integer64, rewriter.getIntegerAttr(integer64, info.batchSize)); - Value lenDim = rewriter.create( - loc, integer64, rewriter.getIntegerAttr(integer64, info.addrLen)); - SmallVector invokeOperands; - invokeOperands.push_back(dispatched); - invokeOperands.append(linalgOp->getOperands().begin(), - linalgOp->getOperands().end()); - invokeOperands.push_back(batchDim); - invokeOperands.push_back(lenDim); - rewriter.create(loc, invokeOperands); - - // create epilogue op & replace original op - rewriter.replaceOpWithNewOp(linalgOp, - dispatched); -} - -bool isZeroArithConstant(arith::ConstantOp op) { - if (!op) - return false; - - if (auto intAttr = llvm::dyn_cast(op.getValue())) { - if (intAttr.getInt() != 0) - return false; - } else if (auto floatAttr = llvm::dyn_cast(op.getValue())) { - if (!floatAttr.getValue().isZero()) - return false; - } else - return false; - - return true; -} - -template -class ConvertContractionOpToBrgemmRewriter - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ContractionOp op, - PatternRewriter &rewriter) const final { - auto brgemmInfo = getBrgemmInfo(op); - if (failed(brgemmInfo)) - return failure(); - // Check for immediately preceding linalg::FillOp - Operation *rawOp = op; - auto block = rawOp->getBlock(); - auto opIter = Block::iterator(rawOp); - if (block->begin() != opIter) { - auto prevOp = &(*(--opIter)); - if (auto fillOp = dyn_cast(prevOp)) { - auto inputCst = dyn_cast_or_null( - fillOp.getInputs()[0].getDefiningOp()); - auto fillOperand = fillOp.getOutputs()[0]; - auto contractionOperand = op.getOutputs()[0]; - if (isZeroArithConstant(inputCst) && - contractionOperand == fillOperand) { - brgemmInfo->isInitOutput = true; - rewriter.eraseOp(prevOp); - } - } - } - replaceOpWithMicrokernelOpSet(rewriter, op, *brgemmInfo); - return success(); - } -}; - -class ConvertLinalgToMicrokernel - : public impl::ConvertLinalgToMicrokernelBase { -public: - using impl::ConvertLinalgToMicrokernelBase< - ConvertLinalgToMicrokernel>::ConvertLinalgToMicrokernelBase; - void runOnOperation() final { - RewritePatternSet patterns(&getContext()); - patterns - .add>( - &getContext()); - patterns.add< - ConvertContractionOpToBrgemmRewriter>( - &getContext()); - FrozenRewritePatternSet patternSet(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) - signalPassFailure(); - } -}; - -} // namespace mlir::microkernel diff --git a/src/gc-opt/gc-opt.cpp b/src/gc-opt/gc-opt.cpp index 89bacbdd8..5754f0269 100644 --- a/src/gc-opt/gc-opt.cpp +++ b/src/gc-opt/gc-opt.cpp @@ -23,7 +23,6 @@ #ifdef GC_HAS_ONEDNN_DIALECT #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #endif -#include "gc/Transforms/Microkernel/MicrokernelPasses.h" #include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" @@ -52,7 +51,6 @@ int main(int argc, char *argv[]) { mlir::gc::registerCPUPipeline(); mlir::gc::registerGraphCompilerPasses(); mlir::cpuruntime::registerCPURuntimePasses(); - mlir::microkernel::registerMicrokernelPasses(); mlir::DialectRegistry registry; #ifdef GC_HAS_ONEDNN_DIALECT diff --git a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir deleted file mode 100644 index 224329f62..000000000 --- a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir +++ /dev/null @@ -1,173 +0,0 @@ -// RUN: gc-opt %s -convert-linalg-to-microkernel -split-input-file | FileCheck %s - -#map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @basic_linalg_to_microkernel() { - %cst = arith.constant 0.000000e+00 : f32 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> - %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.forall (%arg7, %arg8) in (4, 8) { - %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - linalg.batch_reduce_matmul ins(%subview, %subview_11 : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) - %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { - ^bb0(%in: f32, %in_14: f32, %out: f32): - %0 = arith.addf %in, %in_14 : f32 - linalg.yield %0 : f32 - } - %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - memref.dealloc %alloc_10 : memref<32x32xf32> - } - return -} - -// CHECK-LABEL: basic_linalg_to_microkernel -// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 -// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 -// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> -// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) -// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () -// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () -// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () - -// ----- - -#map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @vnni_linalg_to_microkernel() { - %cst = arith.constant 0.000000e+00 : f32 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> - %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.forall (%arg7, %arg8) in (4, 8) { - %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - linalgx.batch_reduce_matmul_vnni ins(%subview, %subview_11 : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) - %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { - ^bb0(%in: f32, %in_14: f32, %out: f32): - %0 = arith.addf %in, %in_14 : f32 - linalg.yield %0 : f32 - } - %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - memref.dealloc %alloc_10 : memref<32x32xf32> - } - return -} - -// CHECK-LABEL: vnni_linalg_to_microkernel -// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 -// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 -// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> -// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) -// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () -// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () -// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () - -// ----- - -#map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @basic_linalg_to_microkernel_fusing_fill() { - %cst = arith.constant 0.000000e+00 : f32 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> - %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.forall (%arg7, %arg8) in (4, 8) { - %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - linalg.fill ins(%cst : f32) outs(%alloc_10 : memref<32x32xf32>) - linalg.batch_reduce_matmul ins(%subview, %subview_11 : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) - %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { - ^bb0(%in: f32, %in_14: f32, %out: f32): - %0 = arith.addf %in, %in_14 : f32 - linalg.yield %0 : f32 - } - %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - memref.dealloc %alloc_10 : memref<32x32xf32> - } - return -} - -// CHECK-LABEL: basic_linalg_to_microkernel_fusing_fill -// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 -// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 -// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> -// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK-NOT: linalg.fill -// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (f32, f32) -// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () -// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () -// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () - -// ----- - -#map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @vnni_linalg_to_microkernel_fusing_fill() { - %cst = arith.constant 0.000000e+00 : f32 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> - %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.forall (%arg7, %arg8) in (4, 8) { - %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - linalg.fill ins(%cst : f32) outs(%alloc_10 : memref<32x32xf32>) - linalgx.batch_reduce_matmul_vnni ins(%subview, %subview_11 : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) - %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { - ^bb0(%in: f32, %in_14: f32, %out: f32): - %0 = arith.addf %in, %in_14 : f32 - linalg.yield %0 : f32 - } - %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - memref.dealloc %alloc_10 : memref<32x32xf32> - } - return -} - -// CHECK-LABEL: vnni_linalg_to_microkernel_fusing_fill -// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 -// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 -// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> -// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -// CHECK-NOT: linalg.fill -// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) -// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () -// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () -// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () - -// ----- From 59a63665b34d9aeb2817a3fb4a4419e50741eeda Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Thu, 8 Aug 2024 23:44:47 -0700 Subject: [PATCH 42/93] add cmake error message --- cmake/onednn_lite_config.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/onednn_lite_config.cmake b/cmake/onednn_lite_config.cmake index 74e8548a2..d1df62836 100644 --- a/cmake/onednn_lite_config.cmake +++ b/cmake/onednn_lite_config.cmake @@ -3,7 +3,7 @@ include_guard() get_property(DNNL_INCLUDES GLOBAL PROPERTY GC_DNNL_INCLUDES) get_property(DNNL_PATH GLOBAL PROPERTY GC_DNNL_SOURCE_DIR) if (NOT DEFINED DNNL_INCLUDES) - return() + message(FATAL_ERROR "DNNL contents not fetched yet, CMake will exit." ) endif () ########## This cmake build lite version of onednn, containing only microkernel related codes From 18ff855176087c93f62f96053777fae391468cfc Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Fri, 9 Aug 2024 00:44:11 -0700 Subject: [PATCH 43/93] use rw lock --- .../CPURuntime/Microkernel/BrgemmNaive.cpp | 10 ++++++---- .../CPURuntime/Microkernel/BrgemmOnednn.cpp | 14 +++++++++----- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp index 75fddea27..c4aafa544 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp @@ -8,7 +8,7 @@ #include #include -#include +#include #include #include #include @@ -177,7 +177,9 @@ static void naive_brgemm_execute_int8(const brgemm_params_t ¶ms, void *A, } } -static std::mutex g_brgemm_mutex; +using read_lock_gurad_t = std::shared_lock; +using write_lock_gurad_t = std::unique_lock; +static std::shared_mutex g_brgemm_lock; static std::vector g_brgemm_list; extern "C" { @@ -186,7 +188,7 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, int64_t stride_a, int64_t stride_b, float beta, int64_t dtypeA, int64_t dtypeB) { - std::lock_guard g(g_brgemm_mutex); + write_lock_guard_t g(g_brgemm_lock); // simply store the given parameters for naive BRGEMM g_brgemm_list.emplace_back(brgemm_params_t(M, N, K, LDA, LDB, LDC, stride_a, stride_b, beta, dtypeA, dtypeB)); @@ -202,7 +204,7 @@ void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void *B, int num) { brgemm_params_t params; { - std::lock_guard g(g_brgemm_mutex); + read_lock_guard_t g(g_brgemm_lock); assert(kernel >= 0 && kernel < (int64_t)g_brgemm_list.size() && "Invalid kernel handler"); params = g_brgemm_list[kernel]; diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index eb9bc6ba6..751ead17f 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include #include @@ -46,7 +46,11 @@ __attribute__((weak)) void print_verbose_header() {} } // namespace dnnl static constexpr int PALETTE_SIZE = 64; -static std::mutex g_brgemm_mutex; + +using read_lock_guard_t = std::shared_lock; +using write_lock_guard_t = std::unique_lock; +static std::shared_mutex g_brgemm_lock; + static std::vector g_brgemm_desc_list; static std::vector g_brgemm_kernel_list; static std::vector> g_brgemm_palette; @@ -95,7 +99,7 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, "Failed to initialize palette for BRGEMM"); } - std::lock_guard g(g_brgemm_mutex); + write_lock_guard_t g(g_brgemm_lock); g_brgemm_desc_list.push_back(desc); g_brgemm_kernel_list.push_back(kernel); g_brgemm_palette.emplace_back(palette_buffer); @@ -106,7 +110,7 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, void dnnl_brgemm_tileconfig(int64_t kernel_idx) { char *palette_buffer = nullptr; { - std::lock_guard g(g_brgemm_mutex); + read_lock_guard_t g(g_brgemm_lock); assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() && "Invalid kernel handler"); brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx]; @@ -136,7 +140,7 @@ void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset, size_t B_offset_in_bytes; size_t C_offset_in_bytes; { - std::lock_guard g(g_brgemm_mutex); + read_lock_guard_t g(g_brgemm_lock); assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() && "Invalid kernel handler"); From d7e1509769c95afdb54ecb51dd18310c0977fca9 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Fri, 9 Aug 2024 01:03:26 -0700 Subject: [PATCH 44/93] fix naive lock --- .../ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp index c4aafa544..7ed688917 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -177,8 +178,8 @@ static void naive_brgemm_execute_int8(const brgemm_params_t ¶ms, void *A, } } -using read_lock_gurad_t = std::shared_lock; -using write_lock_gurad_t = std::unique_lock; +using read_lock_guard_t = std::shared_lock; +using write_lock_guard_t = std::unique_lock; static std::shared_mutex g_brgemm_lock; static std::vector g_brgemm_list; From 8e631c743a9f9cc607a0ed6cbfec9279e852885a Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 13 Aug 2024 20:45:45 -0700 Subject: [PATCH 45/93] add ut for brgemm runtime --- .../CPURuntime/Microkernel/BrgemmInterface.h | 49 +++++ .../ExecutionEngine/CPURuntime/CMakeLists.txt | 8 +- .../CPURuntime/Microkernel/BrgemmNaive.cpp | 80 +++----- .../CPURuntime/Microkernel/BrgemmOnednn.cpp | 4 + .../ExecutionEngine/BrgemmRuntime.cpp | 182 ++++++++++++++++++ .../unittests/ExecutionEngine/CMakeLists.txt | 2 + 6 files changed, 272 insertions(+), 53 deletions(-) create mode 100644 test/mlir/unittests/ExecutionEngine/BrgemmRuntime.cpp diff --git a/include/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmInterface.h b/include/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmInterface.h index ce2f19085..cb2e08093 100644 --- a/include/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmInterface.h +++ b/include/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmInterface.h @@ -9,7 +9,10 @@ #ifndef GC_EXECUTIONENGINE_CPURUNTIME_MICROKERNEL_BRGEMMINTERFACE_H #define GC_EXECUTIONENGINE_CPURUNTIME_MICROKERNEL_BRGEMMINTERFACE_H +#include + extern "C" { +// Runtime interfaces /** * Dispatch (JIT) the Brgemm kernel based on given parameters using DNNL @@ -56,4 +59,50 @@ void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void *B, int num); } +struct bf16_t { + uint16_t storage_; + union caster_t { + uint32_t vl; + float vf; + }; + operator float() const { + caster_t val; + val.vl = uint32_t(storage_) << 16; + return val.vf; + } + bool operator==(const bf16_t &compare_to) const { + return storage_ == compare_to.storage_; + } + bool operator!=(const bf16_t &compare_to) const { + return storage_ != compare_to.storage_; + } + bf16_t(float v) { + if (std::isnan(v)) { + storage_ = UINT32_C(0x7FC0); + } else { + caster_t caster; + caster.vf = v; + uint32_t rounding_bias = ((caster.vl >> 16) & 1) + UINT32_C(0x7FFF); + storage_ = static_cast((caster.vl + rounding_bias) >> 16); + } + } + bf16_t() : storage_(0) {} + inline static bf16_t from_storage(uint16_t v) { + bf16_t ret; + ret.storage_ = v; + return ret; + } +}; + +// Naive implementation of `dnnl_brgemm_dispatch` +int64_t dnnl_brgemm_dispatch_naive(int64_t M, int64_t N, int64_t K, int64_t LDA, + int64_t LDB, int64_t LDC, int64_t stride_a, + int64_t stride_b, float beta, int64_t dtypeA, + int64_t dtypeB); + +// Naive implementation of `dnnl_brgemm_execute` +void dnnl_brgemm_execute_naive(int64_t kernel, void *A, uint64_t A_offset, + void *B, uint64_t B_offset, void *C, + uint64_t C_offset, int num); + #endif diff --git a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt index 7d47bb825..68f597d3b 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CPURuntime/CMakeLists.txt @@ -6,9 +6,7 @@ file(GLOB_RECURSE MICROKERNEL_RUNTIME_SOURCES ) if (GC_ENABLE_RUNTIME_NAIVE_BRGEMM) - string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/BrgemmOnednn.cpp" "" MICROKERNEL_RUNTIME_SOURCES "${MICROKERNEL_RUNTIME_SOURCES}") -else() - string(REPLACE "${CMAKE_CURRENT_SOURCE_DIR}/Microkernel/BrgemmNaive.cpp" "" MICROKERNEL_RUNTIME_SOURCES "${MICROKERNEL_RUNTIME_SOURCES}") + add_definitions("-DGC_ENABLE_RUNTIME_NAIVE_BRGEMM=1") endif() include(onednn) @@ -16,8 +14,6 @@ include(onednn) get_property(GC_DNNL_INCLUDES GLOBAL PROPERTY GC_DNNL_INCLUDES) get_property(GC_DNNL_LIB_DEPS GLOBAL PROPERTY GC_DNNL_LIB_DEPS) -include_directories(${GC_DNNL_INCLUDES}) - gc_add_mlir_library(GcCpuRuntime SHARED Parallel.cpp @@ -35,6 +31,8 @@ gc_add_mlir_library(GcCpuRuntime EXCLUDE_FROM_LIBMLIR ) +target_include_directories(GcCpuRuntime PRIVATE ${GC_DNNL_INCLUDES}) + if ("iomp" IN_LIST OpenMP_C_LIB_NAMES OR "omp" IN_LIST OpenMP_C_LIB_NAMES OR "omp5" IN_LIST OpenMP_C_LIB_NAMES) else() target_compile_options(GcCpuRuntime PRIVATE "-DGC_NEEDS_OMP_WRAPPER") diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp index 7ed688917..55eb9a50f 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp @@ -21,41 +21,6 @@ namespace { -struct bf16_t { - uint16_t storage_; - union caster_t { - uint32_t vl; - float vf; - }; - operator float() const { - caster_t val; - val.vl = uint32_t(storage_) << 16; - return val.vf; - } - bool operator==(const bf16_t &compare_to) const { - return storage_ == compare_to.storage_; - } - bool operator!=(const bf16_t &compare_to) const { - return storage_ != compare_to.storage_; - } - bf16_t(float v) { - if (std::isnan(v)) { - storage_ = UINT32_C(0x7FC0); - } else { - caster_t caster; - caster.vf = v; - uint32_t rounding_bias = ((caster.vl >> 16) & 1) + UINT32_C(0x7FFF); - storage_ = static_cast((caster.vl + rounding_bias) >> 16); - } - } - bf16_t() : storage_(0) {} - inline static bf16_t from_storage(uint16_t v) { - bf16_t ret; - ret.storage_ = v; - return ret; - } -}; - struct brgemm_params_t { int64_t M, N, K; int64_t LDA, LDB, LDC; @@ -183,12 +148,10 @@ using write_lock_guard_t = std::unique_lock; static std::shared_mutex g_brgemm_lock; static std::vector g_brgemm_list; -extern "C" { - -int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, - int64_t LDB, int64_t LDC, int64_t stride_a, - int64_t stride_b, float beta, int64_t dtypeA, - int64_t dtypeB) { +int64_t dnnl_brgemm_dispatch_naive(int64_t M, int64_t N, int64_t K, int64_t LDA, + int64_t LDB, int64_t LDC, int64_t stride_a, + int64_t stride_b, float beta, int64_t dtypeA, + int64_t dtypeB) { write_lock_guard_t g(g_brgemm_lock); // simply store the given parameters for naive BRGEMM g_brgemm_list.emplace_back(brgemm_params_t(M, N, K, LDA, LDB, LDC, stride_a, @@ -196,13 +159,9 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, return g_brgemm_list.size() - 1; } -void dnnl_brgemm_tileconfig(int64_t kernel) { return; } - -void dnnl_brgemm_tilerelease() { return; } - -void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void *B, - uint64_t B_offset, void *C, uint64_t C_offset, - int num) { +void dnnl_brgemm_execute_naive(int64_t kernel, void *A, uint64_t A_offset, + void *B, uint64_t B_offset, void *C, + uint64_t C_offset, int num) { brgemm_params_t params; { read_lock_guard_t g(g_brgemm_lock); @@ -252,4 +211,29 @@ void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void *B, assert(false && "unsupported input dtypes"); } } + +#if defined(GC_ENABLE_RUNTIME_NAIVE_BRGEMM) + +extern "C" { + +int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, + int64_t LDB, int64_t LDC, int64_t stride_a, + int64_t stride_b, float beta, int64_t dtypeA, + int64_t dtypeB) { + return dnnl_brgemm_dispatch_naive(M, N, K, LDA, LDB, LDC, stride_a, stride_b, + beta, dtypeA, dtypeB); +} + +void dnnl_brgemm_tileconfig(int64_t kernel) { return; } + +void dnnl_brgemm_tilerelease() { return; } + +void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void *B, + uint64_t B_offset, void *C, uint64_t C_offset, + int num) { + return dnnl_brgemm_execute_naive(kernel, A, A_offset, B, B_offset, C, + C_offset, num); } +} + +#endif diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp index 751ead17f..d7dd2ccf9 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp @@ -32,6 +32,8 @@ #include "gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmInterface.h" +#if !defined(GC_ENABLE_RUNTIME_NAIVE_BRGEMM) + using namespace dnnl::impl::cpu::x64; namespace dnnl { @@ -161,3 +163,5 @@ void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset, (void *)(C_arith + C_offset_in_bytes), (void *)scratch); } } + +#endif diff --git a/test/mlir/unittests/ExecutionEngine/BrgemmRuntime.cpp b/test/mlir/unittests/ExecutionEngine/BrgemmRuntime.cpp new file mode 100644 index 000000000..b6abc1f3f --- /dev/null +++ b/test/mlir/unittests/ExecutionEngine/BrgemmRuntime.cpp @@ -0,0 +1,182 @@ +//===-- BrgemmRuntime.cpp - Brgemm Runtime ----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gtest/gtest.h" +#include +#include + +#include "gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmInterface.h" + +extern "C" { +extern int gc_runtime_keep_alive; +} + +template inline T randomFP(int range = 1, float delta = 0.0f) { + float fraction = (float)(rand()) / (float)(RAND_MAX); + return T(fraction * range - delta); +} + +template +inline void randomInitWithFP(T *buffer, size_t size, int range = 1, + float delta = 0.0f) { + for (size_t index = 0; index < size; index++) + buffer[index] = randomFP(range, delta); +} + +template +inline bool compareDataFP(T *ref, T *dst, size_t size, float rtol = 1e-4f, + float atol = 1e-6f) { + for (size_t index = 0; index < size; index++) { + const float ref_f32 = static_cast(ref[index]); + const float dst_f32 = static_cast(dst[index]); + const double diff_f32 = dst_f32 - ref_f32; + const double gap = double(rtol) * (std::abs(ref_f32) > std::abs(dst_f32) + ? std::abs(ref_f32) + : std::abs(dst_f32)) + + atol; + bool good = std::abs(diff_f32) <= gap; + EXPECT_TRUE(good) << "Index: " << index << ", ref_f32=" << ref_f32 + << ", dst_f32=" << dst_f32; + if (!good) + return false; + } + return true; +} + +template +inline void testBrgemmRuntimeFP(int batch, int M, int N, int K, int LDA, + int LDB, int LDC, int strideA, int strideB, + float beta) { + using dnnl_f32_enum_val_t = std::integral_constant; + using dnnl_bf16_enum_val_t = std::integral_constant; + constexpr int dtypeA = + std::conditional::value, dnnl_f32_enum_val_t, + dnnl_bf16_enum_val_t>::type::value; + constexpr int dtypeB = dtypeA; + + T A[batch * M * K]; + T B[batch * K * N]; + float refC[M * N]; + + randomInitWithFP(A, batch * M * K, 10, 10.0f); + randomInitWithFP(B, batch * K * N, 10, 10.0f); + randomInitWithFP(refC, M * N, 10, 10.0f); + + float dstC[M * N]; + memcpy(dstC, refC, sizeof(float) * M * N); + + // Calculate reference + auto refHandle = dnnl_brgemm_dispatch_naive(M, N, K, LDA, LDB, LDC, strideA, + strideB, beta, dtypeA, dtypeB); + dnnl_brgemm_execute_naive(refHandle, A, 0, B, 0, refC, 0, batch); + + // Calculate destination + auto dstHandle = dnnl_brgemm_dispatch(M, N, K, LDA, LDB, LDC, strideA, + strideB, beta, dtypeA, dtypeB); + dnnl_brgemm_tileconfig(dstHandle); + dnnl_brgemm_execute(dstHandle, A, 0, B, 0, dstC, 0, batch); + dnnl_brgemm_tilerelease(); + + ASSERT_TRUE(compareDataFP(refC, dstC, M * N)); +} + +template +inline void randomInitWithInt(T *buffer, size_t size, int range, + int delta = 0) { + for (size_t index = 0; index < size; index++) + buffer[index] = rand() % range - delta; +} + +template inline bool compareDataInt(T *ref, T *dst, size_t size) { + for (size_t index = 0; index < size; index++) { + bool good = ref[index] == dst[index]; + EXPECT_TRUE(good) << "Index: " << index << ", ref=" << ref[index] + << ", dst=" << dst[index]; + if (!good) + return false; + } + return true; +} + +inline void testBrgemmRuntimeInt(int batch, int M, int N, int K, int LDA, + int LDB, int LDC, int strideA, int strideB, + float beta) { + constexpr int dtypeA = 6; // dnnl_u8 enum val + constexpr int dtypeB = 5; // dnnl_s8 enum val + + uint8_t A[batch * M * K]; + int8_t B[batch * K * N]; + int32_t refC[M * N]; + + randomInitWithInt(A, batch * M * K, 100, 100); + randomInitWithInt(B, batch * K * N, 100, 100); + randomInitWithInt(refC, M * N, 500, 500); + + int32_t dstC[M * N]; + memcpy(dstC, refC, sizeof(int32_t) * M * N); + + // Calculate reference + auto refHandle = dnnl_brgemm_dispatch_naive(M, N, K, LDA, LDB, LDC, strideA, + strideB, beta, dtypeA, dtypeB); + dnnl_brgemm_execute_naive(refHandle, A, 0, B, 0, refC, 0, batch); + + // Calculate destination + auto dstHandle = dnnl_brgemm_dispatch(M, N, K, LDA, LDB, LDC, strideA, + strideB, beta, dtypeA, dtypeB); + dnnl_brgemm_tileconfig(dstHandle); + dnnl_brgemm_execute(dstHandle, A, 0, B, 0, dstC, 0, batch); + dnnl_brgemm_tilerelease(); + + ASSERT_TRUE(compareDataInt(refC, dstC, M * N)); +} + +TEST(ExecutionEngine, TestBrgemmRuntimeF32) { + gc_runtime_keep_alive = 0; + + srand(static_cast(time(0))); + + constexpr int batch = 4; + constexpr int M = 32, N = 32, K = 32; + constexpr int LDA = 32, LDB = 32, LDC = 32; + constexpr int strideA = 1024, strideB = 1024; + + testBrgemmRuntimeFP(batch, M, N, K, LDA, LDB, LDC, strideA, strideB, + 0.0f); + testBrgemmRuntimeFP(batch, M, N, K, LDA, LDB, LDC, strideA, strideB, + 1.0f); +} + +TEST(ExecutionEngine, TestBrgemmRuntimeBF16) { + gc_runtime_keep_alive = 0; + + srand(static_cast(time(0))); + + constexpr int batch = 4; + constexpr int M = 32, N = 32, K = 32; + constexpr int LDA = 32, LDB = 32, LDC = 32; + constexpr int strideA = 1024, strideB = 1024; + + testBrgemmRuntimeFP(batch, M, N, K, LDA, LDB, LDC, strideA, strideB, + 0.0f); + testBrgemmRuntimeFP(batch, M, N, K, LDA, LDB, LDC, strideA, strideB, + 1.0f); +} + +TEST(ExecutionEngine, TestBrgemmRuntimeU8S8) { + gc_runtime_keep_alive = 0; + + srand(static_cast(time(0))); + + constexpr int batch = 4; + constexpr int M = 32, N = 32, K = 32; + constexpr int LDA = 32, LDB = 32, LDC = 32; + constexpr int strideA = 1024, strideB = 1024; + + testBrgemmRuntimeInt(batch, M, N, K, LDA, LDB, LDC, strideA, strideB, 0.0f); + testBrgemmRuntimeInt(batch, M, N, K, LDA, LDB, LDC, strideA, strideB, 1.0f); +} diff --git a/test/mlir/unittests/ExecutionEngine/CMakeLists.txt b/test/mlir/unittests/ExecutionEngine/CMakeLists.txt index 2cfe3f77e..2735abb28 100644 --- a/test/mlir/unittests/ExecutionEngine/CMakeLists.txt +++ b/test/mlir/unittests/ExecutionEngine/CMakeLists.txt @@ -1,6 +1,8 @@ add_mlir_unittest(GCExecutionEngineTests JitWrapper.cpp + BrgemmRuntime.cpp ) + target_link_libraries(GCExecutionEngineTests PRIVATE GcJitWrapper From 784760d2727325fe9281b715e798e307d6cfc973 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 13 Aug 2024 21:00:17 -0700 Subject: [PATCH 46/93] fix clang-tidy --- test/mlir/unittests/ExecutionEngine/BrgemmRuntime.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/mlir/unittests/ExecutionEngine/BrgemmRuntime.cpp b/test/mlir/unittests/ExecutionEngine/BrgemmRuntime.cpp index b6abc1f3f..b6e458172 100644 --- a/test/mlir/unittests/ExecutionEngine/BrgemmRuntime.cpp +++ b/test/mlir/unittests/ExecutionEngine/BrgemmRuntime.cpp @@ -138,7 +138,7 @@ inline void testBrgemmRuntimeInt(int batch, int M, int N, int K, int LDA, TEST(ExecutionEngine, TestBrgemmRuntimeF32) { gc_runtime_keep_alive = 0; - srand(static_cast(time(0))); + srand(static_cast(time(nullptr))); constexpr int batch = 4; constexpr int M = 32, N = 32, K = 32; @@ -154,7 +154,7 @@ TEST(ExecutionEngine, TestBrgemmRuntimeF32) { TEST(ExecutionEngine, TestBrgemmRuntimeBF16) { gc_runtime_keep_alive = 0; - srand(static_cast(time(0))); + srand(static_cast(time(nullptr))); constexpr int batch = 4; constexpr int M = 32, N = 32, K = 32; @@ -170,7 +170,7 @@ TEST(ExecutionEngine, TestBrgemmRuntimeBF16) { TEST(ExecutionEngine, TestBrgemmRuntimeU8S8) { gc_runtime_keep_alive = 0; - srand(static_cast(time(0))); + srand(static_cast(time(nullptr))); constexpr int batch = 4; constexpr int M = 32, N = 32, K = 32; From 9bfbf7b742ddd699ff0e71688282e40d979d03e1 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 13 Aug 2024 21:07:06 -0700 Subject: [PATCH 47/93] fix clang-tidy --- .../ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp index 55eb9a50f..86e75cee6 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp @@ -27,7 +27,7 @@ struct brgemm_params_t { int64_t stride_a, stride_b; float beta; int64_t dtypeA, dtypeB; - brgemm_params_t() {} + brgemm_params_t() = default; brgemm_params_t(int64_t m, int64_t n, int64_t k, int64_t lda, int64_t ldb, int64_t ldc, int64_t sa, int64_t sb, float b, int64_t da, int64_t db) @@ -154,8 +154,8 @@ int64_t dnnl_brgemm_dispatch_naive(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t dtypeB) { write_lock_guard_t g(g_brgemm_lock); // simply store the given parameters for naive BRGEMM - g_brgemm_list.emplace_back(brgemm_params_t(M, N, K, LDA, LDB, LDC, stride_a, - stride_b, beta, dtypeA, dtypeB)); + g_brgemm_list.emplace_back(M, N, K, LDA, LDB, LDC, stride_a, stride_b, beta, + dtypeA, dtypeB); return g_brgemm_list.size() - 1; } From 7d3cd4b8d8f7bb857566e918ae463a2a991b6122 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Fri, 9 Aug 2024 00:06:53 -0700 Subject: [PATCH 48/93] Revert "remove pass ConvertLinalgToMicrokernel" This reverts commit c01ad17bf2c03e80fdbc984ce5b0223a64f87b10. --- include/gc/Transforms/CMakeLists.txt | 2 + .../Microkernel/BrgemmRuntimeUtils.h | 51 +++ .../gc/Transforms/Microkernel/CMakeLists.txt | 6 + .../Microkernel/MicrokernelPasses.h | 28 ++ .../Microkernel/MicrokernelPasses.td | 43 ++ .../gc/Transforms/Utils/StructuredOpMatcher.h | 20 + lib/gc/Dialect/Microkernel/CMakeLists.txt | 1 + lib/gc/Transforms/CMakeLists.txt | 2 + lib/gc/Transforms/Microkernel/CMakeLists.txt | 22 + .../ConvertLinalgToMicrokernel.cpp | 388 ++++++++++++++++++ src/gc-opt/gc-opt.cpp | 2 + .../Microkernel/linalg-to-microkernel.mlir | 173 ++++++++ 12 files changed, 738 insertions(+) create mode 100644 include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h create mode 100644 include/gc/Transforms/Microkernel/CMakeLists.txt create mode 100644 include/gc/Transforms/Microkernel/MicrokernelPasses.h create mode 100644 include/gc/Transforms/Microkernel/MicrokernelPasses.td create mode 100644 lib/gc/Transforms/Microkernel/CMakeLists.txt create mode 100644 lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp create mode 100644 test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir diff --git a/include/gc/Transforms/CMakeLists.txt b/include/gc/Transforms/CMakeLists.txt index 8014cba72..08443020b 100644 --- a/include/gc/Transforms/CMakeLists.txt +++ b/include/gc/Transforms/CMakeLists.txt @@ -11,3 +11,5 @@ mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix GraphCompiler) mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix GraphCompiler) add_public_tablegen_target(GraphCompilerPassIncGen) add_mlir_doc(Passes GraphCompilerPasses ./ -gen-pass-doc) + +add_subdirectory(Microkernel) diff --git a/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h b/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h new file mode 100644 index 000000000..adb214e10 --- /dev/null +++ b/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h @@ -0,0 +1,51 @@ +//===-- BrgemmRuntimeUtils.h - Utils for Brgemm Runtime ---------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_MICROKERNEL_BRGEMMRUNTIMEUTILS_H +#define GC_MICROKERNEL_BRGEMMRUNTIMEUTILS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/PatternMatch.h" +#include "oneapi/dnnl/dnnl_types.h" + +namespace mlir::microkernel { + +// these strings contain symbols for BRGEMM interfaces used in mlir pass +static const std::string DNNL_BRGEMM_DISPATCH_NAME = "dnnl_brgemm_dispatch"; +static const std::string DNNL_BRGEMM_TILECFG_NAME = "dnnl_brgemm_tileconfig"; +static const std::string DNNL_BRGEMM_TILERELEASE_NAME = + "dnnl_brgemm_tilerelease"; +static const std::string DNNL_BRGEMM_EXECUTE_NAME = "dnnl_brgemm_execute"; + +static inline int64_t getDnnlDataTypeVal(RewriterBase &rewriter, + Attribute attr) { + auto context = rewriter.getContext(); + auto tattr = dyn_cast_or_null(attr); + assert(tattr); + if (tattr == TypeAttr::get(FloatType::getF32(context))) { + return static_cast(dnnl_f32); + } else if (tattr == TypeAttr::get(FloatType::getBF16(context))) { + return static_cast(dnnl_bf16); + } else if (tattr == TypeAttr::get(FloatType::getF16(context))) { + return static_cast(dnnl_f16); + } else if (tattr == TypeAttr::get( + IntegerType::get(context, 32, IntegerType::Signed))) { + return static_cast(dnnl_s32); + } else if (tattr == + TypeAttr::get(IntegerType::get(context, 8, IntegerType::Signed))) { + return static_cast(dnnl_s8); + } else if (tattr == TypeAttr::get(IntegerType::get(context, 8, + IntegerType::Unsigned))) { + return static_cast(dnnl_u8); + } + return static_cast(dnnl_data_type_undef); +} + +}; // namespace mlir::microkernel + +#endif // GC_MICROKERNEL_BRGEMMRUNTIMEUTILS_H diff --git a/include/gc/Transforms/Microkernel/CMakeLists.txt b/include/gc/Transforms/Microkernel/CMakeLists.txt new file mode 100644 index 000000000..2e345775c --- /dev/null +++ b/include/gc/Transforms/Microkernel/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS MicrokernelPasses.td) +mlir_tablegen(MicrokernelPasses.h.inc --gen-pass-decls -name Microkernel) +mlir_tablegen(MicrokernelPasses.capi.h.inc -gen-pass-capi-header --prefix Microkernel) +mlir_tablegen(MicrokernelPasses.capi.cpp.inc -gen-pass-capi-impl --prefix Microkernel) +add_public_tablegen_target(MLIRMicrokernelPassesIncGen) +add_mlir_doc(MicrokernelPasses GraphCompilerMicrokernelPasses ./ -gen-pass-doc) diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.h b/include/gc/Transforms/Microkernel/MicrokernelPasses.h new file mode 100644 index 000000000..ee9da8a4e --- /dev/null +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.h @@ -0,0 +1,28 @@ +//===- MicrokernelPasses.h - Graph Compiler microkerenl passes --*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_MICROKERNELPASSES_H +#define GC_MICROKERNELPASSES_H + +#include "gc/Dialect/Linalgx/LinalgxDialect.h" +#include "gc/Dialect/Microkernel/MicrokernelDialect.h" +#include "gc/Dialect/Microkernel/MicrokernelOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace microkernel { +#define GEN_PASS_DECL +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define GEN_PASS_REGISTRATION +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" +} // namespace microkernel +} // namespace mlir + +#endif // GC_MICROKERNELPASSES_H diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td new file mode 100644 index 000000000..16e11532b --- /dev/null +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.td @@ -0,0 +1,43 @@ +//===-- MicrokernelPasses.td - microkernel passes ----------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_DIALECT_MICROKERNELPASSES +#define GC_DIALECT_MICROKERNELPASSES + +include "mlir/Pass/PassBase.td" + +def ConvertLinalgToMicrokernel: Pass<"convert-linalg-to-microkernel", "::mlir::func::FuncOp"> { + let summary = "Lower eligible linalg ops to microkernels"; + let description = [{ + Convert eligible linalg ops to microkernel dialects based on pattern matching. + For example: + ``` + scf.forall { + linalg.fill ins(...) outs(...) -> tensor<...> + linalg.batch_reduce_matmul ins(...) outs(...) -> tensor<...> + } + ``` + Will be changed into + ``` + scf.forall { + linalg.fill ins(...) outs(...) -> tensor<...> + %0 = microkernel.brgemm.dispatch(...) + microkernel.brgemm.prologue(%0) + microkernel.brgemm(%0, ...) + microkernel.brgemm.epilogue(%0) + } + ``` + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", + "linalg::LinalgDialect", + "linalgx::LinalgxDialect", + "microkernel::MicrokernelDialect"]; +} + +#endif // GC_DIALECT_MICROKERNELPASSES diff --git a/include/gc/Transforms/Utils/StructuredOpMatcher.h b/include/gc/Transforms/Utils/StructuredOpMatcher.h index 66bd22b7a..c931472a1 100644 --- a/include/gc/Transforms/Utils/StructuredOpMatcher.h +++ b/include/gc/Transforms/Utils/StructuredOpMatcher.h @@ -217,6 +217,26 @@ template struct EqualsTo { }; template EqualsTo(T) -> EqualsTo; +// Callable object to check if the input is less than or equal to specified +// `value`. +struct LessThanOrEqualTo { + LessThanOrEqualTo() = delete; + explicit LessThanOrEqualTo(size_t value) : value(value){}; + const size_t value; + + bool operator()(size_t value) const { return value <= this->value; } +}; + +// Callable object to check if the input is greater than or equal to specified +// `value`. +struct GreaterThanOrEqualTo { + GreaterThanOrEqualTo() = delete; + explicit GreaterThanOrEqualTo(size_t value) : value(value){}; + const size_t value; + + bool operator()(size_t value) const { return value >= this->value; } +}; + // Callable object to validate number of init operands for `op`. struct NumDpsInits { NumDpsInits() = delete; diff --git a/lib/gc/Dialect/Microkernel/CMakeLists.txt b/lib/gc/Dialect/Microkernel/CMakeLists.txt index 33e420f2e..3a9099d4f 100644 --- a/lib/gc/Dialect/Microkernel/CMakeLists.txt +++ b/lib/gc/Dialect/Microkernel/CMakeLists.txt @@ -10,6 +10,7 @@ gc_add_mlir_dialect_library(MLIRMicrokernel DEPENDS MLIRMicrokernelOpsIncGen + MLIRMicrokernelPassesIncGen LINK_LIBS PUBLIC ${MLIR_LINK_COMPONENTS} diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 2be85ebea..0b6d83ae0 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -35,3 +35,5 @@ gc_add_mlir_library(GcPasses if(GC_ENABLE_IMEX) add_subdirectory(GPU) endif() + +add_subdirectory(Microkernel) diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt new file mode 100644 index 000000000..642eaa6ca --- /dev/null +++ b/lib/gc/Transforms/Microkernel/CMakeLists.txt @@ -0,0 +1,22 @@ +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR) + +include(onednn) + +gc_add_mlir_dialect_library(MLIRMicrokernelTransforms + ConvertLinalgToMicrokernel.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/ + + DEPENDS + MLIRMicrokernelPassesIncGen + + LINK_LIBS PUBLIC + ${MLIR_LINK_COMPONENTS} + GcInterface + GcUtilsIR + ) + +get_property(GC_DNNL_INCLUDES GLOBAL PROPERTY GC_DNNL_INCLUDES) + +target_include_directories(MLIRMicrokernelTransforms PUBLIC ${GC_DNNL_INCLUDES}) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp new file mode 100644 index 000000000..45380a19e --- /dev/null +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -0,0 +1,388 @@ +//===-- ConvertLinalgToMicrokernel.cpp - Linalg To Microkernel --*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineMap.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" +#include "gc/Transforms/Utils/StructuredOpMatcher.h" +#include "gc/Transforms/Utils/ValueUtils.h" + +namespace mlir::microkernel { +#define GEN_PASS_DEF_CONVERTLINALGTOMICROKERNEL +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define DEBUG_TYPE "convert-linalg-to-microkernel" + +struct BrgemmInfo { + enum BrgemmMode { STRIDE_MODE, LIST_MODE }; + int64_t m; + int64_t n; + int64_t k; + int64_t batchSize; + int64_t addrLen; + + int64_t lda; + int64_t ldb; + int64_t ldc; + int64_t strideA; + int64_t strideB; + + bool isInitOutput; + BrgemmMode mode; +}; + +FailureOr +customInferContractionDims(linalg::LinalgOp linalgOp) { + auto dims = linalg::inferContractionDims(linalgOp); + if (failed(dims)) + return failure(); + if (llvm::isa(linalgOp)) { + // For VnniOp, the K reduction dims (dim index 3 & 4) cannot be infered by + // linalg utils because they form complex affine in operand A; Manually add + // them here + dims->k.push_back(3); + dims->k.push_back(4); + } + return dims; +} + +static bool isMatchingAffineResult(linalg::LinalgOp linalgOp, AffineExpr expr, + ArrayRef dimPos) { + if (dimPos.size() > 2) { + return false; + } + auto firstDim = getAffineDimExpr(dimPos[0], linalgOp.getContext()); + if (dimPos.size() == 1) + return firstDim == expr; + + // If not regular dim affine, check for VNNI format K affine + auto secondKPosDim = getAffineDimExpr(dimPos[1], linalgOp.getContext()); + // An K affine result for VNNI should be this format: + // d{kPos[0]} * s{kPos[1]} + d{kPos[1]} (k0 * K_vnni + k1) + auto add = dyn_cast(expr); + if (!add) + return false; + if (add.getKind() != AffineExprKind::Add) + return false; + auto lhs = add.getLHS(); + auto rhs = add.getRHS(); + if (rhs != secondKPosDim) + return false; + auto mul = dyn_cast(lhs); + if (!mul || mul.getKind() != AffineExprKind::Mul || mul.getLHS() != firstDim) + return false; + + auto cst_affine = dyn_cast(mul.getRHS()); + return cst_affine && + (cst_affine.getValue() == 2 || cst_affine.getValue() == 4); +} + +// Return the position of `dim` in the codomain of `operand`. +static std::optional getPosInCodomain(ArrayRef dimPos, + OpOperand *operand, + linalg::LinalgOp linalgOp) { + assert(operand->getOwner() == linalgOp); + auto map = linalgOp.getMatchingIndexingMap(operand); + for (unsigned i = 0, numResults = map.getNumResults(); i < numResults; i++) { + if (isMatchingAffineResult(linalgOp, map.getResult(i), dimPos)) + return i; + } + return std::nullopt; +} + +static FailureOr +inferBrgemmInfo(linalg::LinalgOp linalgOp, + const linalg::ContractionDimensions &dims) { + unsigned mPos = dims.m[0]; + unsigned nPos = dims.n[0]; + // dims.k could be of 2 cases: + // 1. dims.k.size() == 2: non-VNNI, K = dims.k[1] + // 2. dims.k.size() == 3: VNNI, K = dims.k[1] * dims.k[2] + unsigned batchPos = dims.k.front(); + SmallVector kPos; + if (dims.k.size() == 2) { + kPos = {dims.k[1]}; + } else if (dims.k.size() == 3) { + kPos = {dims.k[1], dims.k[2]}; + } else { + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Candidate dims: " + << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] m pos in affine: " << mPos + << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] n pos in affine: " << nPos + << "\n"); + for (auto kp : kPos) { + LLVM_DEBUG(llvm::dbgs() + << "[inferBrgemmInfo] k pos in affine: " << kp << "\n"); + } + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] batch pos in affine: " + << batchPos << "\n"); + + auto checkStridesAndGetLda = + [&](ArrayRef minorDim, ArrayRef majorDim, + OpOperand *operand, bool allowVnni) -> FailureOr { + auto minorDimPosInCodomain = getPosInCodomain(minorDim, operand, linalgOp); + auto majorDimPosInCodomain = getPosInCodomain(majorDim, operand, linalgOp); + if (!minorDimPosInCodomain || !majorDimPosInCodomain) + return failure(); + auto stridesOnOperand = utils::getStaticStrides(operand->get()); + if (failed(stridesOnOperand)) + return failure(); + auto minorDimLd = (*stridesOnOperand)[*minorDimPosInCodomain]; + auto majorDimLd = (*stridesOnOperand)[*majorDimPosInCodomain]; + if (minorDimLd != 1) { + // VNNI format exists, special treatment to align LD with non-VNNI format + if (!allowVnni || (minorDimLd != 2 && minorDimLd != 4)) + return failure(); + return majorDimLd / minorDimLd; + } + return majorDimLd; + }; + + OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; + OpOperand *operandB = linalgOp.getDpsInputOperands()[1]; + OpOperand *operandC = &linalgOp.getDpsInitsMutable()[0]; + + // A(m, k) + auto lda = checkStridesAndGetLda(kPos, {mPos}, operandA, false); + if (failed(lda)) + return failure(); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on A: OK\n"); + + // B(k, n) + // note: B does not use VNNI format K affine + auto ldb = checkStridesAndGetLda({nPos}, {kPos[0]}, operandB, true); + if (failed(ldb)) + return failure(); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on B: OK\n"); + + // C(m, n) + auto ldc = checkStridesAndGetLda({nPos}, {mPos}, operandC, false); + if (failed(ldc)) + return failure(); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on C: OK\n"); + + int64_t strideA = 1; + int64_t strideB = 1; + auto batchPosCodomainA = getPosInCodomain(batchPos, operandA, linalgOp); + auto stridesOnA = utils::getStaticStrides(operandA->get()); + strideA = (*stridesOnA)[*batchPosCodomainA]; + + auto batchPosCodomainB = getPosInCodomain(batchPos, operandB, linalgOp); + auto stridesOnB = utils::getStaticStrides(operandB->get()); + strideB = (*stridesOnB)[*batchPosCodomainB]; + + auto loops = linalgOp.computeStaticLoopSizes(); + auto kSize = + kPos.size() == 1 ? loops[kPos[0]] : (loops[kPos[0]] * loops[kPos[1]]); + + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] final BrgemmInfo: m(" + << loops[mPos] << "), n(" << loops[nPos] << "), k(" + << kSize << "), batch(" << loops[batchPos] + << "), lda(" << *lda << "), ldb(" << *ldb << "), ldc(" + << *ldc << "), strideA(" << strideA << "), strideB(" + << strideB << ")\n"); + BrgemmInfo info{loops[mPos], + loops[nPos], + kSize, + loops[batchPos], + 0 /* addrLen useless under stride mode */, + *lda, + *ldb, + *ldc, + strideA, + strideB, + false, + BrgemmInfo::STRIDE_MODE}; + return info; +} + +static FailureOr getBrgemmInfo(linalg::LinalgOp linalgOp) { + using namespace mlir::structured_match; + auto validBrgemmMatcher = StructuredOpMatcher::make() + .output(MatchAll(), HasStaticShape()) + .input(MatchAll(), HasStaticShape()) + .output(MatchAll(), HasStaticStrides()) + .input(MatchAll(), HasStaticStrides()) + .operation(NumOfLoops(GreaterThanOrEqualTo(3))); + // clang-format on + if (!validBrgemmMatcher.match(linalgOp)) + return failure(); + + auto contractionDims = customInferContractionDims(linalgOp); + if (failed(contractionDims)) { + LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Not a valid contraction\n"); + return failure(); + } + if (contractionDims->m.size() != 1 || contractionDims->n.size() != 1 || + // batch-reduce dim for BRGEMM should be identified as one of k dim + // including VNNI & non-VNNI cases + (contractionDims->k.size() != 2 && contractionDims->k.size() != 3) || + !contractionDims->batch.empty()) { + LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions\n"); + LLVM_DEBUG(llvm::dbgs() + << "[checkStructure] " << contractionDims->m.size() << " " + << contractionDims->n.size() << " " << contractionDims->k.size() + << " " << contractionDims->batch.size() << "\n"); + return failure(); + } + unsigned classifiedLoops = + contractionDims->m.size() + contractionDims->n.size() + + contractionDims->k.size() + contractionDims->batch.size(); + if (linalgOp.getNumLoops() != classifiedLoops) { + LLVM_DEBUG(llvm::dbgs() + << "[checkStructure] Not all loops are classified\n"); + return failure(); + } + + return inferBrgemmInfo(linalgOp, *contractionDims); +} + +// Replace linalgOp with a set of microkernel ops +static void replaceOpWithMicrokernelOpSet(PatternRewriter &rewriter, + linalg::LinalgOp linalgOp, + const BrgemmInfo &info) { + assert(linalgOp.getDpsInputs().size() == 2); + OpBuilder::InsertionGuard guard(rewriter); + + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + Location loc = linalgOp.getLoc(); + SmallVector brgemmFlags; + if (info.isInitOutput) { + brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( + rewriter.getContext(), microkernel::BrgemmFlags::BETA_0)); + } + if (info.mode == BrgemmInfo::STRIDE_MODE) { + brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( + rewriter.getContext(), microkernel::BrgemmFlags::STRIDE)); + } else if (info.mode == BrgemmInfo::LIST_MODE) { + brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( + rewriter.getContext(), microkernel::BrgemmFlags::LIST)); + } + + SmallVector brgemmDtypes{ + TypeAttr::get(getElementTypeOrSelf(linalgOp.getDpsInputs()[0].getType())), + TypeAttr::get( + getElementTypeOrSelf(linalgOp.getDpsInputs()[1].getType()))}; + + // create dispatch op + auto flags = rewriter.getArrayAttr(brgemmFlags); + auto dtypes = rewriter.getArrayAttr(brgemmDtypes); + DenseI64ArrayAttr dims = DenseI64ArrayAttr::get( + rewriter.getContext(), + ArrayRef{info.m, info.n, info.k, info.lda, info.ldb, info.ldc, + info.strideA, info.strideB}); + Value dispatched = rewriter.create( + loc, integer64, dims, flags, dtypes); + + // create prologue op + rewriter.create(loc, dispatched); + + // create brgemm invoke op + Value batchDim = rewriter.create( + loc, integer64, rewriter.getIntegerAttr(integer64, info.batchSize)); + Value lenDim = rewriter.create( + loc, integer64, rewriter.getIntegerAttr(integer64, info.addrLen)); + SmallVector invokeOperands; + invokeOperands.push_back(dispatched); + invokeOperands.append(linalgOp->getOperands().begin(), + linalgOp->getOperands().end()); + invokeOperands.push_back(batchDim); + invokeOperands.push_back(lenDim); + rewriter.create(loc, invokeOperands); + + // create epilogue op & replace original op + rewriter.replaceOpWithNewOp(linalgOp, + dispatched); +} + +bool isZeroArithConstant(arith::ConstantOp op) { + if (!op) + return false; + + if (auto intAttr = llvm::dyn_cast(op.getValue())) { + if (intAttr.getInt() != 0) + return false; + } else if (auto floatAttr = llvm::dyn_cast(op.getValue())) { + if (!floatAttr.getValue().isZero()) + return false; + } else + return false; + + return true; +} + +template +class ConvertContractionOpToBrgemmRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ContractionOp op, + PatternRewriter &rewriter) const final { + auto brgemmInfo = getBrgemmInfo(op); + if (failed(brgemmInfo)) + return failure(); + // Check for immediately preceding linalg::FillOp + Operation *rawOp = op; + auto block = rawOp->getBlock(); + auto opIter = Block::iterator(rawOp); + if (block->begin() != opIter) { + auto prevOp = &(*(--opIter)); + if (auto fillOp = dyn_cast(prevOp)) { + auto inputCst = dyn_cast_or_null( + fillOp.getInputs()[0].getDefiningOp()); + auto fillOperand = fillOp.getOutputs()[0]; + auto contractionOperand = op.getOutputs()[0]; + if (isZeroArithConstant(inputCst) && + contractionOperand == fillOperand) { + brgemmInfo->isInitOutput = true; + rewriter.eraseOp(prevOp); + } + } + } + replaceOpWithMicrokernelOpSet(rewriter, op, *brgemmInfo); + return success(); + } +}; + +class ConvertLinalgToMicrokernel + : public impl::ConvertLinalgToMicrokernelBase { +public: + using impl::ConvertLinalgToMicrokernelBase< + ConvertLinalgToMicrokernel>::ConvertLinalgToMicrokernelBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns + .add>( + &getContext()); + patterns.add< + ConvertContractionOpToBrgemmRewriter>( + &getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) + signalPassFailure(); + } +}; + +} // namespace mlir::microkernel diff --git a/src/gc-opt/gc-opt.cpp b/src/gc-opt/gc-opt.cpp index 5754f0269..89bacbdd8 100644 --- a/src/gc-opt/gc-opt.cpp +++ b/src/gc-opt/gc-opt.cpp @@ -23,6 +23,7 @@ #ifdef GC_HAS_ONEDNN_DIALECT #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #endif +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" #include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" @@ -51,6 +52,7 @@ int main(int argc, char *argv[]) { mlir::gc::registerCPUPipeline(); mlir::gc::registerGraphCompilerPasses(); mlir::cpuruntime::registerCPURuntimePasses(); + mlir::microkernel::registerMicrokernelPasses(); mlir::DialectRegistry registry; #ifdef GC_HAS_ONEDNN_DIALECT diff --git a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir new file mode 100644 index 000000000..224329f62 --- /dev/null +++ b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir @@ -0,0 +1,173 @@ +// RUN: gc-opt %s -convert-linalg-to-microkernel -split-input-file | FileCheck %s + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @basic_linalg_to_microkernel() { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg7, %arg8) in (4, 8) { + %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + linalg.batch_reduce_matmul ins(%subview, %subview_11 : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) + %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_14: f32, %out: f32): + %0 = arith.addf %in, %in_14 : f32 + linalg.yield %0 : f32 + } + %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %0 = arith.maximumf %in, %cst : f32 + linalg.yield %0 : f32 + } + memref.dealloc %alloc_10 : memref<32x32xf32> + } + return +} + +// CHECK-LABEL: basic_linalg_to_microkernel +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> +// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) +// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () +// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () + +// ----- + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @vnni_linalg_to_microkernel() { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg7, %arg8) in (4, 8) { + %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + linalgx.batch_reduce_matmul_vnni ins(%subview, %subview_11 : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) + %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_14: f32, %out: f32): + %0 = arith.addf %in, %in_14 : f32 + linalg.yield %0 : f32 + } + %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %0 = arith.maximumf %in, %cst : f32 + linalg.yield %0 : f32 + } + memref.dealloc %alloc_10 : memref<32x32xf32> + } + return +} + +// CHECK-LABEL: vnni_linalg_to_microkernel +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> +// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) +// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () +// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () + +// ----- + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @basic_linalg_to_microkernel_fusing_fill() { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg7, %arg8) in (4, 8) { + %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + linalg.fill ins(%cst : f32) outs(%alloc_10 : memref<32x32xf32>) + linalg.batch_reduce_matmul ins(%subview, %subview_11 : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) + %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_14: f32, %out: f32): + %0 = arith.addf %in, %in_14 : f32 + linalg.yield %0 : f32 + } + %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %0 = arith.maximumf %in, %cst : f32 + linalg.yield %0 : f32 + } + memref.dealloc %alloc_10 : memref<32x32xf32> + } + return +} + +// CHECK-LABEL: basic_linalg_to_microkernel_fusing_fill +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> +// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK-NOT: linalg.fill +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (f32, f32) +// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () +// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () + +// ----- + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @vnni_linalg_to_microkernel_fusing_fill() { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg7, %arg8) in (4, 8) { + %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + linalg.fill ins(%cst : f32) outs(%alloc_10 : memref<32x32xf32>) + linalgx.batch_reduce_matmul_vnni ins(%subview, %subview_11 : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) + %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_14: f32, %out: f32): + %0 = arith.addf %in, %in_14 : f32 + linalg.yield %0 : f32 + } + %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %0 = arith.maximumf %in, %cst : f32 + linalg.yield %0 : f32 + } + memref.dealloc %alloc_10 : memref<32x32xf32> + } + return +} + +// CHECK-LABEL: vnni_linalg_to_microkernel_fusing_fill +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> +// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> +// CHECK-NOT: linalg.fill +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) +// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () +// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () + +// ----- From 5165ba81fa7408161c6d05e549bfa2ec4841c7a8 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Mon, 19 Aug 2024 21:16:37 -0700 Subject: [PATCH 49/93] fix as per reviews --- .../gc/Transforms/Utils/StructuredOpMatcher.h | 10 --- .../ConvertLinalgToMicrokernel.cpp | 68 +++++++++++-------- 2 files changed, 38 insertions(+), 40 deletions(-) diff --git a/include/gc/Transforms/Utils/StructuredOpMatcher.h b/include/gc/Transforms/Utils/StructuredOpMatcher.h index c931472a1..66d398474 100644 --- a/include/gc/Transforms/Utils/StructuredOpMatcher.h +++ b/include/gc/Transforms/Utils/StructuredOpMatcher.h @@ -217,16 +217,6 @@ template struct EqualsTo { }; template EqualsTo(T) -> EqualsTo; -// Callable object to check if the input is less than or equal to specified -// `value`. -struct LessThanOrEqualTo { - LessThanOrEqualTo() = delete; - explicit LessThanOrEqualTo(size_t value) : value(value){}; - const size_t value; - - bool operator()(size_t value) const { return value <= this->value; } -}; - // Callable object to check if the input is greater than or equal to specified // `value`. struct GreaterThanOrEqualTo { diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index 45380a19e..d157e9494 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -66,9 +66,10 @@ customInferContractionDims(linalg::LinalgOp linalgOp) { static bool isMatchingAffineResult(linalg::LinalgOp linalgOp, AffineExpr expr, ArrayRef dimPos) { - if (dimPos.size() > 2) { + // Expecting dimPos.size() == 1 for normal dim and == 2 for vnni dim + if (dimPos.size() > 2) return false; - } + auto firstDim = getAffineDimExpr(dimPos[0], linalgOp.getContext()); if (dimPos.size() == 1) return firstDim == expr; @@ -95,10 +96,10 @@ static bool isMatchingAffineResult(linalg::LinalgOp linalgOp, AffineExpr expr, (cst_affine.getValue() == 2 || cst_affine.getValue() == 4); } -// Return the position of `dim` in the codomain of `operand`. -static std::optional getPosInCodomain(ArrayRef dimPos, - OpOperand *operand, - linalg::LinalgOp linalgOp) { +// Return the position of linalg loop `dim` in the domain of `operand`. +static std::optional getPosInDomain(ArrayRef dimPos, + OpOperand *operand, + linalg::LinalgOp linalgOp) { assert(operand->getOwner() == linalgOp); auto map = linalgOp.getMatchingIndexingMap(operand); for (unsigned i = 0, numResults = map.getNumResults(); i < numResults; i++) { @@ -118,13 +119,12 @@ inferBrgemmInfo(linalg::LinalgOp linalgOp, // 2. dims.k.size() == 3: VNNI, K = dims.k[1] * dims.k[2] unsigned batchPos = dims.k.front(); SmallVector kPos; - if (dims.k.size() == 2) { + if (dims.k.size() == 2) kPos = {dims.k[1]}; - } else if (dims.k.size() == 3) { + else if (dims.k.size() == 3) kPos = {dims.k[1], dims.k[2]}; - } else { + else return failure(); - } LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Candidate dims: " << "\n"); @@ -132,25 +132,27 @@ inferBrgemmInfo(linalg::LinalgOp linalgOp, << "\n"); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] n pos in affine: " << nPos << "\n"); - for (auto kp : kPos) { + for (auto kp : kPos) LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] k pos in affine: " << kp << "\n"); - } LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] batch pos in affine: " << batchPos << "\n"); auto checkStridesAndGetLda = [&](ArrayRef minorDim, ArrayRef majorDim, OpOperand *operand, bool allowVnni) -> FailureOr { - auto minorDimPosInCodomain = getPosInCodomain(minorDim, operand, linalgOp); - auto majorDimPosInCodomain = getPosInCodomain(majorDim, operand, linalgOp); - if (!minorDimPosInCodomain || !majorDimPosInCodomain) + std::optional minorDimPosInDomain = + getPosInDomain(minorDim, operand, linalgOp); + std::optional majorDimPosInDomain = + getPosInDomain(majorDim, operand, linalgOp); + if (!minorDimPosInDomain || !majorDimPosInDomain) return failure(); - auto stridesOnOperand = utils::getStaticStrides(operand->get()); + FailureOr> stridesOnOperand = + utils::getStaticStrides(operand->get()); if (failed(stridesOnOperand)) return failure(); - auto minorDimLd = (*stridesOnOperand)[*minorDimPosInCodomain]; - auto majorDimLd = (*stridesOnOperand)[*majorDimPosInCodomain]; + auto minorDimLd = (*stridesOnOperand)[*minorDimPosInDomain]; + auto majorDimLd = (*stridesOnOperand)[*majorDimPosInDomain]; if (minorDimLd != 1) { // VNNI format exists, special treatment to align LD with non-VNNI format if (!allowVnni || (minorDimLd != 2 && minorDimLd != 4)) @@ -165,35 +167,41 @@ inferBrgemmInfo(linalg::LinalgOp linalgOp, OpOperand *operandC = &linalgOp.getDpsInitsMutable()[0]; // A(m, k) - auto lda = checkStridesAndGetLda(kPos, {mPos}, operandA, false); + FailureOr lda = checkStridesAndGetLda(kPos, {mPos}, operandA, false); if (failed(lda)) return failure(); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on A: OK\n"); // B(k, n) // note: B does not use VNNI format K affine - auto ldb = checkStridesAndGetLda({nPos}, {kPos[0]}, operandB, true); + FailureOr ldb = + checkStridesAndGetLda({nPos}, {kPos[0]}, operandB, true); if (failed(ldb)) return failure(); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on B: OK\n"); // C(m, n) - auto ldc = checkStridesAndGetLda({nPos}, {mPos}, operandC, false); + FailureOr ldc = + checkStridesAndGetLda({nPos}, {mPos}, operandC, false); if (failed(ldc)) return failure(); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on C: OK\n"); int64_t strideA = 1; int64_t strideB = 1; - auto batchPosCodomainA = getPosInCodomain(batchPos, operandA, linalgOp); - auto stridesOnA = utils::getStaticStrides(operandA->get()); - strideA = (*stridesOnA)[*batchPosCodomainA]; - - auto batchPosCodomainB = getPosInCodomain(batchPos, operandB, linalgOp); - auto stridesOnB = utils::getStaticStrides(operandB->get()); - strideB = (*stridesOnB)[*batchPosCodomainB]; - - auto loops = linalgOp.computeStaticLoopSizes(); + std::optional batchPosDomainA = + getPosInDomain(batchPos, operandA, linalgOp); + FailureOr> stridesOnA = + utils::getStaticStrides(operandA->get()); + strideA = (*stridesOnA)[*batchPosDomainA]; + + std::optional batchPosDomainB = + getPosInDomain(batchPos, operandB, linalgOp); + FailureOr> stridesOnB = + utils::getStaticStrides(operandB->get()); + strideB = (*stridesOnB)[*batchPosDomainB]; + + SmallVector loops = linalgOp.computeStaticLoopSizes(); auto kSize = kPos.size() == 1 ? loops[kPos[0]] : (loops[kPos[0]] * loops[kPos[1]]); From 59d63c79cea1a1ab633dc13d35fc291ce9c70dec Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Mon, 19 Aug 2024 22:19:26 -0700 Subject: [PATCH 50/93] remove unnecessary header --- lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp index 86e75cee6..f2d837318 100644 --- a/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp +++ b/lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmNaive.cpp @@ -7,12 +7,8 @@ //===----------------------------------------------------------------------===// #include -#include #include #include -#include -#include -#include #include #include "oneapi/dnnl/dnnl_types.h" From 87900a99b0fc053aa0f12e5dd2349fa79a78b4fe Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 20 Aug 2024 22:09:48 -0700 Subject: [PATCH 51/93] fix per reviews --- include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h | 2 ++ lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h b/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h index adb214e10..0c92458ed 100644 --- a/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h +++ b/include/gc/Transforms/Microkernel/BrgemmRuntimeUtils.h @@ -29,6 +29,8 @@ static inline int64_t getDnnlDataTypeVal(RewriterBase &rewriter, assert(tattr); if (tattr == TypeAttr::get(FloatType::getF32(context))) { return static_cast(dnnl_f32); + } else if (tattr == TypeAttr::get(FloatType::getF64(context))) { + return static_cast(dnnl_f64); } else if (tattr == TypeAttr::get(FloatType::getBF16(context))) { return static_cast(dnnl_bf16); } else if (tattr == TypeAttr::get(FloatType::getF16(context))) { diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index d157e9494..24ed43c9b 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -234,7 +234,6 @@ static FailureOr getBrgemmInfo(linalg::LinalgOp linalgOp) { .output(MatchAll(), HasStaticStrides()) .input(MatchAll(), HasStaticStrides()) .operation(NumOfLoops(GreaterThanOrEqualTo(3))); - // clang-format on if (!validBrgemmMatcher.match(linalgOp)) return failure(); From 7f94552b0ff6f72e9a82ad3287aa97e5bf6149f9 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 20 Aug 2024 22:27:25 -0700 Subject: [PATCH 52/93] minor fix --- lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index 24ed43c9b..d326910c2 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -117,7 +117,6 @@ inferBrgemmInfo(linalg::LinalgOp linalgOp, // dims.k could be of 2 cases: // 1. dims.k.size() == 2: non-VNNI, K = dims.k[1] // 2. dims.k.size() == 3: VNNI, K = dims.k[1] * dims.k[2] - unsigned batchPos = dims.k.front(); SmallVector kPos; if (dims.k.size() == 2) kPos = {dims.k[1]}; @@ -125,6 +124,7 @@ inferBrgemmInfo(linalg::LinalgOp linalgOp, kPos = {dims.k[1], dims.k[2]}; else return failure(); + unsigned batchPos = dims.k.front(); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Candidate dims: " << "\n"); From 6bbd864d886ca76060cff7ebfd1a11e371b1bd20 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 20 Aug 2024 22:53:13 -0700 Subject: [PATCH 53/93] Revert "Remove pass ConvertMicrokernelToDnnlFunc" This reverts commit 44937e4e76b85977bf3f0753e20894fb4e597365. --- .../Microkernel/MicrokernelPasses.td | 11 + include/gc/Transforms/Utils/ValueUtils.h | 3 +- lib/gc/Transforms/Microkernel/CMakeLists.txt | 1 + .../ConvertMicrokernelToDnnlFunc.cpp | 222 ++++++++++++++++++ lib/gc/Transforms/Utils/ValueUtils.cpp | 29 +++ .../Microkernel/microkernel-to-dnnl-func.mlir | 70 ++++++ .../test/gc/cpu-runner/brgemm-parallel.mlir | 50 ++++ 7 files changed, 384 insertions(+), 2 deletions(-) create mode 100644 lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp create mode 100644 test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir create mode 100644 test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td index 16e11532b..bf9e3c61d 100644 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.td +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.td @@ -40,4 +40,15 @@ def ConvertLinalgToMicrokernel: Pass<"convert-linalg-to-microkernel", "::mlir::f "microkernel::MicrokernelDialect"]; } +def ConvertMicrokernelToDnnlFunc: Pass<"convert-microkernel-to-dnnl-func", "::mlir::ModuleOp"> { + let summary = "Lower microkernel dialects to dnnl func call"; + let description = [{ + Convert microkernel dialects to runtime function call to oneDNN library. + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", + "LLVM::LLVMDialect", + "microkernel::MicrokernelDialect"]; +} + #endif // GC_DIALECT_MICROKERNELPASSES diff --git a/include/gc/Transforms/Utils/ValueUtils.h b/include/gc/Transforms/Utils/ValueUtils.h index acffd5642..07013bde4 100644 --- a/include/gc/Transforms/Utils/ValueUtils.h +++ b/include/gc/Transforms/Utils/ValueUtils.h @@ -27,8 +27,7 @@ FailureOr> getStaticStrides(Value val); // Return the offset and ptr for `val`. Assert if `val` // is not a memref. -std::pair getPtrAndOffset(OpBuilder &builder, Value val, - Location loc); +std::pair getPtrAndOffset(OpBuilder &builder, Value operand); } // namespace utils } // namespace mlir diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt index 642eaa6ca..e33db7185 100644 --- a/lib/gc/Transforms/Microkernel/CMakeLists.txt +++ b/lib/gc/Transforms/Microkernel/CMakeLists.txt @@ -4,6 +4,7 @@ include(onednn) gc_add_mlir_dialect_library(MLIRMicrokernelTransforms ConvertLinalgToMicrokernel.cpp + ConvertMicrokernelToDnnlFunc.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ diff --git a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp new file mode 100644 index 000000000..966f87c06 --- /dev/null +++ b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp @@ -0,0 +1,222 @@ +//===-- ConvertMicrokernelToDnnlFunc.cpp - Lower to dnnl funcs --*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" +#include "gc/Transforms/Utils/ValueUtils.h" + +namespace mlir::microkernel { +#define GEN_PASS_DEF_CONVERTMICROKERNELTODNNLFUNC +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define DEBUG_TYPE "convert-microkernel-to-dnnl-func" + +static func::CallOp createFuncCall(RewriterBase &rewriter, Location loc, + ModuleOp module, const std::string &funcName, + ArrayRef operands, + ArrayRef operandTypes, + ArrayRef resultTypes) { + FlatSymbolRefAttr fnName = SymbolRefAttr::get(module->getContext(), funcName); + auto fnType = rewriter.getFunctionType(operandTypes, resultTypes); + + if (!module.lookupSymbol(fnName.getAttr())) { + OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(module.getBody(), + std::prev(module.getBody()->end())); + func::FuncOp funcOp = + rewriter.create(loc, fnName.getValue(), fnType); + funcOp.setPrivate(); + } + + func::CallOp call = rewriter.create(loc, fnName.getValue(), + resultTypes, operands); + return call; +} + +class ConvertBrgemmDispatchOpRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + // runtime func for dnnl brgemm dispatch: + // int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, + // int64_t LDB, int64_t LDC, int64_t stride_a, int64_t stride_b, float beta, + // int64_t dtypeA, int64_t dtypeB); + LogicalResult matchAndRewrite(microkernel::BrgemmDispatchOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + ModuleOp module = op->template getParentOfType(); + + SmallVector operands; + SmallVector operandTypes; + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + FloatType float32 = FloatType::getF32(rewriter.getContext()); + + // M, N, K, LDA, LDB, LDC, stride_a, stride_b + // they are in the same order with BrgemmDispatchOp inputs + ArrayRef inputs = op.getInputsAttr().asArrayRef(); + for (auto input : inputs) { + auto attr = IntegerAttr::get(rewriter.getI64Type(), input); + operands.push_back( + rewriter.create(loc, integer64, attr)); + operandTypes.push_back(integer64); + } + + // beta + auto flags = op.getFlagsAttr(); + float beta = 1.0f; + for (auto flag : flags) { + auto brgemmFlag = dyn_cast_or_null(flag); + if (!brgemmFlag) + return rewriter.notifyMatchFailure(op, "unknown flag for BRGEMM"); + if (brgemmFlag.getValue() == BrgemmFlags::LIST) + return rewriter.notifyMatchFailure( + op, "addr mode BRGEMM not supported yet"); + if (brgemmFlag.getValue() == BrgemmFlags::BETA_0) + beta = 0.0f; + } + auto betaAttr = FloatAttr::get(rewriter.getF32Type(), beta); + operands.push_back( + rewriter.create(loc, float32, betaAttr)); + operandTypes.push_back(float32); + + // dtypeA, dtypeB + auto dtypes = op.getDataType(); + if (dtypes.size() != 2) + return rewriter.notifyMatchFailure( + op, "invalid number of DataType for BRGEMM"); + auto dtypeAAttr = IntegerAttr::get(rewriter.getI64Type(), + getDnnlDataTypeVal(rewriter, dtypes[0])); + auto dtypeBAttr = IntegerAttr::get(rewriter.getI64Type(), + getDnnlDataTypeVal(rewriter, dtypes[1])); + operands.push_back( + rewriter.create(loc, integer64, dtypeAAttr)); + operandTypes.push_back(integer64); + operands.push_back( + rewriter.create(loc, integer64, dtypeBAttr)); + operandTypes.push_back(integer64); + + func::CallOp call = + createFuncCall(rewriter, loc, module, DNNL_BRGEMM_DISPATCH_NAME, + operands, operandTypes, {integer64}); + rewriter.replaceOp(op, call.getResult(0)); + return success(); + } +}; + +class ConvertBrgemmPrologueOpRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + // dnnl runtime func for brgemm set hw context: + // void dnnl_brgemm_tileconfig(int64_t kernel_idx); + LogicalResult matchAndRewrite(microkernel::BrgemmPrologueOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + ModuleOp module = op->template getParentOfType(); + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + func::CallOp call = + createFuncCall(rewriter, loc, module, DNNL_BRGEMM_TILECFG_NAME, + op.getInputs(), {integer64}, {}); + rewriter.replaceOp(op, call); + return success(); + } +}; + +class ConvertBrgemmOpRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + // runtime func for stride mode dnnl brgemm execution: + // void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void + // *B, uint64_t B_offset, void *C, uint64_t C_offset, int num) + LogicalResult matchAndRewrite(microkernel::BrgemmOp op, + PatternRewriter &rewriter) const final { + // currently only support stride mode, directly call it + // TODO(haixin): support addr mode execution, through detecting dispatch + // target + + auto context = rewriter.getContext(); + Location loc = op.getLoc(); + ModuleOp module = op->template getParentOfType(); + + SmallVector operands; + SmallVector operandTypes; + + auto raw_operands = op->getOperands(); + size_t raw_op_cnt = 0; + for (Value operand : raw_operands) { + if (raw_op_cnt++ >= 5) { + // drop the last operand for `addr list length` + break; + } + Type operandType = operand.getType(); + if (auto memrefType = dyn_cast(operandType)) { + Type basePtrType = LLVM::LLVMPointerType::get(context); + auto [ptr, offset] = utils::getPtrAndOffset(rewriter, operand); + operands.push_back(ptr); + operands.push_back(offset); + operandTypes.push_back(basePtrType); + operandTypes.push_back(rewriter.getIndexType()); // offset + } else { + operands.push_back(operand); + operandTypes.push_back(operand.getType()); + } + } + + createFuncCall(rewriter, loc, module, DNNL_BRGEMM_EXECUTE_NAME, operands, + operandTypes, {}); + rewriter.eraseOp(op); + return success(); + } +}; + +class ConvertBrgemmEpilogueOpRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + // dnnl runtime func for brgemm release hw context: + // void dnnl_brgemm_tilerelease(); + LogicalResult matchAndRewrite(microkernel::BrgemmEpilogueOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + ModuleOp module = op->template getParentOfType(); + func::CallOp call = createFuncCall( + rewriter, loc, module, DNNL_BRGEMM_TILERELEASE_NAME, {}, {}, {}); + rewriter.replaceOp(op, call); + return success(); + } +}; + +class ConvertMicrokernelToDnnlFunc + : public impl::ConvertMicrokernelToDnnlFuncBase< + ConvertMicrokernelToDnnlFunc> { +public: + using impl::ConvertMicrokernelToDnnlFuncBase< + ConvertMicrokernelToDnnlFunc>::ConvertMicrokernelToDnnlFuncBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns + .add( + &getContext()); + + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) + signalPassFailure(); + } +}; + +} // namespace mlir::microkernel diff --git a/lib/gc/Transforms/Utils/ValueUtils.cpp b/lib/gc/Transforms/Utils/ValueUtils.cpp index d565d0cf8..8750042ee 100644 --- a/lib/gc/Transforms/Utils/ValueUtils.cpp +++ b/lib/gc/Transforms/Utils/ValueUtils.cpp @@ -116,5 +116,34 @@ FailureOr> getStaticStrides(Value value) { return strides; } +std::pair getPtrAndOffset(OpBuilder &builder, Value operand) { + auto memrefType = dyn_cast(operand.getType()); + assert(memrefType && "Expect a memref value"); + + Location loc = operand.getDefiningOp()->getLoc(); + OpBuilder::InsertionGuard guard(builder); + // Insert right after operand producer for better opt chances. + builder.setInsertionPointAfterValue(operand); + + MemRefType baseMemrefType = MemRefType::get({}, memrefType.getElementType()); + Type basePtrType = builder.getIndexType(); + Type offsetType = builder.getIndexType(); + SmallVector sizesTypes(memrefType.getRank(), offsetType); + SmallVector stridesTypes(memrefType.getRank(), offsetType); + auto meta = builder.create( + loc, baseMemrefType, offsetType, sizesTypes, stridesTypes, operand); + Value alignedPointerAsIndex = + builder.create(loc, basePtrType, + operand); + Value alignedPointerAsI64 = builder.create( + loc, builder.getIntegerType(64), alignedPointerAsIndex); + // TODO: non-POD will require an LLVMTypeConverter. + Value alignedPointer = builder.create( + loc, LLVM::LLVMPointerType::get(builder.getContext()), + alignedPointerAsI64); + Value offset = meta.getOffset(); + return std::make_pair(alignedPointer, offset); +} + } // namespace utils } // namespace mlir diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir new file mode 100644 index 000000000..1520ae069 --- /dev/null +++ b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir @@ -0,0 +1,70 @@ +// RUN: gc-opt %s -convert-microkernel-to-dnnl-func -cse -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @basic_convert() { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg0, %arg1) in (4, 8) { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %1 = arith.addf %in, %in_7 : f32 + linalg.yield %1 : f32 + } + %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.maximumf %in, %cst : f32 + linalg.yield %1 : f32 + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + return + } +} + +// CHECK-LABEL: dnnl_brgemm_execute +// CHECK-LABEL: dnnl_brgemm_dispatch +// CHECK-LABEL: basic_convert +// CHECK: %[[CST3:.+]] = arith.constant 3 : i64 +// CHECK: %[[CST1F:.+]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[CST1024:.+]] = arith.constant 1024 : i64 +// CHECK: %[[CST32:.+]] = arith.constant 32 : i64 +// CHECK: %[[CST0:.+]] = arith.constant 0 : index +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 + +// CHECK: %[[ptrC:.+]] = memref.extract_aligned_pointer_as_index %[[memrefC:.+]] : memref<32x32xf32> -> index +// CHECK-NEXT: %[[idxC:.+]] = arith.index_cast %[[ptrC]] : index to i64 +// CHECK-NEXT: %[[llvmptrC:.+]] = llvm.inttoptr %[[idxC]] : i64 to !llvm.ptr + +// CHECK: %[[bbA:.+]], %[[offA:.+]], %[[szA:.+]]:3, %[[strdA:.+]]:3 = memref.extract_strided_metadata %[[memrefA:.+]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index +// CHECK-NEXT: %[[ptrA:.+]] = memref.extract_aligned_pointer_as_index %[[memrefA]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> index +// CHECK-NEXT: %[[idxA:.+]] = arith.index_cast %[[ptrA]] : index to i64 +// CHECK-NEXT: %[[llvmptrA:.+]] = llvm.inttoptr %[[idxA]] : i64 to !llvm.ptr + +// CHECK: %[[bbB:.+]], %[[offB:.+]], %[[szB:.+]]:3, %[[strdB:.+]]:3 = memref.extract_strided_metadata %[[memrefB:.+]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index +// CHECK-NEXT: %[[ptrB:.+]] = memref.extract_aligned_pointer_as_index %[[memrefB]] : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -> index +// CHECK-NEXT: %[[idxB:.+]] = arith.index_cast %[[ptrB]] : index to i64 +// CHECK-NEXT: %[[llvmptrB:.+]] = llvm.inttoptr %[[idxB]] : i64 to !llvm.ptr + +// CHECK: %[[KERNEL:.+]] = func.call @dnnl_brgemm_dispatch(%[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST32]], %[[CST1024]], %[[CST1024]], %[[CST1F]], %[[CST3]], %[[CST3]]) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 +// CHECK-NOT: microkernel.brgemm.prologue(%[[TMP:.+]]) : (i64) -> () + +// CHECK: func.call @dnnl_brgemm_execute(%[[KERNEL]], %[[llvmptrA]], %[[offA]], %[[llvmptrB]], %[[offB]], %[[llvmptrC]], %[[CST0]], %[[CST16]]) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () +// CHECK-NOT: microkernel.brgemm.epilogue(%[[KERNEL]]) : (i64) -> () + +// ----- diff --git a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir new file mode 100644 index 000000000..ad436da0c --- /dev/null +++ b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir @@ -0,0 +1,50 @@ +// RUN: gc-opt %s --convert-microkernel-to-dnnl-func --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc : memref<4x16x32x32xf32>) + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<8x16x32x32xf32>) + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) + scf.forall (%arg0, %arg1) in (4, 8) { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %1 = arith.addf %in, %in_7 : f32 + linalg.yield %1 : f32 + } + %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.maximumf %in, %cst : f32 + linalg.yield %1 : f32 + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + return + } + + func.func @main() { + call @simple_brgemm() : ()->() + // COM: parallelcpu.printf "BRGEMM DONE\n" + return + } + + // COM: CHECK: BRGEMM DONE +} From 2ac30c6a22b4b216255975e0616107ba4de0f025 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 23 Jul 2024 01:08:50 -0700 Subject: [PATCH 54/93] add basic BrgemmOnTensorOp --- .../gc/Dialect/Microkernel/MicrokernelOps.td | 52 +++ lib/gc/Dialect/Microkernel/MicrokernelOps.cpp | 309 ++++++++++++++---- .../Microkernel/linalg-to-microkernel.mlir | 8 +- .../Microkernel/microkernel-to-dnnl-func.mlir | 2 +- .../test/gc/cpu-runner/brgemm-parallel.mlir | 2 +- 5 files changed, 302 insertions(+), 71 deletions(-) diff --git a/include/gc/Dialect/Microkernel/MicrokernelOps.td b/include/gc/Dialect/Microkernel/MicrokernelOps.td index 76e5424c6..4d65f5995 100644 --- a/include/gc/Dialect/Microkernel/MicrokernelOps.td +++ b/include/gc/Dialect/Microkernel/MicrokernelOps.td @@ -11,14 +11,66 @@ include "MicrokernelDialect.td" include "gc/Dialect/Microkernel/MicrokernelEnum.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +class StaticTensorRankOf allowedTypes, list ranks> : + Type.predicate, + HasAnyRankOfPred, HasStaticShapePred]>, + !interleave(!foreach(rank, ranks, rank # "D"), "/") # " static " # + TensorOf.summary, "::mlir::TensorType">; + class StaticMemRefRankOf allowedTypes, list ranks> : Type.predicate, HasAnyRankOfPred, HasStaticShapePred]>, !interleave(!foreach(rank, ranks, rank # "D"), "/") # " static " # MemRefOf.summary, "::mlir::MemRefType">; +def BrgemmTensor : StaticTensorRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>; + +def BrgemmTensorOrMemRef : AnyTypeOf<[StaticTensorRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>, + StaticMemRefRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>]>; + +def Microkernel_BrgemmOnTensorOp : Microkernel_Op<"brgemm_on_tensor", + [DeclareOpInterfaceMethods, + DestinationStyleOpInterface]> { + let summary = "Abstract Op that execute brgemm kernel on tensors."; + let description = [{ + The operation has the following arguments: + 1) Tensors or MemRefs of operand A/B; + 2) The batch dims and leading dims of operand A/B; + And has the following outputs: + 1) Tensor of operand C; + }]; + + let arguments = (ins Variadic:$inputs, + BrgemmTensorOrMemRef:$init, + ConfinedAttr]>:$batchDims, + ConfinedAttr]>:$leadingDims, + TypedArrayAttrBase:$flags); + let results = (outs Variadic:$output); + + let extraClassDeclaration = [{ + Value getOperandA() { return getInputs()[0]; } + Value getOperandB() { return getInputs()[1]; } + Value getOperandC() { return getInit(); } + + int64_t getBatchDimA() { return getBatchDims()[0]; } + int64_t getLeadingDimA() { return getLeadingDims()[0]; } + + int64_t getBatchDimB() { return getBatchDims()[1]; } + int64_t getLeadingDimB() { return getLeadingDims()[1]; } + + MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } + }]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; + let hasFolder = 1; +} + def Microkernel_BrgemmDispatchOp : Microkernel_Op<"brgemm.dispatch", [Pure]> { let summary = "JIT the brgemm microkernel given the parameters"; let description = [{ diff --git a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp index 1cf8c0d05..fc2455882 100644 --- a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp +++ b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp @@ -19,20 +19,22 @@ namespace mlir { namespace microkernel { -constexpr std::string_view INPUTS = "inputs"; -constexpr std::string_view DATA_TYPE = "data_type"; -constexpr std::string_view FLAGS_NAME = "flags"; +constexpr std::string_view INPUTS_ASM_NAME = "ins"; +constexpr std::string_view OUTPUTS_ASM_NAME = "outs"; +constexpr std::string_view DATA_TYPE_ASM_NAME = "data_type"; +constexpr std::string_view FLAGS_ASM_NAME = "flags"; +constexpr std::string_view BATCH_DIMS_ASM_NAME = "batch_dims"; +constexpr std::string_view LEADING_DIMS_ASM_NAME = "leading_dims"; -template -static void printInputImpl(OpAsmPrinter &printer, OpTy op) { - printer << " [" << op.getInputs() << ']'; -} +constexpr std::string_view INPUTS_ATTR_NAME = "inputs"; +constexpr std::string_view BATCH_DIMS_ATTR_NAME = "batchDims"; +constexpr std::string_view LEADING_DIMS_ATTR_NAME = "leadingDims"; template static void printFlagsImpl(OpAsmPrinter &printer, const std::function &fn, const std::string_view &flagsName) { - printer << " " << flagsName << " = ("; + printer << " " << flagsName << "("; llvm::interleaveComma(fn(), printer, [&](auto &flag) { printer << stringifyEnum(cast(flag).getValue()); }); @@ -41,7 +43,7 @@ static void printFlagsImpl(OpAsmPrinter &printer, template static void printDataTypeImpl(OpAsmPrinter &printer, OpTy op) { - printer << DATA_TYPE << " = ("; + printer << DATA_TYPE_ASM_NAME << "("; auto dataTypes = op.getDataType(); for (size_t idx = 0; idx < dataTypes.size(); idx++) { printer.printAttribute(dataTypes[idx]); @@ -64,38 +66,28 @@ static ParseResult parseEnum(EnumClass &value, OpAsmParser &parser) { return success(); } -static ParseResult parseOperandImpl(OpAsmParser &parser, - OperationState &result) { - DenseI64ArrayAttr kindAttr; - if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, INPUTS, - result.attributes)) { - return failure(); - } +template +static ParseResult parseArrayAttrImpl(OpAsmParser &parser, + OperationState &result, + const std::string_view &attrAsmName, + const std::string_view &attrName) { auto &builder = parser.getBuilder(); - result.addTypes(builder.getIntegerType(64)); - return success(); -} - -static ParseResult parseDataTypeImpl(OpAsmParser &parser, - OperationState &result) { - auto &builder = parser.getBuilder(); - if (parser.parseKeyword(DATA_TYPE) || parser.parseEqual() || - parser.parseLParen()) + if (parser.parseKeyword(attrAsmName) || parser.parseLParen()) return failure(); - SmallVector dataTypes; - auto parseTypeAttr = [&]() -> ParseResult { - Attribute dataType; - if (parser.parseAttribute(dataType)) + SmallVector attrs; + auto parseAttr = [&]() -> ParseResult { + Attribute attr; + if (parser.parseAttribute(attr)) return failure(); - if (!isa(dataType)) + if (!isa(attr)) return failure(); - dataTypes.push_back(dataType); + attrs.push_back(attr); return success(); }; - if (parser.parseCommaSeparatedList(parseTypeAttr) || parser.parseRParen()) + if (parser.parseCommaSeparatedList(parseAttr) || parser.parseRParen()) return failure(); - result.addAttribute(DATA_TYPE, builder.getArrayAttr(dataTypes)); + result.addAttribute(attrName, builder.getArrayAttr(attrs)); return success(); } @@ -103,8 +95,7 @@ template static ParseResult parseFlagsImpl(OpAsmParser &parser, OperationState &result, const std::string_view &flagsName) { auto &builder = parser.getBuilder(); - if (parser.parseKeyword(flagsName) || parser.parseEqual() || - parser.parseLParen()) + if (parser.parseKeyword(flagsName) || parser.parseLParen()) return failure(); SmallVector flags; @@ -121,6 +112,25 @@ static ParseResult parseFlagsImpl(OpAsmParser &parser, OperationState &result, return success(); } +static ParseResult parseOperandsImpl(OpAsmParser &parser, + OperationState &result, + const std::string_view &operandsName) { + SMLoc operandsLoc; + SmallVector types; + SmallVector operands; + + if (parser.parseKeyword(operandsName) || parser.parseLParen()) + return failure(); + operandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(operands) || parser.parseColonTypeList(types) || + parser.parseRParen()) + return failure(); + + if (parser.resolveOperands(operands, types, operandsLoc, result.operands)) + return failure(); + return success(); +} + template static LogicalResult verifyUniquenessAndConsistency(ArrayAttr flags, Operation *op, @@ -142,24 +152,215 @@ verifyUniquenessAndConsistency(ArrayAttr flags, Operation *op, return success(); } +static LogicalResult verifyBrgemmFlags(ArrayAttr flags, Operation *op, + const std::string_view &flagsName) { + // Verify flags. + if (failed(verifyUniquenessAndConsistency(flags, op, flagsName))) + return failure(); + + bool strideSet = false; + bool listSet = false; + for (auto flag : flags) { + if (cast(flag).getValue() == BrgemmFlags::STRIDE) { + strideSet = true; + } + if (cast(flag).getValue() == BrgemmFlags::LIST) { + listSet = true; + } + } + // VNNI flags must be specified only for bf16 type + if (strideSet && listSet) { + return op->emitOpError() + << "stride and addr flags conflict with each other"; + } + + return success(); +} + +///////////////////////////////////////////////////// +// Start of BrgemmOnTensorOp + +ParseResult BrgemmOnTensorOp::parse(OpAsmParser &parser, + OperationState &result) { + + if (failed(parseOperandsImpl(parser, result, INPUTS_ASM_NAME))) + return failure(); + if (failed(parseOperandsImpl(parser, result, OUTPUTS_ASM_NAME))) + return failure(); + + if (parseArrayAttrImpl(parser, result, BATCH_DIMS_ASM_NAME, + BATCH_DIMS_ATTR_NAME)) + return failure(); + if (parseArrayAttrImpl(parser, result, LEADING_DIMS_ASM_NAME, + LEADING_DIMS_ATTR_NAME)) + return failure(); + + if (failed(parseFlagsImpl(parser, result, FLAGS_ASM_NAME))) + return failure(); + + SmallVector resultTypes; + if (parser.parseOptionalArrowTypeList(resultTypes)) + return failure(); + result.addTypes(resultTypes); + return success(); +} + +void BrgemmOnTensorOp::print(OpAsmPrinter &printer) { + BrgemmOnTensorOp op = *this; + printer << " " << INPUTS_ASM_NAME << "(" << op.getInputs() << ")"; + printer << " " << OUTPUTS_ASM_NAME << "(" << op.getInit() << ")"; + printer << " " << BATCH_DIMS_ASM_NAME << "(" << op.getBatchDims() << ")"; + printer << " " << LEADING_DIMS_ASM_NAME << "(" << op.getLeadingDims() << ")"; + + auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); }; + printFlagsImpl(printer, getOpFlags, FLAGS_ASM_NAME); + + auto resultTypes = op.getResultTypes(); + if (resultTypes.empty()) + return; + printer.printOptionalArrowTypeList(resultTypes); +} + +LogicalResult BrgemmOnTensorOp::fold(FoldAdaptor, + SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} + +void BrgemmOnTensorOp::getEffects( + SmallVectorImpl> + &effects) { + if (hasPureTensorSemantics()) + return; + + BrgemmOnTensorOp op = *this; + + for (auto [index, operand] : llvm::enumerate(op.getDpsInputs())) { + if (!llvm::isa(operand.getType())) + continue; + effects.emplace_back( + MemoryEffects::Read::get(), &op->getOpOperand(index), /*stage=*/0, + /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get()); + } + + auto flags = op.getFlags(); + bool isInit = false; + for (auto flag : flags) { + if (cast(flag).getValue() == BrgemmFlags::BETA_0) { + isInit = true; + break; + } + } + + assert(op.getDpsInitsMutable().size() == 1 && + "Expecting single DPS init operand"); + OpOperand &operand = op.getDpsInitsMutable()[0]; + if (!llvm::isa(operand.get().getType())) + return; + if (!isInit) { + effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } + effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); +} + +static inline ArrayRef getShapedValueShape(Value val) { + assert((llvm::isa(val.getType()) || + llvm::isa(val.getType())) && + "Expecting shaped value"); + if (auto tensorTy = dyn_cast_or_null(val.getType())) { + return tensorTy.getShape(); + } + auto memrefTy = dyn_cast_or_null(val.getType()); + return memrefTy.getShape(); +} + +LogicalResult BrgemmOnTensorOp::verify() { + BrgemmOnTensorOp op = *this; + + size_t expectedInputSize = 2; + SmallVector ins; + for (auto in : op.getInputs()) + ins.push_back(in); + Value out = op.getInit(); + ArrayRef batchDims = op.getBatchDims(); + ArrayRef leadingDims = op.getLeadingDims(); + if (ins.size() != expectedInputSize && + batchDims.size() != expectedInputSize && + leadingDims.size() != expectedInputSize) + return op.emitOpError() + << "expect inputs and its related info to be size 2\n"; + + ArrayRef dimA = getShapedValueShape(ins[0]); + ArrayRef dimB = getShapedValueShape(ins[1]); + ArrayRef dimC = getShapedValueShape(out); + if (dimA.size() != 3) + return op.emitOpError() << "expect input A to be 3D\n"; + if (dimB.size() != 3 && dimB.size() != 4) + return op.emitOpError() << "expect input B to be 3D or 4D\n"; + if (dimB.size() == 4 && (dimB[3] != 2 && dimB[3] != 4)) + return op.emitOpError() << "expect input B vnni step to be 2 or 4\n"; + if (dimC.size() != 2) + return op.emitOpError() << "expect input C to be 2D\n"; + for (auto dim : batchDims) + if (dim >= 2) + return op.emitOpError() << "batch dim cannot be last dim, as last dim " + "should be contigious\n"; + for (auto dim : leadingDims) + if (dim >= 2) + return op.emitOpError() << "leading dim cannot be last dim, as last dim " + "should be contigious\n"; + + auto batchA = dimA[batchDims[0]]; + auto batchB = dimB[batchDims[1]]; + auto majorDimA = dimA[leadingDims[0]]; + auto majorDimB = dimB.size() == 3 ? dimB[leadingDims[1]] + : (dimB[leadingDims[1]] * dimB[3]); + auto minorDimA = dimA[2]; + auto minorDimB = dimB[2]; + auto majorDimC = dimC[0]; + auto minorDimC = dimC[1]; + if (batchA != batchB) + return op.emitOpError() << "unmatched batch dim of A and B\n"; + if (minorDimA != majorDimB) + return op.emitOpError() << "unmatched matmul dim of A and B\n"; + if (majorDimA != majorDimC || minorDimB != minorDimC) + return op.emitOpError() << "unmatched matmul dim of A, B and C\n"; + + return verifyBrgemmFlags(op.getFlags(), op, FLAGS_ASM_NAME); +} + ///////////////////////////////////////////////////// // Start of BrgemmDispatchOp void BrgemmDispatchOp::print(OpAsmPrinter &printer) { - printInputImpl(printer, *this); + BrgemmDispatchOp op = *this; + + printer << " [" << op.getInputs() << ']'; + auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); }; - printFlagsImpl(printer, getOpFlags, FLAGS_NAME); + printFlagsImpl(printer, getOpFlags, FLAGS_ASM_NAME); + printDataTypeImpl(printer, *this); } ParseResult BrgemmDispatchOp::parse(OpAsmParser &parser, OperationState &result) { - if (failed(parseOperandImpl(parser, result)) || - failed(parseFlagsImpl(parser, result, FLAGS_NAME))) + auto &builder = parser.getBuilder(); + result.addTypes(builder.getIntegerType(64)); + + DenseI64ArrayAttr inputAttr; + if (parser.parseCustomAttributeWithFallback( + inputAttr, Type{}, INPUTS_ATTR_NAME, result.attributes)) + return failure(); + if (failed(parseFlagsImpl(parser, result, FLAGS_ASM_NAME))) return failure(); - if (failed(parseDataTypeImpl(parser, result))) + if (failed(parseArrayAttrImpl(parser, result, DATA_TYPE_ASM_NAME, + DATA_TYPE_ASM_NAME))) return failure(); - return parser.parseOptionalAttrDict(result.attributes); + return success(); } static LogicalResult verifyBrgemmDataTypes(ArrayAttr dtypes, @@ -190,30 +391,8 @@ static LogicalResult verifyBrgemmDataTypes(ArrayAttr dtypes, return success(); } -static LogicalResult verifyBrgemmFlags(ArrayAttr flags, BrgemmDispatchOp op, - const std::string_view &flagsName) { - // Verify flags. - if (failed(verifyUniquenessAndConsistency(flags, op, flagsName))) - return failure(); - - bool strideSet = false; - bool listSet = false; - for (auto flag : flags) { - if (cast(flag).getValue() == BrgemmFlags::STRIDE) - strideSet = true; - if (cast(flag).getValue() == BrgemmFlags::LIST) - listSet = true; - } - // VNNI flags must be specified only for bf16 type - if (strideSet && listSet) - return op->emitOpError() - << "stride and addr flags conflict with each other"; - - return success(); -} - LogicalResult BrgemmDispatchOp::verify() { - BrgemmDispatchOp &op = *this; + BrgemmDispatchOp op = *this; // 'inputs' = [m, n, k, lda, ldb, ldc, stride_a, stride_b] for BRGEMM. size_t expected = 8; size_t numInputs = op.getInputs().size(); @@ -239,7 +418,7 @@ LogicalResult BrgemmDispatchOp::verify() { return op.emitOpError() << "expect ldc to be >= of dimension n\n"; // Verify dispatch flags. - return verifyBrgemmFlags(op.getFlags(), op, FLAGS_NAME); + return verifyBrgemmFlags(op.getFlags(), op, FLAGS_ASM_NAME); } ///////////////////////////////////////////////////// diff --git a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir index 224329f62..161121f0c 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir @@ -35,7 +35,7 @@ func.func @basic_linalg_to_microkernel() { // CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> // CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> // CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) // CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () // CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () // CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () @@ -77,7 +77,7 @@ func.func @vnni_linalg_to_microkernel() { // CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> // CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> // CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) // CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () // CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () // CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () @@ -121,7 +121,7 @@ func.func @basic_linalg_to_microkernel_fusing_fill() { // CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> // CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> // CHECK-NOT: linalg.fill -// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (f32, f32) +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(beta_0, stride) data_type(f32, f32) // CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () // CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () // CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () @@ -165,7 +165,7 @@ func.func @vnni_linalg_to_microkernel_fusing_fill() { // CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> // CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> // CHECK-NOT: linalg.fill -// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(beta_0, stride) data_type(bf16, bf16) // CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () // CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () // CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir index 1520ae069..91a3aefb4 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir @@ -15,7 +15,7 @@ module { linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) microkernel.brgemm.prologue(%0) : (i64) -> () microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%0) : (i64) -> () diff --git a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir index ad436da0c..11c6c50ad 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir @@ -19,7 +19,7 @@ module { linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) microkernel.brgemm.prologue(%0) : (i64) -> () microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%0) : (i64) -> () From 0bdce9ebe7682edbd3dce5688cfb0fad0708c43b Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 23 Jul 2024 03:27:09 -0700 Subject: [PATCH 55/93] add Bufferize support --- .../gc/Dialect/Microkernel/MicrokernelOps.h | 2 + .../gc/Dialect/Microkernel/MicrokernelOps.td | 11 +++ lib/gc/Dialect/Microkernel/MicrokernelOps.cpp | 98 ++++++++++++++++++- 3 files changed, 110 insertions(+), 1 deletion(-) diff --git a/include/gc/Dialect/Microkernel/MicrokernelOps.h b/include/gc/Dialect/Microkernel/MicrokernelOps.h index a478c1dee..4ff6dd358 100644 --- a/include/gc/Dialect/Microkernel/MicrokernelOps.h +++ b/include/gc/Dialect/Microkernel/MicrokernelOps.h @@ -9,12 +9,14 @@ #ifndef GC_DIALECTS_MICROKERNELOPS_H #define GC_DIALECTS_MICROKERNELOPS_H +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "gc/Dialect/Microkernel/MicrokernelDialect.h" diff --git a/include/gc/Dialect/Microkernel/MicrokernelOps.td b/include/gc/Dialect/Microkernel/MicrokernelOps.td index 4d65f5995..41ae47e30 100644 --- a/include/gc/Dialect/Microkernel/MicrokernelOps.td +++ b/include/gc/Dialect/Microkernel/MicrokernelOps.td @@ -11,6 +11,7 @@ include "MicrokernelDialect.td" include "gc/Dialect/Microkernel/MicrokernelEnum.td" +include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -33,6 +34,7 @@ def BrgemmTensorOrMemRef : AnyTypeOf<[StaticTensorRankOf<[F32, BF16, SI32, SI8, def Microkernel_BrgemmOnTensorOp : Microkernel_Op<"brgemm_on_tensor", [DeclareOpInterfaceMethods, + BufferizableOpInterface, DestinationStyleOpInterface]> { let summary = "Abstract Op that execute brgemm kernel on tensors."; let description = [{ @@ -64,6 +66,15 @@ def Microkernel_BrgemmOnTensorOp : Microkernel_Op<"brgemm_on_tensor", int64_t getLeadingDimB() { return getLeadingDims()[1]; } MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } + + bool bufferizesToMemoryRead(OpOperand &, + const bufferization::AnalysisState &); + bool bufferizesToMemoryWrite(OpOperand &, + const bufferization::AnalysisState &); + bool bufferizesToElementwiseAccess(const bufferization::AnalysisState &, + ArrayRef); + LogicalResult bufferize(RewriterBase &, + const bufferization::BufferizationOptions &); }]; let hasVerifier = 1; diff --git a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp index fc2455882..6a1ab3b8b 100644 --- a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp +++ b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp @@ -6,15 +6,27 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + #include "gc/Dialect/Microkernel/MicrokernelOps.h" #include "gc/Dialect/Microkernel/MicrokernelDialect.h" -#include #define GET_OP_CLASSES #include "gc/Dialect/Microkernel/MicrokernelOps.cpp.inc" #include +using namespace mlir::bufferization; + namespace mlir { namespace microkernel { @@ -332,6 +344,90 @@ LogicalResult BrgemmOnTensorOp::verify() { return verifyBrgemmFlags(op.getFlags(), op, FLAGS_ASM_NAME); } +bool BrgemmOnTensorOp::bufferizesToMemoryRead(OpOperand &opOperand, + const AnalysisState &state) { + Operation *op = *this; + auto dpsOp = cast(op); + return !dpsOp.isDpsInit(&opOperand); +} + +bool BrgemmOnTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand, + const AnalysisState &state) { + Operation *op = *this; + auto dpsOp = cast(op); + return dpsOp.isDpsInit(&opOperand); +} + +bool BrgemmOnTensorOp::bufferizesToElementwiseAccess( + const AnalysisState &state, ArrayRef opOperands) { + // This op contains non-parallel reduction loops, + // should return `false` per linalg implementation + return false; +} + +LogicalResult BrgemmOnTensorOp::bufferize(RewriterBase &rewriter, + const BufferizationOptions &options) { + // This implementation refers to linalg's + // `bufferizeDestinationStyleOpInterface` + Operation *op = *this; + auto dpsOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(dpsOp); + + // Nothing to do. This op is already bufferized. + if (dpsOp.hasPureBufferSemantics()) + return success(); + + // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need + // basis. + if (!dpsOp.hasPureTensorSemantics()) + return emitError() << "op does not have pure tensor semantics"; + + // New input operands for the cloned op. + SmallVector newInputBuffers; + newInputBuffers.reserve(dpsOp.getNumDpsInputs()); + for (OpOperand *opOperand : dpsOp.getDpsInputOperands()) { + FailureOr buffer = getBuffer(rewriter, opOperand->get(), options); + if (failed(buffer)) + return failure(); + newInputBuffers.push_back(*buffer); + } + + // New output operands for the cloned op. + SmallVector newOutputBuffers; + for (OpResult opResult : dpsOp->getOpResults()) { + OpOperand *opOperand = dpsOp.getDpsInitOperand(opResult.getResultNumber()); + FailureOr resultBuffer = + getBuffer(rewriter, opOperand->get(), options); + if (failed(resultBuffer)) + return failure(); + newOutputBuffers.push_back(*resultBuffer); + } + + // Merge input/output operands. + SmallVector newOperands = newInputBuffers; + newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); + + // Set insertion point now that potential alloc/dealloc are introduced. + rewriter.setInsertionPoint(dpsOp); + // Clone the op, but use the new operands. Since the new op does not have any + // tensor results, it does not return anything. + OperationState state(dpsOp->getLoc(), dpsOp->getName(), newOperands, + TypeRange{}, dpsOp->getAttrs()); + Operation *newOp = Operation::create(state); + + // We don't want the rewriter tracks an incomplete operation, so insert new + // operation after op was fully constructed. + rewriter.insert(newOp); + + // Replace the results of the old op with the new output buffers. + replaceOpWithBufferizedValues(rewriter, dpsOp, newOutputBuffers); + + return success(); +} + ///////////////////////////////////////////////////// // Start of BrgemmDispatchOp From 7d5160375c7c9f5f63352c6dd5765decbfd221c6 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 00:03:58 -0700 Subject: [PATCH 56/93] add ExpandMicrokernel pass --- .../gc/Dialect/Microkernel/MicrokernelOps.td | 4 +- .../Microkernel/MicrokernelPasses.td | 30 +- lib/gc/Dialect/Microkernel/MicrokernelOps.cpp | 42 ++- lib/gc/Transforms/Microkernel/CMakeLists.txt | 1 + .../ConvertLinalgToMicrokernel.cpp | 2 +- .../ConvertMicrokernelToDnnlFunc.cpp | 9 +- .../Microkernel/ExpandMicrokernel.cpp | 266 ++++++++++++++++++ 7 files changed, 322 insertions(+), 32 deletions(-) create mode 100644 lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp diff --git a/include/gc/Dialect/Microkernel/MicrokernelOps.td b/include/gc/Dialect/Microkernel/MicrokernelOps.td index 41ae47e30..366933e7f 100644 --- a/include/gc/Dialect/Microkernel/MicrokernelOps.td +++ b/include/gc/Dialect/Microkernel/MicrokernelOps.td @@ -32,7 +32,7 @@ def BrgemmTensor : StaticTensorRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>; def BrgemmTensorOrMemRef : AnyTypeOf<[StaticTensorRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>, StaticMemRefRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>]>; -def Microkernel_BrgemmOnTensorOp : Microkernel_Op<"brgemm_on_tensor", +def Microkernel_BrgemmOp : Microkernel_Op<"brgemm", [DeclareOpInterfaceMethods, BufferizableOpInterface, DestinationStyleOpInterface]> { @@ -143,7 +143,7 @@ def Microkernel_BrgemmEpilogueOp : Microkernel_Op<"brgemm.epilogue"> { */ def BrgemmMemRefOrI64 : AnyTypeOf<[StaticMemRefRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>, I64]>; -def Microkernel_BrgemmOp : Microkernel_Op<"brgemm"> { +def Microkernel_BrgemmExecuteOp : Microkernel_Op<"brgemm.execute"> { let summary = "execute the JITed brgemm kernel."; let description = [{ The operation has the following arguments: diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td index bf9e3c61d..dc554aeed 100644 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.td +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.td @@ -24,19 +24,43 @@ def ConvertLinalgToMicrokernel: Pass<"convert-linalg-to-microkernel", "::mlir::f ``` Will be changed into ``` + scf.forall { + linalg.fill ins(...) outs(...) -> tensor<...> + microkernel.brgemm ins(...) outs(...) -> tensor<...> + } + ``` + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", + "linalg::LinalgDialect", + "linalgx::LinalgxDialect", + "microkernel::MicrokernelDialect"]; +} + +def ExpandMicrokernel: Pass<"expand-microkernel", "::mlir::func::FuncOp"> { + let summary = "Expand abstract microkernels into detailed execution phases"; + let description = [{ + Expand abstract microkernels into detailed execution phases + For example: + ``` + scf.forall { + linalg.fill ins(...) outs(...) -> tensor<...> + microkernel.brgemm ins(...) outs(...) -> tensor<...> + } + ``` + Will be changed into + ``` scf.forall { linalg.fill ins(...) outs(...) -> tensor<...> %0 = microkernel.brgemm.dispatch(...) microkernel.brgemm.prologue(%0) - microkernel.brgemm(%0, ...) + microkernel.brgemm.execute(%0, ...) microkernel.brgemm.epilogue(%0) } ``` }]; let dependentDialects = ["func::FuncDialect", "memref::MemRefDialect", - "linalg::LinalgDialect", - "linalgx::LinalgxDialect", "microkernel::MicrokernelDialect"]; } diff --git a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp index 6a1ab3b8b..4d569acc7 100644 --- a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp +++ b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp @@ -190,10 +190,9 @@ static LogicalResult verifyBrgemmFlags(ArrayAttr flags, Operation *op, } ///////////////////////////////////////////////////// -// Start of BrgemmOnTensorOp +// Start of BrgemmOp -ParseResult BrgemmOnTensorOp::parse(OpAsmParser &parser, - OperationState &result) { +ParseResult BrgemmOp::parse(OpAsmParser &parser, OperationState &result) { if (failed(parseOperandsImpl(parser, result, INPUTS_ASM_NAME))) return failure(); @@ -217,8 +216,8 @@ ParseResult BrgemmOnTensorOp::parse(OpAsmParser &parser, return success(); } -void BrgemmOnTensorOp::print(OpAsmPrinter &printer) { - BrgemmOnTensorOp op = *this; +void BrgemmOp::print(OpAsmPrinter &printer) { + BrgemmOp op = *this; printer << " " << INPUTS_ASM_NAME << "(" << op.getInputs() << ")"; printer << " " << OUTPUTS_ASM_NAME << "(" << op.getInit() << ")"; printer << " " << BATCH_DIMS_ASM_NAME << "(" << op.getBatchDims() << ")"; @@ -233,18 +232,17 @@ void BrgemmOnTensorOp::print(OpAsmPrinter &printer) { printer.printOptionalArrowTypeList(resultTypes); } -LogicalResult BrgemmOnTensorOp::fold(FoldAdaptor, - SmallVectorImpl &) { +LogicalResult BrgemmOp::fold(FoldAdaptor, SmallVectorImpl &) { return memref::foldMemRefCast(*this); } -void BrgemmOnTensorOp::getEffects( +void BrgemmOp::getEffects( SmallVectorImpl> &effects) { if (hasPureTensorSemantics()) return; - BrgemmOnTensorOp op = *this; + BrgemmOp op = *this; for (auto [index, operand] : llvm::enumerate(op.getDpsInputs())) { if (!llvm::isa(operand.getType())) @@ -289,8 +287,8 @@ static inline ArrayRef getShapedValueShape(Value val) { return memrefTy.getShape(); } -LogicalResult BrgemmOnTensorOp::verify() { - BrgemmOnTensorOp op = *this; +LogicalResult BrgemmOp::verify() { + BrgemmOp op = *this; size_t expectedInputSize = 2; SmallVector ins; @@ -344,29 +342,29 @@ LogicalResult BrgemmOnTensorOp::verify() { return verifyBrgemmFlags(op.getFlags(), op, FLAGS_ASM_NAME); } -bool BrgemmOnTensorOp::bufferizesToMemoryRead(OpOperand &opOperand, - const AnalysisState &state) { +bool BrgemmOp::bufferizesToMemoryRead(OpOperand &opOperand, + const AnalysisState &state) { Operation *op = *this; auto dpsOp = cast(op); return !dpsOp.isDpsInit(&opOperand); } -bool BrgemmOnTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand, - const AnalysisState &state) { +bool BrgemmOp::bufferizesToMemoryWrite(OpOperand &opOperand, + const AnalysisState &state) { Operation *op = *this; auto dpsOp = cast(op); return dpsOp.isDpsInit(&opOperand); } -bool BrgemmOnTensorOp::bufferizesToElementwiseAccess( - const AnalysisState &state, ArrayRef opOperands) { +bool BrgemmOp::bufferizesToElementwiseAccess(const AnalysisState &state, + ArrayRef opOperands) { // This op contains non-parallel reduction loops, // should return `false` per linalg implementation return false; } -LogicalResult BrgemmOnTensorOp::bufferize(RewriterBase &rewriter, - const BufferizationOptions &options) { +LogicalResult BrgemmOp::bufferize(RewriterBase &rewriter, + const BufferizationOptions &options) { // This implementation refers to linalg's // `bufferizeDestinationStyleOpInterface` Operation *op = *this; @@ -518,7 +516,7 @@ LogicalResult BrgemmDispatchOp::verify() { } ///////////////////////////////////////////////////// -// Start of BrgemmOp +// Start of BrgemmExecuteOp // TODO(haixin): could use compiler-wide VNNI utils? static bool isInVnniLayout(MemRefType memref) { @@ -556,8 +554,8 @@ static bool isTypeSupported(Type outType, Type operandAType, return true; } -LogicalResult BrgemmOp::verify() { - BrgemmOp &brgemmOp = *this; +LogicalResult BrgemmExecuteOp::verify() { + BrgemmExecuteOp &brgemmOp = *this; SmallVector inputs = brgemmOp.getInputs(); // inputs for BRGEMM: kernel id, A memref, B memref, C memref, batch_size, diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt index e33db7185..9064b70db 100644 --- a/lib/gc/Transforms/Microkernel/CMakeLists.txt +++ b/lib/gc/Transforms/Microkernel/CMakeLists.txt @@ -4,6 +4,7 @@ include(onednn) gc_add_mlir_dialect_library(MLIRMicrokernelTransforms ConvertLinalgToMicrokernel.cpp + ExpandMicrokernel.cpp ConvertMicrokernelToDnnlFunc.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index d326910c2..b559bd56a 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -317,7 +317,7 @@ static void replaceOpWithMicrokernelOpSet(PatternRewriter &rewriter, linalgOp->getOperands().end()); invokeOperands.push_back(batchDim); invokeOperands.push_back(lenDim); - rewriter.create(loc, invokeOperands); + rewriter.create(loc, invokeOperands); // create epilogue op & replace original op rewriter.replaceOpWithNewOp(linalgOp, diff --git a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp index 966f87c06..34d8c2cbd 100644 --- a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp @@ -136,13 +136,14 @@ class ConvertBrgemmPrologueOpRewriter } }; -class ConvertBrgemmOpRewriter : public OpRewritePattern { +class ConvertBrgemmExecuteOpRewriter + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; // runtime func for stride mode dnnl brgemm execution: // void dnnl_brgemm_execute(int64_t kernel, void *A, uint64_t A_offset, void // *B, uint64_t B_offset, void *C, uint64_t C_offset, int num) - LogicalResult matchAndRewrite(microkernel::BrgemmOp op, + LogicalResult matchAndRewrite(microkernel::BrgemmExecuteOp op, PatternRewriter &rewriter) const final { // currently only support stride mode, directly call it // TODO(haixin): support addr mode execution, through detecting dispatch @@ -210,7 +211,7 @@ class ConvertMicrokernelToDnnlFunc RewritePatternSet patterns(&getContext()); patterns .add( + ConvertBrgemmExecuteOpRewriter, ConvertBrgemmEpilogueOpRewriter>( &getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); diff --git a/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp new file mode 100644 index 000000000..c7984892e --- /dev/null +++ b/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp @@ -0,0 +1,266 @@ +//===- ConvertLinalgToMicrokernel.cpp - Linalg To Microkernel -*- C++ -*--===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include + +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineMap.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" +#include "gc/Utils/StructuredOpMatcher.h" +#include "gc/Utils/ValueUtils.h" + +namespace mlir::microkernel { +#define GEN_PASS_DEF_CONVERTLINALGTOMICROKERNEL +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define DEBUG_TYPE "convert-linalg-to-microkernel" + +struct BrgemmInfo { + enum BrgemmMode { STRIDE_MODE, LIST_MODE }; + int64_t m; + int64_t n; + int64_t k; + int64_t batchSize; + int64_t addrLen; + + int64_t lda; + int64_t ldb; + int64_t ldc; + int64_t strideA; + int64_t strideB; + + bool isInitOutput; + BrgemmMode mode; +}; + +static FailureOr inferBrgemmInfo(microkernel::BrgemmOp brgemmOp) { + Value operandA = brgemmOp.getOperandA(); + Value operandB = brgemmOp.getOperandB(); + Value operandC = brgemmOp.getOperandC(); + + auto checkTypeAndGetShape = + [&](Value operand) -> FailureOr> { + auto operandTy = operand.getType(); + if (!llvm::isa(operandTy)) + return failure(); + return operandTy.getShape(); + }; + + auto checkAndGetDimSize = + [&](int64_t batchDim, int64_t leadingDim, + Value operand) -> std::tuple, FailureOr, + FailureOr> { + auto operandShape = checkTypeAndGetShape(operand); + if (failed(operandShape)) + return {failure(), failure(), failure()}; + int64_t batchDimSize = (*operandShape)[batchDim]; + int64_t leadingDimSize = (*operandShape)[leadingDim]; + // minorDim is always last dim (the 3rd dim in 3D shape) + int64_t minorDimSize = (*operandShape)[2]; + if (operandShape->size() == 4) + // Input B VNNI format exists, special treatment to align with non-VNNI + // format + leadingDimSize *= (*operandShape)[3]; + return {batchDimSize, leadingDimSize, minorDimSize}; + }; + + auto checkAndGetLdStride = [&](int64_t leadingDim, + Value operand) -> FailureOr { + auto operandShape = checkTypeAndGetShape(operand); + if (failed(operandShape)) + return failure(); + auto stridesOnOperand = gcext::utils::getStaticStrides(operand); + if (failed(stridesOnOperand)) + return failure(); + auto leadingDimStride = (*stridesOnOperand)[leadingDim]; + if (operandShape->size() == 4) + // Input B VNNI format exists, special treatment to align with non-VNNI + // format + return leadingDimStride / (*operandShape)[3]; + return leadingDimStride; + }; + + auto checkAndGetBatchStride = [&](int64_t batchDim, + Value operand) -> FailureOr { + auto stridesOnOperand = gcext::utils::getStaticStrides(operand); + if (failed(stridesOnOperand)) + return failure(); + return (*stridesOnOperand)[batchDim]; + }; + + // A(m, k) + auto leadingDimA = brgemmOp.getLeadingDimA(); + auto [batchA, M, KA] = checkAndGetDimSize(leadingDimA, operandA); + auto lda = checkAndGetLdStride(leadingDimA, operandA); + if (failed(batchA) || failed(M) || failed(K) || failed(lda)) + return failure(); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] M, K, Lda for A: " << *M << ", " + << *KA << ", " << *lda << "\n"); + + // B(k, n) + auto leadingDimB = brgemmOp.getLeadingDimB(); + auto [batchB, KB, N] = checkAndGetDimSize(leadingDimB, operandB); + auto ldb = checkAndGetLdStride(leadingDimB, operandB); + if (failed(batchB) || failed(KB) || failed(N) || failed(ldb)) + return failure(); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] K, N, Ldb for B: " << *KB + << ", " << *N << ", " << *ldb << "\n"); + assert(*batchA == *batchB && *KA == *KB && + "Expecting matching shapes of inputs"); + + // C(m, n) + auto ldc = checkAndGetLdStride(0, operandC); + if (failed(ldc)) + return failure(); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Ld stride on C: " << ldc + << "\n"); + + auto strideA = checkAndGetBatchStride(brgemmOp.getBatchDimA(), operandA); + if (failed(strideA)) + return failure(); + + auto strideB = checkAndGetBatchStride(brgemmOp.getBatchDimB(), operandB); + if (failed(strideB)) + return failure(); + + bool isInit = false; + auto flags = brgemmOp.getFlagsAttr(); + for (auto flag : flags) { + auto brgemmFlag = dyn_cast_or_null(flag); + if (!brgemmFlag) + return failure(); + if (brgemmFlag.getValue() == BrgemmFlags::LIST) + return failure(); + if (brgemmFlag.getValue() == BrgemmFlags::BETA_0) + isInit = true; + } + + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] final BrgemmInfo: m(" << *M + << "), n(" << *N << "), k(" << * << "), batch(" + << *batchA << "), lda(" << *lda << "), ldb(" << *ldb + << "), ldc(" << *ldc << "), strideA(" << *strideA + << "), strideB(" << *strideB << ")\n"); + BrgemmInfo info{*M, + *N, + *KA, + *batchA, + 0 /* addrLen useless under stride mode */, + *lda, + *ldb, + *ldc, + *strideA, + *strideB, + isInit, + BrgemmInfo::STRIDE_MODE}; + return info; +} + +// Replace microkernel.BrgemmOp with a set of microkernel ops +static void replaceOpWithMicrokernelOpSet(PatternRewriter &rewriter, + microkernel::BrgemmOp brgemmOp, + const BrgemmInfo &info) { + assert(brgemmOp.getDpsInputs().size() == 2); + OpBuilder::InsertionGuard guard(rewriter); + + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + Location loc = brgemmOp.getLoc(); + SmallVector brgemmFlags; + if (info.isInitOutput) { + brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( + rewriter.getContext(), microkernel::BrgemmFlags::BETA_0)); + } + if (info.mode == BrgemmInfo::STRIDE_MODE) { + brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( + rewriter.getContext(), microkernel::BrgemmFlags::STRIDE)); + } else if (info.mode == BrgemmInfo::LIST_MODE) { + brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( + rewriter.getContext(), microkernel::BrgemmFlags::LIST)); + } + + SmallVector brgemmDtypes{ + TypeAttr::get(getElementTypeOrSelf(brgemmOp.getDpsInputs()[0].getType())), + TypeAttr::get( + getElementTypeOrSelf(brgemmOp.getDpsInputs()[1].getType()))}; + + // create dispatch op + auto flags = rewriter.getArrayAttr(brgemmFlags); + auto dtypes = rewriter.getArrayAttr(brgemmDtypes); + DenseI64ArrayAttr dims = DenseI64ArrayAttr::get( + rewriter.getContext(), + ArrayRef{info.m, info.n, info.k, info.lda, info.ldb, info.ldc, + info.strideA, info.strideB}); + Value dispatched = rewriter.create( + loc, integer64, dims, flags, dtypes); + + // create prologue op + rewriter.create(loc, dispatched); + + // create brgemm invoke op + Value batchDim = rewriter.create( + loc, integer64, rewriter.getIntegerAttr(integer64, info.batchSize)); + Value lenDim = rewriter.create( + loc, integer64, rewriter.getIntegerAttr(integer64, info.addrLen)); + SmallVector invokeOperands; + invokeOperands.push_back(dispatched); + invokeOperands.append(brgemmOp->getOperands().begin(), + brgemmOp->getOperands().end()); + invokeOperands.push_back(batchDim); + invokeOperands.push_back(lenDim); + rewriter.create(loc, invokeOperands); + + // create epilogue op & replace original op + rewriter.replaceOpWithNewOp(brgemmOp, + dispatched); +} + +class ExpandMicrokernelBrgemmRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(microkernel::BrgemmOp op, + PatternRewriter &rewriter) const final { + if (!op.hasPureBufferSemantics()) + return failure(); + + auto brgemmInfo = inferBrgemmInfo(op); + if (failed(brgemmInfo)) + return failure(); + + replaceOpWithMicrokernelOpSet(rewriter, op, *brgemmInfo); + return success(); + } +}; + +class ExpandMicrokernel + : public impl::ExpandMicrokernelBase { +public: + using impl::ExpandMicrokernelBase::ExpandMicrokernelBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) + signalPassFailure(); + } +}; + +} // namespace mlir::microkernel From bc95df33d540101a431bc93ed5f8c66189439a7f Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Thu, 25 Jul 2024 01:59:03 -0700 Subject: [PATCH 57/93] add lowering from linalgOp to newly added brgemmOp --- .../ConvertLinalgToMicrokernel.cpp | 365 ++++++++---------- 1 file changed, 161 insertions(+), 204 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index b559bd56a..cdbed727f 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -31,22 +31,19 @@ namespace mlir::microkernel { #define DEBUG_TYPE "convert-linalg-to-microkernel" -struct BrgemmInfo { - enum BrgemmMode { STRIDE_MODE, LIST_MODE }; - int64_t m; - int64_t n; - int64_t k; - int64_t batchSize; - int64_t addrLen; - - int64_t lda; - int64_t ldb; - int64_t ldc; - int64_t strideA; - int64_t strideB; - - bool isInitOutput; - BrgemmMode mode; +struct BrgemmDims { + int64_t batchDimA; + int64_t leadingDimA; + int64_t minorDimA; + + int64_t batchDimB; + int64_t leadingDimB; + int64_t minorDimB; + + BrgemmDims(int64_t bdA, int64_t ldA, int64_t mdA, int64_t bdB, int64_t ldB, + int64_t mdB) + : batchDimA(bdA), leadingDimA(lda), minorDimA(mdA), batchDimB(bdB), + leadingDimB(ldB), minorDimB(mdB) {} }; FailureOr @@ -109,124 +106,7 @@ static std::optional getPosInDomain(ArrayRef dimPos, return std::nullopt; } -static FailureOr -inferBrgemmInfo(linalg::LinalgOp linalgOp, - const linalg::ContractionDimensions &dims) { - unsigned mPos = dims.m[0]; - unsigned nPos = dims.n[0]; - // dims.k could be of 2 cases: - // 1. dims.k.size() == 2: non-VNNI, K = dims.k[1] - // 2. dims.k.size() == 3: VNNI, K = dims.k[1] * dims.k[2] - SmallVector kPos; - if (dims.k.size() == 2) - kPos = {dims.k[1]}; - else if (dims.k.size() == 3) - kPos = {dims.k[1], dims.k[2]}; - else - return failure(); - unsigned batchPos = dims.k.front(); - - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Candidate dims: " - << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] m pos in affine: " << mPos - << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] n pos in affine: " << nPos - << "\n"); - for (auto kp : kPos) - LLVM_DEBUG(llvm::dbgs() - << "[inferBrgemmInfo] k pos in affine: " << kp << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] batch pos in affine: " - << batchPos << "\n"); - - auto checkStridesAndGetLda = - [&](ArrayRef minorDim, ArrayRef majorDim, - OpOperand *operand, bool allowVnni) -> FailureOr { - std::optional minorDimPosInDomain = - getPosInDomain(minorDim, operand, linalgOp); - std::optional majorDimPosInDomain = - getPosInDomain(majorDim, operand, linalgOp); - if (!minorDimPosInDomain || !majorDimPosInDomain) - return failure(); - FailureOr> stridesOnOperand = - utils::getStaticStrides(operand->get()); - if (failed(stridesOnOperand)) - return failure(); - auto minorDimLd = (*stridesOnOperand)[*minorDimPosInDomain]; - auto majorDimLd = (*stridesOnOperand)[*majorDimPosInDomain]; - if (minorDimLd != 1) { - // VNNI format exists, special treatment to align LD with non-VNNI format - if (!allowVnni || (minorDimLd != 2 && minorDimLd != 4)) - return failure(); - return majorDimLd / minorDimLd; - } - return majorDimLd; - }; - - OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; - OpOperand *operandB = linalgOp.getDpsInputOperands()[1]; - OpOperand *operandC = &linalgOp.getDpsInitsMutable()[0]; - - // A(m, k) - FailureOr lda = checkStridesAndGetLda(kPos, {mPos}, operandA, false); - if (failed(lda)) - return failure(); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on A: OK\n"); - - // B(k, n) - // note: B does not use VNNI format K affine - FailureOr ldb = - checkStridesAndGetLda({nPos}, {kPos[0]}, operandB, true); - if (failed(ldb)) - return failure(); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on B: OK\n"); - - // C(m, n) - FailureOr ldc = - checkStridesAndGetLda({nPos}, {mPos}, operandC, false); - if (failed(ldc)) - return failure(); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Strides on C: OK\n"); - - int64_t strideA = 1; - int64_t strideB = 1; - std::optional batchPosDomainA = - getPosInDomain(batchPos, operandA, linalgOp); - FailureOr> stridesOnA = - utils::getStaticStrides(operandA->get()); - strideA = (*stridesOnA)[*batchPosDomainA]; - - std::optional batchPosDomainB = - getPosInDomain(batchPos, operandB, linalgOp); - FailureOr> stridesOnB = - utils::getStaticStrides(operandB->get()); - strideB = (*stridesOnB)[*batchPosDomainB]; - - SmallVector loops = linalgOp.computeStaticLoopSizes(); - auto kSize = - kPos.size() == 1 ? loops[kPos[0]] : (loops[kPos[0]] * loops[kPos[1]]); - - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] final BrgemmInfo: m(" - << loops[mPos] << "), n(" << loops[nPos] << "), k(" - << kSize << "), batch(" << loops[batchPos] - << "), lda(" << *lda << "), ldb(" << *ldb << "), ldc(" - << *ldc << "), strideA(" << strideA << "), strideB(" - << strideB << ")\n"); - BrgemmInfo info{loops[mPos], - loops[nPos], - kSize, - loops[batchPos], - 0 /* addrLen useless under stride mode */, - *lda, - *ldb, - *ldc, - strideA, - strideB, - false, - BrgemmInfo::STRIDE_MODE}; - return info; -} - -static FailureOr getBrgemmInfo(linalg::LinalgOp linalgOp) { +static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { using namespace mlir::structured_match; auto validBrgemmMatcher = StructuredOpMatcher::make() .output(MatchAll(), HasStaticShape()) @@ -263,68 +143,123 @@ static FailureOr getBrgemmInfo(linalg::LinalgOp linalgOp) { return failure(); } - return inferBrgemmInfo(linalgOp, *contractionDims); + unsigned mAffinePos = contractionDims.m[0]; + unsigned nAffinePos = contractionDims.n[0]; + // contractionDims.k could be of 2 cases: + // 1. dims.k.size() == 2: non-VNNI, K = dims.k[1] + // 2. dims.k.size() == 3: VNNI, K = dims.k[1] * dims.k[2] + unsigned batchAffinePos = contractionDims.k.front(); + SmallVector kAffinePos; + if (contractionDims.k.size() == 2) + kPos = {contractionDims.k[1]}; + else if (contractionDims.k.size() == 3) + kPos = {contractionDims.k[1], contractionDims.k[2]}; + else + return failure(); + + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] Candidate dims: " + << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] m pos in affine: " << mPos + << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] n pos in affine: " << nPos + << "\n"); + for (auto kp : kPos) + LLVM_DEBUG(llvm::dbgs() + << "[inferBrgemmDims] k pos in affine: " << kp << "\n"); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] batch pos in affine: " + << batchPos << "\n"); + + OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; + OpOperand *operandB = linalgOp.getDpsInputOperands()[1]; + OpOperand *operandC = &linalgOp.getDpsInitsMutable()[0]; + + BrgemmDims brgemmDims; + // A(batch, m, k) + brgemmDims.batchDimA = getPosInCodomain(batchAffinePos, operandA, linalgOp); + brgemmDims.leadingDimA = getPosInCodomain({mAffinePos}, operandA, linalgOp); + brgemmDims.minorDimA = getPosInCodomain(kAffinePos, operandA, linalgOp); + // B(batch, k, n) or B(batch, k/vnni_step, n, vnni_step) + // note: B does not use VNNI format K affine + brgemmDims.batchDimB = getPosInCodomain(batchAffinePos, operandB, linalgOp); + brgemmDims.leadingDimB = + getPosInCodomain({kAffinePos[0]}, operandB, linalgOp); + brgemmDims.minorDimB = getPosInCodomain({nAffinePos}, operandB, linalgOp); + // C(m, n) + brgemmDims.leadingDimC = getPosInCodomain({mAffinePos}, operandC, linalgOp); + brgemmDims.minorDimC = getPosInCodomain(kAffinePos, operandC, linalgOp); + + return brgemmDims; } -// Replace linalgOp with a set of microkernel ops -static void replaceOpWithMicrokernelOpSet(PatternRewriter &rewriter, - linalg::LinalgOp linalgOp, - const BrgemmInfo &info) { - assert(linalgOp.getDpsInputs().size() == 2); +template +static FailureOr +getFusibleTranspose(linalg::LinalgOp linalgOp, Value buffer) { + auto brmmOp = dyn_cast_or_null(linalgOp); + if (!brmmOp) + return failure(); + auto defOp = buffer.getDefiningOp(); + auto transOp = dyn_cast_or_null(defOp); + if (!transOp) + return failure(); + + using one_t = std::integral_constant; + using two_t = std::integral_constant; + constexpr size_t lastDimOffsetA = 1; + constexpr size_t lastDimOffsetB = std::conditional< + std::is_same::value, one_t, + two_t>::type::value; + size_t lastDimOffset = + buffer == brmmOp.getInputs()[0] ? lastDimOffsetA : lastDimOffsetB; + + ArrayRef permutation = transOp.getPermutation(); + bool lastDimContigious = true; + // Last dim can't not be permuted if we want to incorporate the + // transpose, because BRGEMM requires last dim to be contigious. + // For VNNI, it requires the last two dims to be non-permutedi + for (size_t idx = permutation.size() - lastDimOffset; + idx < permutation.size(); idx++) + lastDimContigious = lastDimContigious && (permutation[idx] == idx); + + if (lastDimContigious) + return transOp; + return failure(); +} + +// Replace linalgOp with corresponding microkernel Op +static void +replaceOpWithMicrokernelOpSet(PatternRewriter &rewriter, + linalg::LinalgOp linalgOp, const BrgemmDims &dims, + const DenseMap &replaceMap, + bool isInitOutput) { OpBuilder::InsertionGuard guard(rewriter); - IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); - Location loc = linalgOp.getLoc(); + DenseI64ArrayAttr batchDims = DenseI64ArrayAttr::get( + rewriter.getContext(), ArrayRef{dims.batchDimA, dims.batchDimB}); + DenseI64ArrayAttr leadingDims = DenseI64ArrayAttr::get( + rewriter.getContext(), + ArrayRef{dims.leadingDimA, dims.leadingDimB}); + SmallVector brgemmFlags; - if (info.isInitOutput) { + if (isInitOutput) { brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( rewriter.getContext(), microkernel::BrgemmFlags::BETA_0)); } - if (info.mode == BrgemmInfo::STRIDE_MODE) { - brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( - rewriter.getContext(), microkernel::BrgemmFlags::STRIDE)); - } else if (info.mode == BrgemmInfo::LIST_MODE) { - brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( - rewriter.getContext(), microkernel::BrgemmFlags::LIST)); - } - - SmallVector brgemmDtypes{ - TypeAttr::get(getElementTypeOrSelf(linalgOp.getDpsInputs()[0].getType())), - TypeAttr::get( - getElementTypeOrSelf(linalgOp.getDpsInputs()[1].getType()))}; - // create dispatch op - auto flags = rewriter.getArrayAttr(brgemmFlags); - auto dtypes = rewriter.getArrayAttr(brgemmDtypes); - DenseI64ArrayAttr dims = DenseI64ArrayAttr::get( - rewriter.getContext(), - ArrayRef{info.m, info.n, info.k, info.lda, info.ldb, info.ldc, - info.strideA, info.strideB}); - Value dispatched = rewriter.create( - loc, integer64, dims, flags, dtypes); - - // create prologue op - rewriter.create(loc, dispatched); - - // create brgemm invoke op - Value batchDim = rewriter.create( - loc, integer64, rewriter.getIntegerAttr(integer64, info.batchSize)); - Value lenDim = rewriter.create( - loc, integer64, rewriter.getIntegerAttr(integer64, info.addrLen)); - SmallVector invokeOperands; - invokeOperands.push_back(dispatched); - invokeOperands.append(linalgOp->getOperands().begin(), - linalgOp->getOperands().end()); - invokeOperands.push_back(batchDim); - invokeOperands.push_back(lenDim); - rewriter.create(loc, invokeOperands); - - // create epilogue op & replace original op - rewriter.replaceOpWithNewOp(linalgOp, - dispatched); + Value operandA = linalgOp.getDpsInputOperands()[0]->get(); + Value operandB = linalgOp.getDpsInputOperands()[1]->get(); + Value operandC = linalgOp.getDpsInitsMutable()[0].get(); + + auto brgemmOp = rewriter.replaceOpWithNewOp( + linalgOp, operandC.getType(), {operandA, operandB}, operandC, batchDims, + leadingDims, brgemmFlags); + // Replace operands according to fusion + rewriter.modifyOpInPlace(brgemmOp, [&]() { + for (const auto &pair : replaceMap) + brgemmOp->replaceUsesOfWith(pair.first, pair.second); + }); } -bool isZeroArithConstant(arith::ConstantOp op) { +static bool isZeroArithConstant(arith::ConstantOp op) { if (!op) return false; @@ -347,28 +282,50 @@ class ConvertContractionOpToBrgemmRewriter using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ContractionOp op, PatternRewriter &rewriter) const final { - auto brgemmInfo = getBrgemmInfo(op); - if (failed(brgemmInfo)) + if (!op.hasPureTensorSemantics()) + return failure(); + + auto brgemmDims = inferBrgemmDims(op); + if (failed(brgemmDims)) return failure(); - // Check for immediately preceding linalg::FillOp - Operation *rawOp = op; - auto block = rawOp->getBlock(); - auto opIter = Block::iterator(rawOp); - if (block->begin() != opIter) { - auto prevOp = &(*(--opIter)); - if (auto fillOp = dyn_cast(prevOp)) { - auto inputCst = dyn_cast_or_null( - fillOp.getInputs()[0].getDefiningOp()); - auto fillOperand = fillOp.getOutputs()[0]; - auto contractionOperand = op.getOutputs()[0]; - if (isZeroArithConstant(inputCst) && - contractionOperand == fillOperand) { - brgemmInfo->isInitOutput = true; - rewriter.eraseOp(prevOp); - } + + DenseMap replaceMap; + // Check for fusible linalg::TransposeOp on operand A & B + Value operandA = op.getDpsInputOperands()[0]->get(); + Value operandB = op.getDpsInputOperands()[1]->get(); + auto fusibleTransA = getFusibleTranspose(operandA); + auto fusibleTransB = getFusibleTranspose(operandB); + // Presumably minorDims are last dims and not permutated, so no need to + // transform them + if (fusibleTransA) { + ArrayRef permutation = fusibleTransA->getPermutation(); + brgemmDims.batchDimA = permutation[brgemmDims.batchDimA]; + brgemmDims.leadingDimA = permutation[brgemmDims.leadingDimA]; + replaceMap[fusibleTransA->getResult()[0]] = fusibleTransA->getInput(); + } + if (fusibleTransB) { + ArrayRef permutation = fusibleTransB->getPermutation(); + brgemmDims.batchDimB = permutation[brgemmDims.batchDimB]; + brgemmDims.leadingDimB = permutation[brgemmDims.leadingDimB]; + replaceMap[fusibleTransB->getResult()[0]] = fusibleTransB->getInput(); + } + + // Check for fusible linalg::FillOp on operand C + bool isInitOutput = false; + Value operandC = op.getDpsInitsMutable()[0].get(); + auto defOp = operandC.getDefiningOp(); + if (llvm::isa(defOp)) { + auto fillOp = dyn_cast_or_null(defOp); + auto inputCst = dyn_cast_or_null( + fillOp.getInputs()[0].getDefiningOp()); + if (isZeroArithConstant(inputCst)) { + replaceMap[fillOp.getResultTensors()[0]] = fillOp.getOutputs()[0]; + isInitOutput = true; } } - replaceOpWithMicrokernelOpSet(rewriter, op, *brgemmInfo); + + replaceOpWithMicrokernelOp(rewriter, op, brgemmDims, replaceMap, + isInitOutput); return success(); } }; From 039f24d4d199e4fba379d176b7e913edf18d1694 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Thu, 25 Jul 2024 02:00:37 -0700 Subject: [PATCH 58/93] fix header --- lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp index c7984892e..5b17afb02 100644 --- a/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp @@ -24,8 +24,8 @@ #include "gc/Dialect/Linalgx/LinalgxOps.h" #include "gc/Transforms/Microkernel/MicrokernelPasses.h" -#include "gc/Utils/StructuredOpMatcher.h" -#include "gc/Utils/ValueUtils.h" +#include "gc/Transforms/Utils/StructuredOpMatcher.h" +#include "gc/Transforms/Utils/ValueUtils.h" namespace mlir::microkernel { #define GEN_PASS_DEF_CONVERTLINALGTOMICROKERNEL From ccd31ccd54af23180738e461f74efc316ea57be0 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Fri, 26 Jul 2024 02:10:23 -0700 Subject: [PATCH 59/93] fix compile issue --- .../ConvertLinalgToMicrokernel.cpp | 95 ++++++++++--------- .../Microkernel/ExpandMicrokernel.cpp | 22 +++-- 2 files changed, 63 insertions(+), 54 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index cdbed727f..2121e8515 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -40,9 +40,10 @@ struct BrgemmDims { int64_t leadingDimB; int64_t minorDimB; + BrgemmDims() {} BrgemmDims(int64_t bdA, int64_t ldA, int64_t mdA, int64_t bdB, int64_t ldB, int64_t mdB) - : batchDimA(bdA), leadingDimA(lda), minorDimA(mdA), batchDimB(bdB), + : batchDimA(bdA), leadingDimA(ldA), minorDimA(mdA), batchDimB(bdB), leadingDimB(ldB), minorDimB(mdB) {} }; @@ -106,7 +107,7 @@ static std::optional getPosInDomain(ArrayRef dimPos, return std::nullopt; } -static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { +static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { using namespace mlir::structured_match; auto validBrgemmMatcher = StructuredOpMatcher::make() .output(MatchAll(), HasStaticShape()) @@ -143,60 +144,64 @@ static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { return failure(); } - unsigned mAffinePos = contractionDims.m[0]; - unsigned nAffinePos = contractionDims.n[0]; + unsigned mAffinePos = contractionDims->m[0]; + unsigned nAffinePos = contractionDims->n[0]; // contractionDims.k could be of 2 cases: // 1. dims.k.size() == 2: non-VNNI, K = dims.k[1] // 2. dims.k.size() == 3: VNNI, K = dims.k[1] * dims.k[2] - unsigned batchAffinePos = contractionDims.k.front(); + unsigned batchAffinePos = contractionDims->k.front(); SmallVector kAffinePos; - if (contractionDims.k.size() == 2) - kPos = {contractionDims.k[1]}; - else if (contractionDims.k.size() == 3) - kPos = {contractionDims.k[1], contractionDims.k[2]}; + if (contractionDims->k.size() == 2) + kAffinePos = {contractionDims->k[1]}; + else if (contractionDims->k.size() == 3) + kAffinePos = {contractionDims->k[1], contractionDims->k[2]}; else return failure(); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] Candidate dims: " << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] m pos in affine: " << mPos + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] m pos in affine: " << mAffinePos << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] n pos in affine: " << nPos + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] n pos in affine: " << nAffinePos << "\n"); - for (auto kp : kPos) + for (auto kp : kAffinePos) LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] k pos in affine: " << kp << "\n"); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] batch pos in affine: " - << batchPos << "\n"); + << batchAffinePos << "\n"); OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; OpOperand *operandB = linalgOp.getDpsInputOperands()[1]; - OpOperand *operandC = &linalgOp.getDpsInitsMutable()[0]; BrgemmDims brgemmDims; + + auto checkAndGetPosInCodomain = [&](int64_t &dim, ArrayRef dimPos, + OpOperand *operand) { + auto pos = getPosInCodomain(batchAffinePos, operand, linalgOp); + assert(pos && "Cannot find position in codomain"); + dim = *pos; + }; + // A(batch, m, k) - brgemmDims.batchDimA = getPosInCodomain(batchAffinePos, operandA, linalgOp); - brgemmDims.leadingDimA = getPosInCodomain({mAffinePos}, operandA, linalgOp); - brgemmDims.minorDimA = getPosInCodomain(kAffinePos, operandA, linalgOp); + checkAndGetPosInCodomain(brgemmDims.batchDimA, batchAffinePos, operandA); + checkAndGetPosInCodomain(brgemmDims.leadingDimA, {mAffinePos}, operandA); + checkAndGetPosInCodomain(brgemmDims.minorDimA, kAffinePos, operandA); // B(batch, k, n) or B(batch, k/vnni_step, n, vnni_step) // note: B does not use VNNI format K affine - brgemmDims.batchDimB = getPosInCodomain(batchAffinePos, operandB, linalgOp); - brgemmDims.leadingDimB = - getPosInCodomain({kAffinePos[0]}, operandB, linalgOp); - brgemmDims.minorDimB = getPosInCodomain({nAffinePos}, operandB, linalgOp); + checkAndGetPosInCodomain(brgemmDims.batchDimB, batchAffinePos, operandB); + checkAndGetPosInCodomain(brgemmDims.leadingDimB, {kAffinePos[0]}, operandB); + checkAndGetPosInCodomain(brgemmDims.minorDimB, {nAffinePos}, operandB); // C(m, n) - brgemmDims.leadingDimC = getPosInCodomain({mAffinePos}, operandC, linalgOp); - brgemmDims.minorDimC = getPosInCodomain(kAffinePos, operandC, linalgOp); + // Currently useless, no need to set + // checkAndGetPosInCodomain(brgemmDims.leadingDimC, {mAffinePos}, operandC); + // checkAndGetPosInCodomain(brgemmDims.minorDimC, kAffinePos, operandC); return brgemmDims; } template -static FailureOr -getFusibleTranspose(linalg::LinalgOp linalgOp, Value buffer) { - auto brmmOp = dyn_cast_or_null(linalgOp); - if (!brmmOp) - return failure(); +static FailureOr getFusibleTranspose(SrcBrmmOpTy brmmOp, + Value buffer) { auto defOp = buffer.getDefiningOp(); auto transOp = dyn_cast_or_null(defOp); if (!transOp) @@ -226,11 +231,11 @@ getFusibleTranspose(linalg::LinalgOp linalgOp, Value buffer) { } // Replace linalgOp with corresponding microkernel Op -static void -replaceOpWithMicrokernelOpSet(PatternRewriter &rewriter, - linalg::LinalgOp linalgOp, const BrgemmDims &dims, - const DenseMap &replaceMap, - bool isInitOutput) { +static void replaceOpWithMicrokernelOp(PatternRewriter &rewriter, + linalg::LinalgOp linalgOp, + const BrgemmDims &dims, + const DenseMap &replaceMap, + bool isInitOutput) { OpBuilder::InsertionGuard guard(rewriter); DenseI64ArrayAttr batchDims = DenseI64ArrayAttr::get( @@ -244,14 +249,16 @@ replaceOpWithMicrokernelOpSet(PatternRewriter &rewriter, brgemmFlags.push_back(microkernel::BrgemmFlagsAttr::get( rewriter.getContext(), microkernel::BrgemmFlags::BETA_0)); } + auto flags = rewriter.getArrayAttr(brgemmFlags); Value operandA = linalgOp.getDpsInputOperands()[0]->get(); Value operandB = linalgOp.getDpsInputOperands()[1]->get(); Value operandC = linalgOp.getDpsInitsMutable()[0].get(); + SmallVector inputs{operandA, operandB}; auto brgemmOp = rewriter.replaceOpWithNewOp( - linalgOp, operandC.getType(), {operandA, operandB}, operandC, batchDims, - leadingDims, brgemmFlags); + linalgOp, operandC.getType(), inputs, operandC, batchDims, leadingDims, + flags); // Replace operands according to fusion rewriter.modifyOpInPlace(brgemmOp, [&]() { for (const auto &pair : replaceMap) @@ -293,20 +300,20 @@ class ConvertContractionOpToBrgemmRewriter // Check for fusible linalg::TransposeOp on operand A & B Value operandA = op.getDpsInputOperands()[0]->get(); Value operandB = op.getDpsInputOperands()[1]->get(); - auto fusibleTransA = getFusibleTranspose(operandA); - auto fusibleTransB = getFusibleTranspose(operandB); + auto fusibleTransA = getFusibleTranspose(op, operandA); + auto fusibleTransB = getFusibleTranspose(op, operandB); // Presumably minorDims are last dims and not permutated, so no need to // transform them - if (fusibleTransA) { + if (!failed(fusibleTransA)) { ArrayRef permutation = fusibleTransA->getPermutation(); - brgemmDims.batchDimA = permutation[brgemmDims.batchDimA]; - brgemmDims.leadingDimA = permutation[brgemmDims.leadingDimA]; + brgemmDims->batchDimA = permutation[brgemmDims->batchDimA]; + brgemmDims->leadingDimA = permutation[brgemmDims->leadingDimA]; replaceMap[fusibleTransA->getResult()[0]] = fusibleTransA->getInput(); } - if (fusibleTransB) { + if (!failed(fusibleTransB)) { ArrayRef permutation = fusibleTransB->getPermutation(); - brgemmDims.batchDimB = permutation[brgemmDims.batchDimB]; - brgemmDims.leadingDimB = permutation[brgemmDims.leadingDimB]; + brgemmDims->batchDimB = permutation[brgemmDims->batchDimB]; + brgemmDims->leadingDimB = permutation[brgemmDims->leadingDimB]; replaceMap[fusibleTransB->getResult()[0]] = fusibleTransB->getInput(); } @@ -324,7 +331,7 @@ class ConvertContractionOpToBrgemmRewriter } } - replaceOpWithMicrokernelOp(rewriter, op, brgemmDims, replaceMap, + replaceOpWithMicrokernelOp(rewriter, op, *brgemmDims, replaceMap, isInitOutput); return success(); } diff --git a/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp index 5b17afb02..20c4443ca 100644 --- a/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp @@ -28,10 +28,10 @@ #include "gc/Transforms/Utils/ValueUtils.h" namespace mlir::microkernel { -#define GEN_PASS_DEF_CONVERTLINALGTOMICROKERNEL +#define GEN_PASS_DEF_EXPANDMICROKERNEL #include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" -#define DEBUG_TYPE "convert-linalg-to-microkernel" +#define DEBUG_TYPE "expand-microkernel" struct BrgemmInfo { enum BrgemmMode { STRIDE_MODE, LIST_MODE }; @@ -61,7 +61,7 @@ static FailureOr inferBrgemmInfo(microkernel::BrgemmOp brgemmOp) { auto operandTy = operand.getType(); if (!llvm::isa(operandTy)) return failure(); - return operandTy.getShape(); + return dyn_cast(operandTy).getShape(); }; auto checkAndGetDimSize = @@ -87,7 +87,7 @@ static FailureOr inferBrgemmInfo(microkernel::BrgemmOp brgemmOp) { auto operandShape = checkTypeAndGetShape(operand); if (failed(operandShape)) return failure(); - auto stridesOnOperand = gcext::utils::getStaticStrides(operand); + auto stridesOnOperand = utils::getStaticStrides(operand); if (failed(stridesOnOperand)) return failure(); auto leadingDimStride = (*stridesOnOperand)[leadingDim]; @@ -100,24 +100,26 @@ static FailureOr inferBrgemmInfo(microkernel::BrgemmOp brgemmOp) { auto checkAndGetBatchStride = [&](int64_t batchDim, Value operand) -> FailureOr { - auto stridesOnOperand = gcext::utils::getStaticStrides(operand); + auto stridesOnOperand = utils::getStaticStrides(operand); if (failed(stridesOnOperand)) return failure(); return (*stridesOnOperand)[batchDim]; }; // A(m, k) + auto batchDimA = brgemmOp.getBatchDimA(); auto leadingDimA = brgemmOp.getLeadingDimA(); - auto [batchA, M, KA] = checkAndGetDimSize(leadingDimA, operandA); + auto [batchA, M, KA] = checkAndGetDimSize(batchDimA, leadingDimA, operandA); auto lda = checkAndGetLdStride(leadingDimA, operandA); - if (failed(batchA) || failed(M) || failed(K) || failed(lda)) + if (failed(batchA) || failed(M) || failed(KA) || failed(lda)) return failure(); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] M, K, Lda for A: " << *M << ", " << *KA << ", " << *lda << "\n"); // B(k, n) + auto batchDimB = brgemmOp.getBatchDimB(); auto leadingDimB = brgemmOp.getLeadingDimB(); - auto [batchB, KB, N] = checkAndGetDimSize(leadingDimB, operandB); + auto [batchB, KB, N] = checkAndGetDimSize(batchDimB, leadingDimB, operandB); auto ldb = checkAndGetLdStride(leadingDimB, operandB); if (failed(batchB) || failed(KB) || failed(N) || failed(ldb)) return failure(); @@ -154,7 +156,7 @@ static FailureOr inferBrgemmInfo(microkernel::BrgemmOp brgemmOp) { } LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] final BrgemmInfo: m(" << *M - << "), n(" << *N << "), k(" << * << "), batch(" + << "), n(" << *N << "), k(" << *KB << "), batch(" << *batchA << "), lda(" << *lda << "), ldb(" << *ldb << "), ldc(" << *ldc << "), strideA(" << *strideA << "), strideB(" << *strideB << ")\n"); @@ -234,7 +236,7 @@ static void replaceOpWithMicrokernelOpSet(PatternRewriter &rewriter, class ExpandMicrokernelBrgemmRewriter : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(microkernel::BrgemmOp op, PatternRewriter &rewriter) const final { if (!op.hasPureBufferSemantics()) From 15b09c534d8851cdd8da49a7730690f0cb1885e7 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Fri, 26 Jul 2024 02:26:12 -0700 Subject: [PATCH 60/93] minor fix on test cases --- .../gc/Dialect/Microkernel/linalg-to-microkernel.mlir | 8 ++++---- .../gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir | 2 +- test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir index 161121f0c..e9d40eedb 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir @@ -37,7 +37,7 @@ func.func @basic_linalg_to_microkernel() { // CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> // CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) // CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () -// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () // CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () // ----- @@ -79,7 +79,7 @@ func.func @vnni_linalg_to_microkernel() { // CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> // CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) // CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () -// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () // CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () // ----- @@ -123,7 +123,7 @@ func.func @basic_linalg_to_microkernel_fusing_fill() { // CHECK-NOT: linalg.fill // CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(beta_0, stride) data_type(f32, f32) // CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () -// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () // CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () // ----- @@ -167,7 +167,7 @@ func.func @vnni_linalg_to_microkernel_fusing_fill() { // CHECK-NOT: linalg.fill // CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(beta_0, stride) data_type(bf16, bf16) // CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () -// CHECK-NEXT: microkernel.brgemm(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () // CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () // ----- diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir index 91a3aefb4..ad5c3b84c 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/microkernel-to-dnnl-func.mlir @@ -17,7 +17,7 @@ module { %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%0) : (i64) -> () %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { diff --git a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir index 11c6c50ad..a76f1c617 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir @@ -21,7 +21,7 @@ module { %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%0) : (i64) -> () %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { From faa74db8ef22f7287d12d5edf420db0a64887279 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Mon, 29 Jul 2024 01:55:17 -0700 Subject: [PATCH 61/93] add test & bug fixes --- lib/gc/Dialect/Microkernel/MicrokernelOps.cpp | 48 +++-- .../Microkernel/expand-microkernel.mlir | 169 ++++++++++++++++++ 2 files changed, 205 insertions(+), 12 deletions(-) create mode 100644 test/mlir/test/gc/Dialect/Microkernel/expand-microkernel.mlir diff --git a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp index 4d569acc7..143e31705 100644 --- a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp +++ b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp @@ -25,6 +25,8 @@ #include +#define DEBUG_TYPE "microkernel-ops" + using namespace mlir::bufferization; namespace mlir { @@ -78,6 +80,29 @@ static ParseResult parseEnum(EnumClass &value, OpAsmParser &parser) { return success(); } +static ParseResult +parseDenseI64ArrayAttrImpl(OpAsmParser &parser, OperationState &result, + const std::string_view &attrAsmName, + const std::string_view &attrName) { + auto &builder = parser.getBuilder(); + if (parser.parseKeyword(attrAsmName) || parser.parseLParen()) + return failure(); + SmallVector vals; + auto parseVal = [&]() -> ParseResult { + int64_t val; + if (parser.parseInteger(val)) + return failure(); + vals.push_back(val); + return success(); + }; + if (parser.parseCommaSeparatedList(parseVal) || parser.parseRParen()) + return failure(); + + auto valAttr = builder.getDenseI64ArrayAttr(vals); + result.addAttribute(attrName, valAttr); + return success(); +} + template static ParseResult parseArrayAttrImpl(OpAsmParser &parser, OperationState &result, @@ -88,18 +113,17 @@ static ParseResult parseArrayAttrImpl(OpAsmParser &parser, return failure(); SmallVector attrs; auto parseAttr = [&]() -> ParseResult { - Attribute attr; + AttrType attr; if (parser.parseAttribute(attr)) return failure(); - if (!isa(attr)) - return failure(); attrs.push_back(attr); return success(); }; if (parser.parseCommaSeparatedList(parseAttr) || parser.parseRParen()) return failure(); - result.addAttribute(attrName, builder.getArrayAttr(attrs)); + auto arrayAttr = builder.getArrayAttr(attrs); + result.addAttribute(attrName, arrayAttr); return success(); } @@ -107,7 +131,7 @@ template static ParseResult parseFlagsImpl(OpAsmParser &parser, OperationState &result, const std::string_view &flagsName) { auto &builder = parser.getBuilder(); - if (parser.parseKeyword(flagsName) || parser.parseLParen()) + if (parser.parseKeyword(flagsName)) return failure(); SmallVector flags; @@ -118,7 +142,7 @@ static ParseResult parseFlagsImpl(OpAsmParser &parser, OperationState &result, flags.push_back(builder.getI64IntegerAttr(static_cast(flag))); return success(); }; - if (parser.parseCommaSeparatedList(parseFlags) || parser.parseRParen()) + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, parseFlags)) return failure(); result.addAttribute(flagsName, builder.getArrayAttr(flags)); return success(); @@ -193,26 +217,26 @@ static LogicalResult verifyBrgemmFlags(ArrayAttr flags, Operation *op, // Start of BrgemmOp ParseResult BrgemmOp::parse(OpAsmParser &parser, OperationState &result) { - if (failed(parseOperandsImpl(parser, result, INPUTS_ASM_NAME))) return failure(); if (failed(parseOperandsImpl(parser, result, OUTPUTS_ASM_NAME))) return failure(); - if (parseArrayAttrImpl(parser, result, BATCH_DIMS_ASM_NAME, - BATCH_DIMS_ATTR_NAME)) + if (failed(parseDenseI64ArrayAttrImpl(parser, result, BATCH_DIMS_ASM_NAME, + BATCH_DIMS_ATTR_NAME))) return failure(); - if (parseArrayAttrImpl(parser, result, LEADING_DIMS_ASM_NAME, - LEADING_DIMS_ATTR_NAME)) + if (failed(parseDenseI64ArrayAttrImpl(parser, result, LEADING_DIMS_ASM_NAME, + LEADING_DIMS_ATTR_NAME))) return failure(); if (failed(parseFlagsImpl(parser, result, FLAGS_ASM_NAME))) return failure(); SmallVector resultTypes; - if (parser.parseOptionalArrowTypeList(resultTypes)) + if (failed(parser.parseOptionalArrowTypeList(resultTypes))) return failure(); result.addTypes(resultTypes); + return success(); } diff --git a/test/mlir/test/gc/Dialect/Microkernel/expand-microkernel.mlir b/test/mlir/test/gc/Dialect/Microkernel/expand-microkernel.mlir new file mode 100644 index 000000000..c99642198 --- /dev/null +++ b/test/mlir/test/gc/Dialect/Microkernel/expand-microkernel.mlir @@ -0,0 +1,169 @@ +// RUN: gc-opt %s -expand-microkernel -split-input-file | FileCheck %s + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @basic_expand_microkernel_non_init() { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg7, %arg8) in (4, 8) { + %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + microkernel.brgemm ins(%subview, %subview_11 : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) batch_dims(0, 0) leading_dims(1, 1) flags() + %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_14: f32, %out: f32): + %0 = arith.addf %in, %in_14 : f32 + linalg.yield %0 : f32 + } + %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %0 = arith.maximumf %in, %cst : f32 + linalg.yield %0 : f32 + } + memref.dealloc %alloc_10 : memref<32x32xf32> + } + return +} + +// CHECK-LABEL: basic_expand_microkernel_non_init +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> +// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) +// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () +// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () + +// ----- + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @basic_expand_microkernel_init() { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg7, %arg8) in (4, 8) { + %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + microkernel.brgemm ins(%subview, %subview_11 : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) batch_dims(0, 0) leading_dims(1, 1) flags(beta_0) + %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_14: f32, %out: f32): + %0 = arith.addf %in, %in_14 : f32 + linalg.yield %0 : f32 + } + %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %0 = arith.maximumf %in, %cst : f32 + linalg.yield %0 : f32 + } + memref.dealloc %alloc_10 : memref<32x32xf32> + } + return +} + +// CHECK-LABEL: basic_expand_microkernel_init +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> +// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(beta_0, stride) data_type(f32, f32) +// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () +// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () + +// ----- + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @transpose_expand_microkernel_init() { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x32x16x32xf32> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x32x16x32xf32> + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg7, %arg8) in (4, 8) { + %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : memref<4x32x16x32xf32> to memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> + %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : memref<8x32x16x32xf32> to memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> + microkernel.brgemm ins(%subview, %subview_11 : memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>>, memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) batch_dims(1, 1) leading_dims(0, 0) flags(beta_0) + %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_14: f32, %out: f32): + %0 = arith.addf %in, %in_14 : f32 + linalg.yield %0 : f32 + } + %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %0 = arith.maximumf %in, %cst : f32 + linalg.yield %0 : f32 + } + memref.dealloc %alloc_10 : memref<32x32xf32> + } + return +} + +// CHECK-LABEL: transpose_expand_microkernel_init +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> +// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : memref<4x32x16x32xf32> to memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> +// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : memref<8x32x16x32xf32> to memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 512, 512, 32, 32, 32] flags(beta_0, stride) data_type(f32, f32) +// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () +// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>>, memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () + +// ----- + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @transpose_expand_microkernel_init_vnni() { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x32x16x32xbf16> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg7, %arg8) in (4, 8) { + %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : memref<4x32x16x32xbf16> to memref<32x16x32xbf16, strided<[512, 32, 1], offset: ?>> + %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + microkernel.brgemm ins(%subview, %subview_11 : memref<32x16x32xbf16, strided<[512, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) batch_dims(1, 1) leading_dims(0, 0) flags(beta_0) + %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_14: f32, %out: f32): + %0 = arith.addf %in, %in_14 : f32 + linalg.yield %0 : f32 + } + %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %0 = arith.maximumf %in, %cst : f32 + linalg.yield %0 : f32 + } + memref.dealloc %alloc_10 : memref<32x32xf32> + } + return +} + +// CHECK-LABEL: transpose_expand_microkernel_init_vnni +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> +// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : memref<4x32x16x32xbf16> to memref<32x16x32xbf16, strided<[512, 32, 1], offset: ?>> +// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> +// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 512, 512, 32, 32, 64] flags(beta_0, stride) data_type(bf16, bf16) +// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () +// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<32x16x32xbf16, strided<[512, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () + +// ----- From 9463006755c485e58836da4f7008524c9885aa47 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Mon, 29 Jul 2024 23:40:05 -0700 Subject: [PATCH 62/93] minor fixes & add tests --- .../ConvertLinalgToMicrokernel.cpp | 9 +- .../Microkernel/linalg-to-microkernel.mlir | 274 +++++++++--------- 2 files changed, 146 insertions(+), 137 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index 2121e8515..5239a3452 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -177,7 +177,7 @@ static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { auto checkAndGetPosInCodomain = [&](int64_t &dim, ArrayRef dimPos, OpOperand *operand) { - auto pos = getPosInCodomain(batchAffinePos, operand, linalgOp); + auto pos = getPosInCodomain(dimPos, operand, linalgOp); assert(pos && "Cannot find position in codomain"); dim = *pos; }; @@ -196,6 +196,13 @@ static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { // checkAndGetPosInCodomain(brgemmDims.leadingDimC, {mAffinePos}, operandC); // checkAndGetPosInCodomain(brgemmDims.minorDimC, kAffinePos, operandC); + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] A batch dim: " + << brgemmDims.batchDimA + << ", A leading dim: " << brgemmDims.leadingDimA + << ", A minor dim: " << brgemmDims.minorDimA + << "; B batch dim: " << brgemmDims.batchDimB + << ", B leading dim: " << brgemmDims.leadingDimB + << ", B minor dim: " << brgemmDims.minorDimB << "\n"); return brgemmDims; } diff --git a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir index e9d40eedb..35baface9 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir @@ -1,173 +1,175 @@ // RUN: gc-opt %s -convert-linalg-to-microkernel -split-input-file | FileCheck %s #map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @basic_linalg_to_microkernel() { - %cst = arith.constant 0.000000e+00 : f32 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> - %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.forall (%arg7, %arg8) in (4, 8) { - %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - linalg.batch_reduce_matmul ins(%subview, %subview_11 : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) - %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { - ^bb0(%in: f32, %in_14: f32, %out: f32): - %0 = arith.addf %in, %in_14 : f32 - linalg.yield %0 : f32 - } - %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - memref.dealloc %alloc_10 : memref<32x32xf32> +func.func @basic_linalg_to_microkernel(%arg0: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { + %alloc_1 = tensor.empty() : tensor<4x16x32x32xf32> + %alloc_4 = tensor.empty() : tensor<8x16x32x32xf32> + %ret = scf.forall (%arg7, %arg8) in (4, 8) shared_outs(%argp = %arg0) -> (tensor<4x8x32x32xf32>) { + %alloc_10 = tensor.extract_slice %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> + %subview = tensor.extract_slice %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<16x32x32xf32> + %subview_11 = tensor.extract_slice %alloc_4[%arg8, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<8x16x32x32xf32> to tensor<16x32x32xf32> + %res = linalg.batch_reduce_matmul ins(%subview, %subview_11 : tensor<16x32x32xf32>, tensor<16x32x32xf32>) outs(%alloc_10 : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %res into %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<4x8x32x32xf32> + } } - return + return %ret : tensor<4x8x32x32xf32> } // CHECK-LABEL: basic_linalg_to_microkernel -// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 -// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 -// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> -// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) -// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () -// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () -// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () +// CHECK: scf.forall +// CHECK: %[[C:.+]] = tensor.extract_slice %[[Csrc:.+]][%[[arg1:.+]], %[[arg2:.+]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> +// CHECK: %[[A:.+]] = tensor.extract_slice %[[Asrc:.+]][%[[arg1]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<16x32x32xf32> +// CHECK: %[[B:.+]] = tensor.extract_slice %[[Bsrc:.+]][%[[arg2]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<8x16x32x32xf32> to tensor<16x32x32xf32> +// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]]) outs(%[[C]]) batch_dims(0, 0) leading_dims(1, 1) flags() -> tensor<32x32xf32> +// CHECK-NEXT: scf.forall.in_parallel // ----- #map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @vnni_linalg_to_microkernel() { - %cst = arith.constant 0.000000e+00 : f32 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> - %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.forall (%arg7, %arg8) in (4, 8) { - %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - linalgx.batch_reduce_matmul_vnni ins(%subview, %subview_11 : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) - %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { - ^bb0(%in: f32, %in_14: f32, %out: f32): - %0 = arith.addf %in, %in_14 : f32 - linalg.yield %0 : f32 - } - %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - memref.dealloc %alloc_10 : memref<32x32xf32> +func.func @vnni_linalg_to_microkernel(%arg0: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { + %alloc_1 = tensor.empty() : tensor<4x16x32x32xbf16> + %alloc_4 = tensor.empty() : tensor<8x16x16x32x2xbf16> + %ret = scf.forall (%arg7, %arg8) in (4, 8) shared_outs(%argp = %arg0) -> (tensor<4x8x32x32xf32>) { + %alloc_10 = tensor.extract_slice %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> + %subview = tensor.extract_slice %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xbf16> to tensor<16x32x32xbf16> + %subview_11 = tensor.extract_slice %alloc_4[%arg8, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<8x16x16x32x2xbf16> to tensor<16x16x32x2xbf16> + %res = linalgx.batch_reduce_matmul_vnni ins(%subview, %subview_11 : tensor<16x32x32xbf16>, tensor<16x16x32x2xbf16>) outs(%alloc_10 : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %res into %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<4x8x32x32xf32> + } } - return + return %ret : tensor<4x8x32x32xf32> } // CHECK-LABEL: vnni_linalg_to_microkernel -// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 -// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 -// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> -// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) -// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () -// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () -// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () +// CHECK: scf.forall +// CHECK: %[[C:.+]] = tensor.extract_slice %[[Csrc:.+]][%[[arg1:.+]], %[[arg2:.+]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> +// CHECK: %[[A:.+]] = tensor.extract_slice %[[Asrc:.+]][%[[arg1]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xbf16> to tensor<16x32x32xbf16> +// CHECK: %[[B:.+]] = tensor.extract_slice %[[Bsrc:.+]][%[[arg2]], 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<8x16x16x32x2xbf16> to tensor<16x16x32x2xbf16> +// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]]) outs(%[[C]]) batch_dims(0, 0) leading_dims(1, 1) flags() -> tensor<32x32xf32> +// CHECK-NEXT: scf.forall.in_parallel // ----- #map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @basic_linalg_to_microkernel_fusing_fill() { +func.func @basic_linalg_to_microkernel_fusing_fill(%arg0: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { %cst = arith.constant 0.000000e+00 : f32 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> - %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.forall (%arg7, %arg8) in (4, 8) { - %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - linalg.fill ins(%cst : f32) outs(%alloc_10 : memref<32x32xf32>) - linalg.batch_reduce_matmul ins(%subview, %subview_11 : memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) - %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { - ^bb0(%in: f32, %in_14: f32, %out: f32): - %0 = arith.addf %in, %in_14 : f32 - linalg.yield %0 : f32 - } - %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - memref.dealloc %alloc_10 : memref<32x32xf32> + %alloc_1 = tensor.empty() : tensor<4x16x32x32xf32> + %alloc_4 = tensor.empty() : tensor<8x16x32x32xf32> + %ret = scf.forall (%arg7, %arg8) in (4, 8) shared_outs(%argp = %arg0) -> (tensor<4x8x32x32xf32>) { + %alloc_10 = tensor.extract_slice %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> + %subview = tensor.extract_slice %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<16x32x32xf32> + %subview_11 = tensor.extract_slice %alloc_4[%arg8, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<8x16x32x32xf32> to tensor<16x32x32xf32> + %11 = linalg.fill ins(%cst : f32) outs(%alloc_10 : tensor<32x32xf32>) -> tensor<32x32xf32> + %res = linalg.batch_reduce_matmul ins(%subview, %subview_11 : tensor<16x32x32xf32>, tensor<16x32x32xf32>) outs(%11 : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %res into %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<4x8x32x32xf32> + } } - return + return %ret : tensor<4x8x32x32xf32> } // CHECK-LABEL: basic_linalg_to_microkernel_fusing_fill -// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 -// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 -// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> -// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: scf.forall +// CHECK: %[[C:.+]] = tensor.extract_slice %[[Csrc:.+]][%[[arg1:.+]], %[[arg2:.+]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> +// CHECK: %[[A:.+]] = tensor.extract_slice %[[Asrc:.+]][%[[arg1]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<16x32x32xf32> +// CHECK: %[[B:.+]] = tensor.extract_slice %[[Bsrc:.+]][%[[arg2]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<8x16x32x32xf32> to tensor<16x32x32xf32> // CHECK-NOT: linalg.fill -// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(beta_0, stride) data_type(f32, f32) -// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () -// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () -// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () +// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]]) outs(%[[C]]) batch_dims(0, 0) leading_dims(1, 1) flags(beta_0) -> tensor<32x32xf32> +// CHECK-NEXT: scf.forall.in_parallel // ----- #map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @vnni_linalg_to_microkernel_fusing_fill() { +func.func @vnni_linalg_to_microkernel_fusing_fill(%arg0: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { %cst = arith.constant 0.000000e+00 : f32 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> - %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.forall (%arg7, %arg8) in (4, 8) { - %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %subview = memref.subview %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_11 = memref.subview %alloc_4[%arg8, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - linalg.fill ins(%cst : f32) outs(%alloc_10 : memref<32x32xf32>) - linalgx.batch_reduce_matmul_vnni ins(%subview, %subview_11 : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) - %subview_12 = memref.subview %alloc_5[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10, %subview_12 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_10 : memref<32x32xf32>) { - ^bb0(%in: f32, %in_14: f32, %out: f32): - %0 = arith.addf %in, %in_14 : f32 - linalg.yield %0 : f32 - } - %subview_13 = memref.subview %alloc_6[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%alloc_10 : memref<32x32xf32>) outs(%subview_13 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - memref.dealloc %alloc_10 : memref<32x32xf32> + %alloc_1 = tensor.empty() : tensor<4x16x32x32xbf16> + %alloc_4 = tensor.empty() : tensor<8x16x16x32x2xbf16> + %ret = scf.forall (%arg7, %arg8) in (4, 8) shared_outs(%argp = %arg0) -> (tensor<4x8x32x32xf32>) { + %alloc_10 = tensor.extract_slice %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> + %subview = tensor.extract_slice %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xbf16> to tensor<16x32x32xbf16> + %subview_11 = tensor.extract_slice %alloc_4[%arg8, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<8x16x16x32x2xbf16> to tensor<16x16x32x2xbf16> + %11 = linalg.fill ins(%cst : f32) outs(%alloc_10 : tensor<32x32xf32>) -> tensor<32x32xf32> + %res = linalgx.batch_reduce_matmul_vnni ins(%subview, %subview_11 : tensor<16x32x32xbf16>, tensor<16x16x32x2xbf16>) outs(%11 : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %res into %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<4x8x32x32xf32> + } } - return + return %ret : tensor<4x8x32x32xf32> } // CHECK-LABEL: vnni_linalg_to_microkernel_fusing_fill -// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 -// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 -// CHECK: %[[C:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> -// CHECK: %[[A:.+]] = memref.subview %[[TMP1:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[B:.+]] = memref.subview %[[TMP2:.+]][%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> +// CHECK: scf.forall +// CHECK: %[[C:.+]] = tensor.extract_slice %[[Csrc:.+]][%[[arg1:.+]], %[[arg2:.+]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> +// CHECK: %[[A:.+]] = tensor.extract_slice %[[Asrc:.+]][%[[arg1]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xbf16> to tensor<16x32x32xbf16> +// CHECK: %[[B:.+]] = tensor.extract_slice %[[Bsrc:.+]][%[[arg2]], 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<8x16x16x32x2xbf16> to tensor<16x16x32x2xbf16> +// CHECK-NOT: linalg.fill +// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]]) outs(%[[C]]) batch_dims(0, 0) leading_dims(1, 1) flags(beta_0) -> tensor<32x32xf32> +// CHECK-NEXT: scf.forall.in_parallel + +// ----- + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @basic_linalg_to_microkernel_fusing_transpose(%arg0: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = tensor.empty() : tensor<4x16x32x32xf32> + %alloc_4 = tensor.empty() : tensor<8x32x16x32xf32> + %trans_base = tensor.empty() : tensor<16x32x32xf32> + %ret = scf.forall (%arg7, %arg8) in (4, 8) shared_outs(%argp = %arg0) -> (tensor<4x8x32x32xf32>) { + %alloc_10 = tensor.extract_slice %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> + %subview = tensor.extract_slice %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<16x32x32xf32> + %subview_11 = tensor.extract_slice %alloc_4[%arg8, 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : tensor<8x32x16x32xf32> to tensor<32x16x32xf32> + %11 = linalg.fill ins(%cst : f32) outs(%alloc_10 : tensor<32x32xf32>) -> tensor<32x32xf32> + %transposed = linalg.transpose ins(%subview_11 : tensor<32x16x32xf32>) outs(%trans_base : tensor<16x32x32xf32>) permutation = [1, 0, 2] + %res = linalg.batch_reduce_matmul ins(%subview, %transposed : tensor<16x32x32xf32>, tensor<16x32x32xf32>) outs(%11 : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %res into %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<4x8x32x32xf32> + } + } + return %ret : tensor<4x8x32x32xf32> +} + +// CHECK-LABEL: basic_linalg_to_microkernel_fusing_transpose +// CHECK: scf.forall +// CHECK: %[[C:.+]] = tensor.extract_slice %[[Csrc:.+]][%[[arg1:.+]], %[[arg2:.+]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> +// CHECK: %[[A:.+]] = tensor.extract_slice %[[Asrc:.+]][%[[arg1]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<16x32x32xf32> +// CHECK: %[[B:.+]] = tensor.extract_slice %[[Bsrc:.+]][%[[arg2]], 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : tensor<8x32x16x32xf32> to tensor<32x16x32xf32> +// CHECK-NOT: linalg.fill +// CHECK-NOT: linalg.transpose +// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]]) outs(%[[C]]) batch_dims(0, 1) leading_dims(1, 0) flags(beta_0) -> tensor<32x32xf32> +// CHECK-NEXT: scf.forall.in_parallel + +// ----- + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @vnni_linalg_to_microkernel_fusing_transpose(%arg0: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %alloc_1 = tensor.empty() : tensor<4x16x32x32xbf16> + %alloc_4 = tensor.empty() : tensor<8x16x16x32x2xbf16> + %trans_base = tensor.empty() : tensor<16x16x32x2xbf16> + %ret = scf.forall (%arg7, %arg8) in (4, 8) shared_outs(%argp = %arg0) -> (tensor<4x8x32x32xf32>) { + %alloc_10 = tensor.extract_slice %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> + %subview = tensor.extract_slice %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xbf16> to tensor<16x32x32xbf16> + %subview_11 = tensor.extract_slice %alloc_4[%arg8, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<8x16x16x32x2xbf16> to tensor<16x16x32x2xbf16> + %11 = linalg.fill ins(%cst : f32) outs(%alloc_10 : tensor<32x32xf32>) -> tensor<32x32xf32> + %transposed = linalg.transpose ins(%subview_11 : tensor<16x16x32x2xbf16>) outs(%trans_base : tensor<16x16x32x2xbf16>) permutation = [1, 0, 2, 3] + %res = linalgx.batch_reduce_matmul_vnni ins(%subview, %transposed : tensor<16x32x32xbf16>, tensor<16x16x32x2xbf16>) outs(%11 : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %res into %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<4x8x32x32xf32> + } + } + return %ret : tensor<4x8x32x32xf32> +} + +// CHECK-LABEL: vnni_linalg_to_microkernel_fusing_transpose +// CHECK: scf.forall +// CHECK: %[[C:.+]] = tensor.extract_slice %[[Csrc:.+]][%[[arg1:.+]], %[[arg2:.+]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> +// CHECK: %[[A:.+]] = tensor.extract_slice %[[Asrc:.+]][%[[arg1]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xbf16> to tensor<16x32x32xbf16> +// CHECK: %[[B:.+]] = tensor.extract_slice %[[Bsrc:.+]][%[[arg2]], 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<8x16x16x32x2xbf16> to tensor<16x16x32x2xbf16> // CHECK-NOT: linalg.fill -// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(beta_0, stride) data_type(bf16, bf16) -// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> () -// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]], %[[A]], %[[B]], %[[C]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () -// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> () +// CHECK-NOT: linalg.transpose +// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]]) outs(%[[C]]) batch_dims(0, 1) leading_dims(1, 0) flags(beta_0) -> tensor<32x32xf32> +// CHECK-NEXT: scf.forall.in_parallel // ----- From a5b99a42281f72481812f4d324af1cb145a7bb6d Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 14 Aug 2024 00:01:56 -0700 Subject: [PATCH 63/93] fix bufferizableOpInterface --- include/gc/Dialect/Microkernel/MicrokernelOps.td | 2 ++ lib/gc/Dialect/Microkernel/MicrokernelOps.cpp | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/include/gc/Dialect/Microkernel/MicrokernelOps.td b/include/gc/Dialect/Microkernel/MicrokernelOps.td index 366933e7f..fdd4b7347 100644 --- a/include/gc/Dialect/Microkernel/MicrokernelOps.td +++ b/include/gc/Dialect/Microkernel/MicrokernelOps.td @@ -73,6 +73,8 @@ def Microkernel_BrgemmOp : Microkernel_Op<"brgemm", const bufferization::AnalysisState &); bool bufferizesToElementwiseAccess(const bufferization::AnalysisState &, ArrayRef); + bufferization::AliasingValueList getAliasingValues(OpOperand &opOperand, + const bufferization::AnalysisState &state); LogicalResult bufferize(RewriterBase &, const bufferization::BufferizationOptions &); }]; diff --git a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp index 143e31705..98e20e76f 100644 --- a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp +++ b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp @@ -387,6 +387,18 @@ bool BrgemmOp::bufferizesToElementwiseAccess(const AnalysisState &state, return false; } +AliasingValueList BrgemmOp::getAliasingValues(OpOperand &opOperand, + const AnalysisState &state) { + // This implementation refers to linalg's usage of + // ` DstBufferizableOpInterfaceExternalModel` + Operation *op = *this; + // Output operands alias with their respective tied OpResults. + auto dstOp = cast(op); + if (dstOp.isDpsInit(&opOperand)) + return {{dstOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}; + return {}; +} + LogicalResult BrgemmOp::bufferize(RewriterBase &rewriter, const BufferizationOptions &options) { // This implementation refers to linalg's From 75e926b15f0ded0841e8d99b718a09c79d5e66fe Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 14 Aug 2024 01:52:56 -0700 Subject: [PATCH 64/93] fix BrgemmOp asm & add bufferization tests --- lib/gc/Dialect/Microkernel/MicrokernelOps.cpp | 8 ++- .../Microkernel/linalg-to-microkernel.mlir | 12 ++--- .../Microkernel/microkernel-bufferize.mlir | 53 +++++++++++++++++++ 3 files changed, 65 insertions(+), 8 deletions(-) create mode 100644 test/mlir/test/gc/Dialect/Microkernel/microkernel-bufferize.mlir diff --git a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp index 98e20e76f..06ae0e237 100644 --- a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp +++ b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp @@ -242,8 +242,12 @@ ParseResult BrgemmOp::parse(OpAsmParser &parser, OperationState &result) { void BrgemmOp::print(OpAsmPrinter &printer) { BrgemmOp op = *this; - printer << " " << INPUTS_ASM_NAME << "(" << op.getInputs() << ")"; - printer << " " << OUTPUTS_ASM_NAME << "(" << op.getInit() << ")"; + ValueRange inputs = op.getInputs(); + Value init = op.getInit(); + printer << " " << INPUTS_ASM_NAME << "(" << inputs << " : " + << inputs.getTypes() << ")"; + printer << " " << OUTPUTS_ASM_NAME << "(" << init << " : " << init.getType() + << ")"; printer << " " << BATCH_DIMS_ASM_NAME << "(" << op.getBatchDims() << ")"; printer << " " << LEADING_DIMS_ASM_NAME << "(" << op.getLeadingDims() << ")"; diff --git a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir index 35baface9..2a4890c8b 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir @@ -21,7 +21,7 @@ func.func @basic_linalg_to_microkernel(%arg0: tensor<4x8x32x32xf32>) -> tensor<4 // CHECK: %[[C:.+]] = tensor.extract_slice %[[Csrc:.+]][%[[arg1:.+]], %[[arg2:.+]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> // CHECK: %[[A:.+]] = tensor.extract_slice %[[Asrc:.+]][%[[arg1]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<16x32x32xf32> // CHECK: %[[B:.+]] = tensor.extract_slice %[[Bsrc:.+]][%[[arg2]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<8x16x32x32xf32> to tensor<16x32x32xf32> -// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]]) outs(%[[C]]) batch_dims(0, 0) leading_dims(1, 1) flags() -> tensor<32x32xf32> +// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]] : [[TYPE:.+]]) outs(%[[C]] : [[TYPEC:.+]]) batch_dims(0, 0) leading_dims(1, 1) flags() -> tensor<32x32xf32> // CHECK-NEXT: scf.forall.in_parallel // ----- @@ -47,7 +47,7 @@ func.func @vnni_linalg_to_microkernel(%arg0: tensor<4x8x32x32xf32>) -> tensor<4x // CHECK: %[[C:.+]] = tensor.extract_slice %[[Csrc:.+]][%[[arg1:.+]], %[[arg2:.+]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> // CHECK: %[[A:.+]] = tensor.extract_slice %[[Asrc:.+]][%[[arg1]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xbf16> to tensor<16x32x32xbf16> // CHECK: %[[B:.+]] = tensor.extract_slice %[[Bsrc:.+]][%[[arg2]], 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<8x16x16x32x2xbf16> to tensor<16x16x32x2xbf16> -// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]]) outs(%[[C]]) batch_dims(0, 0) leading_dims(1, 1) flags() -> tensor<32x32xf32> +// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]] : [[TYPE:.+]]) outs(%[[C]] : [[TYPEC:.+]]) batch_dims(0, 0) leading_dims(1, 1) flags() -> tensor<32x32xf32> // CHECK-NEXT: scf.forall.in_parallel // ----- @@ -76,7 +76,7 @@ func.func @basic_linalg_to_microkernel_fusing_fill(%arg0: tensor<4x8x32x32xf32>) // CHECK: %[[A:.+]] = tensor.extract_slice %[[Asrc:.+]][%[[arg1]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<16x32x32xf32> // CHECK: %[[B:.+]] = tensor.extract_slice %[[Bsrc:.+]][%[[arg2]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<8x16x32x32xf32> to tensor<16x32x32xf32> // CHECK-NOT: linalg.fill -// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]]) outs(%[[C]]) batch_dims(0, 0) leading_dims(1, 1) flags(beta_0) -> tensor<32x32xf32> +// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]] : [[TYPE:.+]]) outs(%[[C]] : [[TYPE2:.+]]) batch_dims(0, 0) leading_dims(1, 1) flags(beta_0) -> tensor<32x32xf32> // CHECK-NEXT: scf.forall.in_parallel // ----- @@ -105,7 +105,7 @@ func.func @vnni_linalg_to_microkernel_fusing_fill(%arg0: tensor<4x8x32x32xf32>) // CHECK: %[[A:.+]] = tensor.extract_slice %[[Asrc:.+]][%[[arg1]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xbf16> to tensor<16x32x32xbf16> // CHECK: %[[B:.+]] = tensor.extract_slice %[[Bsrc:.+]][%[[arg2]], 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<8x16x16x32x2xbf16> to tensor<16x16x32x2xbf16> // CHECK-NOT: linalg.fill -// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]]) outs(%[[C]]) batch_dims(0, 0) leading_dims(1, 1) flags(beta_0) -> tensor<32x32xf32> +// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]] : [[TYPE:.+]]) outs(%[[C]] : [[TYPE2:.+]]) batch_dims(0, 0) leading_dims(1, 1) flags(beta_0) -> tensor<32x32xf32> // CHECK-NEXT: scf.forall.in_parallel // ----- @@ -137,7 +137,7 @@ func.func @basic_linalg_to_microkernel_fusing_transpose(%arg0: tensor<4x8x32x32x // CHECK: %[[B:.+]] = tensor.extract_slice %[[Bsrc:.+]][%[[arg2]], 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : tensor<8x32x16x32xf32> to tensor<32x16x32xf32> // CHECK-NOT: linalg.fill // CHECK-NOT: linalg.transpose -// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]]) outs(%[[C]]) batch_dims(0, 1) leading_dims(1, 0) flags(beta_0) -> tensor<32x32xf32> +// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]] : [[TYPE:.+]]) outs(%[[C]] : [[TYPE2:.+]]) batch_dims(0, 1) leading_dims(1, 0) flags(beta_0) -> tensor<32x32xf32> // CHECK-NEXT: scf.forall.in_parallel // ----- @@ -169,7 +169,7 @@ func.func @vnni_linalg_to_microkernel_fusing_transpose(%arg0: tensor<4x8x32x32xf // CHECK: %[[B:.+]] = tensor.extract_slice %[[Bsrc:.+]][%[[arg2]], 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<8x16x16x32x2xbf16> to tensor<16x16x32x2xbf16> // CHECK-NOT: linalg.fill // CHECK-NOT: linalg.transpose -// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]]) outs(%[[C]]) batch_dims(0, 1) leading_dims(1, 0) flags(beta_0) -> tensor<32x32xf32> +// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]] : [[TYPE:.+]]) outs(%[[C]] : [[TYPE2:.+]]) batch_dims(0, 1) leading_dims(1, 0) flags(beta_0) -> tensor<32x32xf32> // CHECK-NEXT: scf.forall.in_parallel // ----- diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-bufferize.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-bufferize.mlir new file mode 100644 index 000000000..bc561790c --- /dev/null +++ b/test/mlir/test/gc/Dialect/Microkernel/microkernel-bufferize.mlir @@ -0,0 +1,53 @@ +// RUN: gc-opt %s -one-shot-bufferize -split-input-file | FileCheck %s + +func.func @basic_microkernel_bufferize(%arg0: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { + %0 = tensor.empty() : tensor<4x16x32x32xf32> + %1 = tensor.empty() : tensor<8x16x32x32xf32> + %2 = scf.forall (%arg1, %arg2) in (4, 8) shared_outs(%arg3 = %arg0) -> (tensor<4x8x32x32xf32>) { + %extracted_slice = tensor.extract_slice %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> + %extracted_slice_0 = tensor.extract_slice %0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<16x32x32xf32> + %extracted_slice_1 = tensor.extract_slice %1[%arg2, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<8x16x32x32xf32> to tensor<16x32x32xf32> + %3 = microkernel.brgemm ins(%extracted_slice_0, %extracted_slice_1 : tensor<16x32x32xf32>, tensor<16x32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) batch_dims(0, 0) leading_dims(1, 1) flags() -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<4x8x32x32xf32> + } + } + return %2 : tensor<4x8x32x32xf32> +} + +// CHECK-LABEL: basic_microkernel_bufferize +// CHECK: %[[Abuf:.+]] = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> +// CHECK: %[[Bbuf:.+]] = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> +// CHECK: %[[Cbuf:.+]] = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> +// CHECK: scf.forall +// CHECK: %[[C:.+]] = memref.subview %[[Cbuf]] +// CHECK: %[[A:.+]] = memref.subview %[[Abuf]] +// CHECK: %[[B:.+]] = memref.subview %[[Bbuf]] +// CHECK: microkernel.brgemm ins(%[[A]], %[[B]] : [[TYPE:.+]]) outs(%[[C]] : [[TYPE2:.+]]) batch_dims(0, 0) leading_dims(1, 1) flags() + +// ----- + +func.func @vnni_microkernel_bufferize(%arg0: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { + %0 = tensor.empty() : tensor<4x16x32x32xbf16> + %1 = tensor.empty() : tensor<8x16x16x32x2xbf16> + %2 = scf.forall (%arg1, %arg2) in (4, 8) shared_outs(%arg3 = %arg0) -> (tensor<4x8x32x32xf32>) { + %extracted_slice = tensor.extract_slice %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> + %extracted_slice_0 = tensor.extract_slice %0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xbf16> to tensor<16x32x32xbf16> + %extracted_slice_1 = tensor.extract_slice %1[%arg2, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<8x16x16x32x2xbf16> to tensor<16x16x32x2xbf16> + %3 = microkernel.brgemm ins(%extracted_slice_0, %extracted_slice_1 : tensor<16x32x32xbf16>, tensor<16x16x32x2xbf16>) outs(%extracted_slice : tensor<32x32xf32>) batch_dims(0, 0) leading_dims(1, 1) flags() -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<4x8x32x32xf32> + } + } + return %2 : tensor<4x8x32x32xf32> +} + +// CHECK-LABEL: vnni_microkernel_bufferize +// CHECK: %[[Abuf:.+]] = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> +// CHECK: %[[Bbuf:.+]] = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> +// CHECK: %[[Cbuf:.+]] = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> +// CHECK: scf.forall +// CHECK: %[[C:.+]] = memref.subview %[[Cbuf]] +// CHECK: %[[A:.+]] = memref.subview %[[Abuf]] +// CHECK: %[[B:.+]] = memref.subview %[[Bbuf]] +// CHECK: microkernel.brgemm ins(%[[A]], %[[B]] : [[TYPE:.+]]) outs(%[[C]] : [[TYPE2:.+]]) batch_dims(0, 0) leading_dims(1, 1) flags() From 6d445cbb09b419ec3b194559ce9bcacb462c2ad1 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 20 Aug 2024 23:05:42 -0700 Subject: [PATCH 65/93] fix merge issue --- .../ConvertLinalgToMicrokernel.cpp | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index 5239a3452..bcba8403a 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -175,26 +175,26 @@ static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { BrgemmDims brgemmDims; - auto checkAndGetPosInCodomain = [&](int64_t &dim, ArrayRef dimPos, - OpOperand *operand) { - auto pos = getPosInCodomain(dimPos, operand, linalgOp); + auto checkAndGetPosInDomain = [&](int64_t &dim, ArrayRef dimPos, + OpOperand *operand) { + auto pos = getPosInDomain(dimPos, operand, linalgOp); assert(pos && "Cannot find position in codomain"); dim = *pos; }; // A(batch, m, k) - checkAndGetPosInCodomain(brgemmDims.batchDimA, batchAffinePos, operandA); - checkAndGetPosInCodomain(brgemmDims.leadingDimA, {mAffinePos}, operandA); - checkAndGetPosInCodomain(brgemmDims.minorDimA, kAffinePos, operandA); + checkAndGetPosInDomain(brgemmDims.batchDimA, batchAffinePos, operandA); + checkAndGetPosInDomain(brgemmDims.leadingDimA, {mAffinePos}, operandA); + checkAndGetPosInDomain(brgemmDims.minorDimA, kAffinePos, operandA); // B(batch, k, n) or B(batch, k/vnni_step, n, vnni_step) // note: B does not use VNNI format K affine - checkAndGetPosInCodomain(brgemmDims.batchDimB, batchAffinePos, operandB); - checkAndGetPosInCodomain(brgemmDims.leadingDimB, {kAffinePos[0]}, operandB); - checkAndGetPosInCodomain(brgemmDims.minorDimB, {nAffinePos}, operandB); + checkAndGetPosInDomain(brgemmDims.batchDimB, batchAffinePos, operandB); + checkAndGetPosInDomain(brgemmDims.leadingDimB, {kAffinePos[0]}, operandB); + checkAndGetPosInDomain(brgemmDims.minorDimB, {nAffinePos}, operandB); // C(m, n) // Currently useless, no need to set - // checkAndGetPosInCodomain(brgemmDims.leadingDimC, {mAffinePos}, operandC); - // checkAndGetPosInCodomain(brgemmDims.minorDimC, kAffinePos, operandC); + // checkAndGetPosInDomain(brgemmDims.leadingDimC, {mAffinePos}, operandC); + // checkAndGetPosInDomain(brgemmDims.minorDimC, kAffinePos, operandC); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] A batch dim: " << brgemmDims.batchDimA From e2858414b279e482201d0a2d4c69809f45ecbcd3 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 20 Aug 2024 23:21:34 -0700 Subject: [PATCH 66/93] fix format --- lib/gc/Dialect/Microkernel/MicrokernelOps.cpp | 2 +- lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp index 06ae0e237..20f736fc7 100644 --- a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp +++ b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp @@ -17,8 +17,8 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" -#include "gc/Dialect/Microkernel/MicrokernelOps.h" #include "gc/Dialect/Microkernel/MicrokernelDialect.h" +#include "gc/Dialect/Microkernel/MicrokernelOps.h" #define GET_OP_CLASSES #include "gc/Dialect/Microkernel/MicrokernelOps.cpp.inc" diff --git a/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp index 20c4443ca..5830b9ac4 100644 --- a/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp @@ -1,10 +1,10 @@ -//===- ConvertLinalgToMicrokernel.cpp - Linalg To Microkernel -*- C++ -*--===// +//===-- ExpandMicrokernel.cpp - Expand frontend microkernel Op --*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===---------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// #include From 4e2c05393d2faffa4cae2427d2778bf6fbf8c155 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 20 Aug 2024 23:32:52 -0700 Subject: [PATCH 67/93] fix clang --- lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index bcba8403a..ccb3171b1 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -40,7 +40,7 @@ struct BrgemmDims { int64_t leadingDimB; int64_t minorDimB; - BrgemmDims() {} + BrgemmDims() = default; BrgemmDims(int64_t bdA, int64_t ldA, int64_t mdA, int64_t bdB, int64_t ldB, int64_t mdB) : batchDimA(bdA), leadingDimA(ldA), minorDimA(mdA), batchDimB(bdB), @@ -230,7 +230,7 @@ static FailureOr getFusibleTranspose(SrcBrmmOpTy brmmOp, // For VNNI, it requires the last two dims to be non-permutedi for (size_t idx = permutation.size() - lastDimOffset; idx < permutation.size(); idx++) - lastDimContigious = lastDimContigious && (permutation[idx] == idx); + lastDimContigious = lastDimContigious && (permutation[idx] == long(idx)); if (lastDimContigious) return transOp; From c5d935bea6f21a3799ae752798067bc32f8a397d Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 21 Aug 2024 19:46:40 -0700 Subject: [PATCH 68/93] fix cpu-runner test --- test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir index a76f1c617..cd927f175 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s --convert-microkernel-to-dnnl-func --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void +// RUN: gc-opt %s --convert-microkernel-to-dnnl-func --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-cpuruntime-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void #map = affine_map<(d0, d1) -> (d0, d1)> module { @@ -42,9 +42,9 @@ module { func.func @main() { call @simple_brgemm() : ()->() - // COM: parallelcpu.printf "BRGEMM DONE\n" + cpuruntime.printf "BRGEMM DONE\n" return } - // COM: CHECK: BRGEMM DONE + // CHECK: BRGEMM DONE } From 705658868d860bb35d63a87159d2f4b770bd2ed7 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Fri, 23 Aug 2024 00:09:02 -0700 Subject: [PATCH 69/93] fix per review --- .../Microkernel/MicrokernelPasses.td | 1 + .../ConvertLinalgToMicrokernel.cpp | 34 +++++++++++-------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td index dc554aeed..4af89f26a 100644 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.td +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.td @@ -31,6 +31,7 @@ def ConvertLinalgToMicrokernel: Pass<"convert-linalg-to-microkernel", "::mlir::f ``` }]; let dependentDialects = ["func::FuncDialect", + "tensor::TensorDialect", "memref::MemRefDialect", "linalg::LinalgDialect", "linalgx::LinalgxDialect", diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index ccb3171b1..59e807973 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Support/LogicalResult.h" @@ -175,26 +176,29 @@ static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { BrgemmDims brgemmDims; - auto checkAndGetPosInDomain = [&](int64_t &dim, ArrayRef dimPos, - OpOperand *operand) { - auto pos = getPosInDomain(dimPos, operand, linalgOp); - assert(pos && "Cannot find position in codomain"); - dim = *pos; - }; +#define CHECK_GET_POS_IN_DOMAIN(dim, dimPos, operand) \ + pos = getPosInDomain(dimPos, operand, linalgOp); \ + if (!pos) { \ + LLVM_DEBUG(llvm::dbgs() << "Cannot find position in domain for operand: " \ + << operand << "\n"); \ + return failure(); \ + } \ + dim = *pos; + std::optional pos = std::nullopt; // A(batch, m, k) - checkAndGetPosInDomain(brgemmDims.batchDimA, batchAffinePos, operandA); - checkAndGetPosInDomain(brgemmDims.leadingDimA, {mAffinePos}, operandA); - checkAndGetPosInDomain(brgemmDims.minorDimA, kAffinePos, operandA); + CHECK_GET_POS_IN_DOMAIN(brgemmDims.batchDimA, batchAffinePos, operandA); + CHECK_GET_POS_IN_DOMAIN(brgemmDims.leadingDimA, {mAffinePos}, operandA); + CHECK_GET_POS_IN_DOMAIN(brgemmDims.minorDimA, kAffinePos, operandA); // B(batch, k, n) or B(batch, k/vnni_step, n, vnni_step) // note: B does not use VNNI format K affine - checkAndGetPosInDomain(brgemmDims.batchDimB, batchAffinePos, operandB); - checkAndGetPosInDomain(brgemmDims.leadingDimB, {kAffinePos[0]}, operandB); - checkAndGetPosInDomain(brgemmDims.minorDimB, {nAffinePos}, operandB); + CHECK_GET_POS_IN_DOMAIN(brgemmDims.batchDimB, batchAffinePos, operandB); + CHECK_GET_POS_IN_DOMAIN(brgemmDims.leadingDimB, {kAffinePos[0]}, operandB); + CHECK_GET_POS_IN_DOMAIN(brgemmDims.minorDimB, {nAffinePos}, operandB); // C(m, n) - // Currently useless, no need to set - // checkAndGetPosInDomain(brgemmDims.leadingDimC, {mAffinePos}, operandC); - // checkAndGetPosInDomain(brgemmDims.minorDimC, kAffinePos, operandC); + // Currently C dims are useless, no need to set + +#undef CHECK_GET_POS_IN_DOMAIN LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] A batch dim: " << brgemmDims.batchDimA From ef883731c4f0c1d4b94179f496f2233b6d8e0ec6 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Sun, 25 Aug 2024 23:28:40 -0700 Subject: [PATCH 70/93] fix per comments --- .../ConvertLinalgToMicrokernel.cpp | 80 +++++++++++-------- .../Microkernel/linalg-to-microkernel.mlir | 44 ++++++++++ 2 files changed, 90 insertions(+), 34 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index 59e807973..1233b504c 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -113,8 +113,6 @@ static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { auto validBrgemmMatcher = StructuredOpMatcher::make() .output(MatchAll(), HasStaticShape()) .input(MatchAll(), HasStaticShape()) - .output(MatchAll(), HasStaticStrides()) - .input(MatchAll(), HasStaticStrides()) .operation(NumOfLoops(GreaterThanOrEqualTo(3))); if (!validBrgemmMatcher.match(linalgOp)) return failure(); @@ -293,6 +291,50 @@ static bool isZeroArithConstant(arith::ConstantOp op) { return true; } +template +static std::pair +checkFusibleTransposeOp(BrgemmDims &brgemmDims, + DenseMap &replaceMap, ContractionOp op) { + // Check for fusible linalg::TransposeOp on operand A & B + Value operandA = op.getDpsInputOperands()[0]->get(); + Value operandB = op.getDpsInputOperands()[1]->get(); + auto fusibleTransA = getFusibleTranspose(op, operandA); + auto fusibleTransB = getFusibleTranspose(op, operandB); + // Presumably minorDims are last dims and not permutated, so no need to + // transform them + if (!failed(fusibleTransA)) { + ArrayRef permutation = fusibleTransA->getPermutation(); + brgemmDims.batchDimA = permutation[brgemmDims.batchDimA]; + brgemmDims.leadingDimA = permutation[brgemmDims.leadingDimA]; + replaceMap[fusibleTransA->getResult()[0]] = fusibleTransA->getInput(); + } + if (!failed(fusibleTransB)) { + ArrayRef permutation = fusibleTransB->getPermutation(); + brgemmDims.batchDimB = permutation[brgemmDims.batchDimB]; + brgemmDims.leadingDimB = permutation[brgemmDims.leadingDimB]; + replaceMap[fusibleTransB->getResult()[0]] = fusibleTransB->getInput(); + } + return {!failed(fusibleTransA), !failed(fusibleTransB)}; +} + +template +static bool checkFusibleFillOp(DenseMap &replaceMap, + ContractionOp op) { + // Check for fusible linalg::FillOp on operand C + bool fuseFill = false; + Value operandC = op.getDpsInitsMutable()[0].get(); + auto defOp = operandC.getDefiningOp(); + if (auto fillOp = dyn_cast(defOp)) { + auto inputCst = dyn_cast_or_null( + fillOp.getInputs()[0].getDefiningOp()); + if (isZeroArithConstant(inputCst)) { + replaceMap[fillOp.getResultTensors()[0]] = fillOp.getOutputs()[0]; + fuseFill = true; + } + } + return fuseFill; +} + template class ConvertContractionOpToBrgemmRewriter : public OpRewritePattern { @@ -308,39 +350,9 @@ class ConvertContractionOpToBrgemmRewriter return failure(); DenseMap replaceMap; - // Check for fusible linalg::TransposeOp on operand A & B - Value operandA = op.getDpsInputOperands()[0]->get(); - Value operandB = op.getDpsInputOperands()[1]->get(); - auto fusibleTransA = getFusibleTranspose(op, operandA); - auto fusibleTransB = getFusibleTranspose(op, operandB); - // Presumably minorDims are last dims and not permutated, so no need to - // transform them - if (!failed(fusibleTransA)) { - ArrayRef permutation = fusibleTransA->getPermutation(); - brgemmDims->batchDimA = permutation[brgemmDims->batchDimA]; - brgemmDims->leadingDimA = permutation[brgemmDims->leadingDimA]; - replaceMap[fusibleTransA->getResult()[0]] = fusibleTransA->getInput(); - } - if (!failed(fusibleTransB)) { - ArrayRef permutation = fusibleTransB->getPermutation(); - brgemmDims->batchDimB = permutation[brgemmDims->batchDimB]; - brgemmDims->leadingDimB = permutation[brgemmDims->leadingDimB]; - replaceMap[fusibleTransB->getResult()[0]] = fusibleTransB->getInput(); - } - // Check for fusible linalg::FillOp on operand C - bool isInitOutput = false; - Value operandC = op.getDpsInitsMutable()[0].get(); - auto defOp = operandC.getDefiningOp(); - if (llvm::isa(defOp)) { - auto fillOp = dyn_cast_or_null(defOp); - auto inputCst = dyn_cast_or_null( - fillOp.getInputs()[0].getDefiningOp()); - if (isZeroArithConstant(inputCst)) { - replaceMap[fillOp.getResultTensors()[0]] = fillOp.getOutputs()[0]; - isInitOutput = true; - } - } + checkFusibleTransposeOp(*brgemmDims, replaceMap, op); + bool isInitOutput = checkFusibleFillOp(replaceMap, op); replaceOpWithMicrokernelOp(rewriter, op, *brgemmDims, replaceMap, isInitOutput); diff --git a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir index 2a4890c8b..4db711acb 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/linalg-to-microkernel.mlir @@ -173,3 +173,47 @@ func.func @vnni_linalg_to_microkernel_fusing_transpose(%arg0: tensor<4x8x32x32xf // CHECK-NEXT: scf.forall.in_parallel // ----- + +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @basic_linalg_to_microkernel_fusing_with_branch(%arg0: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 0 : index + %cst_1 = arith.constant 1 : index + %alloc_1 = tensor.empty() : tensor<4x16x32x32xf32> + %alloc_4 = tensor.empty() : tensor<8x32x16x32xf32> + %trans_base = tensor.empty() : tensor<16x32x32xf32> + %ret = scf.forall (%arg7, %arg8) in (4, 8) shared_outs(%argp = %arg0) -> (tensor<4x8x32x32xf32>) { + %alloc_10 = tensor.extract_slice %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> + %subview = tensor.extract_slice %alloc_1[%arg7, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<16x32x32xf32> + %subview_11 = tensor.extract_slice %alloc_4[%arg8, 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : tensor<8x32x16x32xf32> to tensor<32x16x32xf32> + %transposed = linalg.transpose ins(%subview_11 : tensor<32x16x32xf32>) outs(%trans_base : tensor<16x32x32xf32>) permutation = [1, 0, 2] + %6 = arith.cmpi eq, %cst_0, %cst_1 : index + %branch_res = scf.if %6 -> (tensor<32x32xf32>) { + %11 = linalg.fill ins(%cst : f32) outs(%alloc_10 : tensor<32x32xf32>) -> tensor<32x32xf32> + %res = linalg.batch_reduce_matmul ins(%subview, %transposed : tensor<16x32x32xf32>, tensor<16x32x32xf32>) outs(%11 : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.yield %res : tensor<32x32xf32> + } else { + %res = linalg.batch_reduce_matmul ins(%subview, %transposed : tensor<16x32x32xf32>, tensor<16x32x32xf32>) outs(%alloc_10: tensor<32x32xf32>) -> tensor<32x32xf32> + scf.yield %res : tensor<32x32xf32> + } + scf.forall.in_parallel { + tensor.parallel_insert_slice %branch_res into %argp[%arg7, %arg8, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<4x8x32x32xf32> + } + } + return %ret : tensor<4x8x32x32xf32> +} + +// CHECK-LABEL: basic_linalg_to_microkernel_fusing_with_branch +// CHECK: scf.forall +// CHECK: %[[C:.+]] = tensor.extract_slice %[[Csrc:.+]][%[[arg1:.+]], %[[arg2:.+]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<32x32xf32> +// CHECK: %[[A:.+]] = tensor.extract_slice %[[Asrc:.+]][%[[arg1]], 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : tensor<4x16x32x32xf32> to tensor<16x32x32xf32> +// CHECK: %[[B:.+]] = tensor.extract_slice %[[Bsrc:.+]][%[[arg2]], 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : tensor<8x32x16x32xf32> to tensor<32x16x32xf32> +// CHECK-NOT: linalg.transpose +// CHECK: scf.if +// CHECK-NOT: linalg.fill +// CHECK: %[[RES:.+]] = microkernel.brgemm ins(%[[A]], %[[B]] : [[TYPE:.+]]) outs(%[[C]] : [[TYPE2:.+]]) batch_dims(0, 1) leading_dims(1, 0) flags(beta_0) -> tensor<32x32xf32> +// CHECK: else +// CHECK: %[[RES]] = microkernel.brgemm ins(%[[A]], %[[B]] : [[TYPE:.+]]) outs(%[[C]] : [[TYPE2:.+]]) batch_dims(0, 1) leading_dims(1, 0) flags() -> tensor<32x32xf32> +// CHECK: scf.forall.in_parallel + +// ----- From c86c71eec06226162959d2a32fad4720bb03dd1c Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Sun, 25 Aug 2024 23:35:23 -0700 Subject: [PATCH 71/93] fix clang-tidy --- lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index 1233b504c..980fe8288 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -178,10 +178,10 @@ static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { pos = getPosInDomain(dimPos, operand, linalgOp); \ if (!pos) { \ LLVM_DEBUG(llvm::dbgs() << "Cannot find position in domain for operand: " \ - << operand << "\n"); \ + << (operand) << "\n"); \ return failure(); \ } \ - dim = *pos; + (dim) = *pos; std::optional pos = std::nullopt; // A(batch, m, k) From 5831fefcd79c86581c43082c0f6d7dbdb87ae78f Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Mon, 26 Aug 2024 06:35:03 -0700 Subject: [PATCH 72/93] replace some with real types --- .../Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp index 34d8c2cbd..647d8f784 100644 --- a/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertMicrokernelToDnnlFunc.cpp @@ -76,9 +76,9 @@ class ConvertBrgemmDispatchOpRewriter } // beta - auto flags = op.getFlagsAttr(); + ArrayAttr flags = op.getFlagsAttr(); float beta = 1.0f; - for (auto flag : flags) { + for (Attribute flag : flags) { auto brgemmFlag = dyn_cast_or_null(flag); if (!brgemmFlag) return rewriter.notifyMatchFailure(op, "unknown flag for BRGEMM"); From c00fcf45555b778e299ae82d520ec88e83806297 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Thu, 30 May 2024 00:45:17 -0700 Subject: [PATCH 73/93] add optimization pass --- .../Microkernel/MicrokernelPasses.td | 22 + lib/gc/Transforms/Microkernel/CMakeLists.txt | 2 + .../Microkernel/EarlyDispatchMicrokernel.cpp | 194 ++++++++ .../MicrokernelInvariantCodeMotion.cpp | 437 ++++++++++++++++++ .../Microkernel/brgemm-multilevel-for.mlir | 61 +++ .../Microkernel/brgemm-parallel.mlir | 50 ++ .../Microkernel/brgemm-simple-for.mlir | 56 +++ .../early-dispatch-microkernel.mlir | 62 +++ .../microkernel-invariant-code-motion.mlir | 191 ++++++++ .../test/gc/cpu-runner/brgemm-parallel.mlir | 2 +- 10 files changed, 1076 insertions(+), 1 deletion(-) create mode 100644 lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp create mode 100644 lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp create mode 100644 test/gc/cpu-runner/Microkernel/brgemm-multilevel-for.mlir create mode 100644 test/gc/cpu-runner/Microkernel/brgemm-parallel.mlir create mode 100644 test/gc/cpu-runner/Microkernel/brgemm-simple-for.mlir create mode 100644 test/mlir/test/gc/Dialect/Microkernel/early-dispatch-microkernel.mlir create mode 100644 test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td index 4af89f26a..021dd96a7 100644 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.td +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.td @@ -76,4 +76,26 @@ def ConvertMicrokernelToDnnlFunc: Pass<"convert-microkernel-to-dnnl-func", "::ml "microkernel::MicrokernelDialect"]; } +def EarlyDispatchMicrokernel: Pass<"early-dispatch-microkernel", "::mlir::ModuleOp"> { + let summary = "Early dispatch microkernel during compile time"; + let description = [{ + Early dispatch microkernel during compile time. + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", + "LLVM::LLVMDialect", + "microkernel::MicrokernelDialect"]; +} + +def MicrokernelInvariantCodeMotion: Pass<"microkernel-invariant-code-motion", "::mlir::ModuleOp"> { + let summary = "Hoist invariant microkernel code to avoid redundant execution"; + let description = [{ + Hoist invariant microkernel code to avoid redundant execution. + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect", + "LLVM::LLVMDialect", + "microkernel::MicrokernelDialect"]; +} + #endif // GC_DIALECT_MICROKERNELPASSES diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt index 9064b70db..c7bdc4617 100644 --- a/lib/gc/Transforms/Microkernel/CMakeLists.txt +++ b/lib/gc/Transforms/Microkernel/CMakeLists.txt @@ -6,6 +6,8 @@ gc_add_mlir_dialect_library(MLIRMicrokernelTransforms ConvertLinalgToMicrokernel.cpp ExpandMicrokernel.cpp ConvertMicrokernelToDnnlFunc.cpp + EarlyDispatchMicrokernel.cpp + MicrokernelInvariantCodeMotion.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ diff --git a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp new file mode 100644 index 000000000..a3843e01d --- /dev/null +++ b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp @@ -0,0 +1,194 @@ +//===- EarlyDispatchMicrokernel.cpp ----------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + +#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" +#include "gc/Utils/ValueUtils.h" +#include "oneapi/dnnl/dnnl_types.h" + +namespace mlir::microkernel { +#define GEN_PASS_DEF_EARLYDISPATCHMICROKERNEL +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define DEBUG_TYPE "early-dispatch-microkernel" + +static constexpr StringRef getGlobalCtorsVarName() { + return "llvm.global_ctors"; +} + +static FailureOr +createGlobalKernelHandleName(RewriterBase &rewriter, + microkernel::BrgemmDispatchOp op) { + // TODO(haixin): Add runtime backend type to global name + std::stringstream ss; + ss << "g_dispatched_microkernel_brgemm"; + + auto flags = op.getFlagsAttr(); + for (auto flag : flags) { + auto brgemmFlag = dyn_cast_or_null(flag); + if (!brgemmFlag) + return failure("unknown flag for BRGEMM"); + if (brgemmFlag.getValue() == BrgemmFlags::LIST) + return failure("addr mode BRGEMM not supported yet"); + if (brgemmFlag.getValue() == BrgemmFlags::BETA_0) + ss << "_init"; + } + + // M, N, K, LDA, LDB, LDC, stride_a, stride_b + // they are in the same order with BrgemmDispatchOp inputs + ArrayRef inputs = op.getInputsAttr().asArrayRef(); + for (auto input : inputs) { + ss << "_" << input; + } + + // dtypeA, dtypeB + auto dtypes = op.getDataType(); + if (dtypes.size() != 2) + return failure("invalid number of DataType for BRGEMM"); + ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[0]); + ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[1]); + + return ss.str(); +} + +// get or create global kernel handle with initializer, identified by +// `kernelName` +static FailureOr +getOrCreateGlobalKernelHandle(RewriterBase &rewriter, ModuleOp module, + const std::string &kernelName, + microkernel::BrgemmDispatchOp op) { + // Create the global at the entry of the module + LLVM::GlobalOp global; + if (!(global = module.lookupSymbol(kernelName))) { + auto global_type = op.getResults().getType(); + FlatSymbolRefAttr ctorName = + SymbolRefAttr::get(module->getContext(), kernelName + "_ctor"); + if (module.lookupSymbol(ctorName.getAttr())) { + return failure("Existing ctor for new global kernel handle"); + } + + OpBuilder::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + global = rewriter.create( + module.getLoc(), global_type, /*isConstant=*/false, + LLVM::Linkage::Internal, kernelName, Attribute(), + /*alignment=*/0); + + // create ctor for this global, which needs to be LLVMFuncOp + LLVM::LLVMFuncOp ctorFunc = rewriter.create( + module.getLoc(), ctorName.getValue(), + LLVM::LLVMFunctionType::get(global_type, {}, false)); + + Location loc = ctorFunc.getLoc(); + Block *entryBlock = ctorFunc.addEntryBlock(rewriter); + rewriter.setInsertionPointToEnd(entryBlock); + + auto dispatch = op.clone(); + rewriter.insert(dispatch); + Value globalPtr = rewriter.create(loc, global); + rewriter.create(loc, dispatch.getResults(), globalPtr); + rewriter.create(loc, dispatch.getResults()); + + // initialize the gloabl with global_ctors, as the initializer of global + // does not allow side effect + rewriter.setInsertionPointToStart(module.getBody()); + LLVM::GlobalCtorsOp global_ctors = + module.lookupSymbol(getGlobalCtorsVarName()); + SmallVector ctorRefs; + SmallVector priorities; + if (global_ctors) { + auto ctorRefsAttr = global_ctors.getCtors(); + auto prioritiesAttr = global_ctors.getPriorities(); + for (auto &&[ctor, prior] : llvm::zip(ctorRefsAttr, prioritiesAttr)) { + ctorRefs.push_back(ctor); + priorities.push_back(prior); + } + } + ctorRefs.push_back(ctorName); + // Set new ctor's priority to lowest + priorities.push_back(IntegerAttr::get(rewriter.getI32Type(), INT_MAX)); + if (global_ctors) { + // If there's existing ctors + rewriter.replaceOpWithNewOp( + global_ctors, rewriter.getArrayAttr(ctorRefs), + rewriter.getArrayAttr(priorities)); + } else { + rewriter.create(module.getLoc(), + rewriter.getArrayAttr(ctorRefs), + rewriter.getArrayAttr(priorities)); + } + } + return global; +} + +class EarlyDispatchBrgemmRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(microkernel::BrgemmDispatchOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + ModuleOp module = op->template getParentOfType(); + func::FuncOp func = op->template getParentOfType(); + + auto globalKernelName = createGlobalKernelHandleName(rewriter, op); + if (failed(globalKernelName)) { + return rewriter.notifyMatchFailure( + op, "Failed to create global kernel handle name"); + } + + // Generate kernel handle global name + auto globalKernel = + getOrCreateGlobalKernelHandle(rewriter, module, *globalKernelName, op); + if (failed(globalKernel)) { + return rewriter.notifyMatchFailure( + op, "Failed to create global kernel handle"); + } + + // Inject global val loading into start of function + auto funcBlock = &func.getBody().front(); + rewriter.setInsertionPointToStart(funcBlock); + Value globalPtr = rewriter.create(loc, *globalKernel); + Value globalVal = rewriter.create( + loc, op.getResults().getType(), globalPtr); + rewriter.moveOpAfter(op, funcBlock, funcBlock->begin()); + rewriter.replaceOp(op, globalVal); + return success(); + } +}; + +class EarlyDispatchMicrokernel + : public impl::EarlyDispatchMicrokernelBase { +public: + using impl::EarlyDispatchMicrokernelBase< + EarlyDispatchMicrokernel>::EarlyDispatchMicrokernelBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + + // Ignore newly created Ops + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + if (failed( + applyPatternsAndFoldGreedily(getOperation(), patternSet, config))) + signalPassFailure(); + } +}; + +} // namespace mlir::microkernel diff --git a/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp new file mode 100644 index 000000000..8604e9c59 --- /dev/null +++ b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp @@ -0,0 +1,437 @@ +//===- MicrokernelInvariantCodeMotion.cpp ----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include +#include + +#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" +#include "gc/Utils/ValueUtils.h" +#include "oneapi/dnnl/dnnl_types.h" + +namespace mlir::microkernel { +#define GEN_PASS_DEF_MICROKERNELINVARIANTCODEMOTION +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define DEBUG_TYPE "microkernel-invariant-code-motion" + +enum BrgemmCallType { INAPPLICABLE = -1, DISPATCH, TILECFG, TILERELEASE }; + +static bool isParallelLoop(Operation *op) { + return llvm::isa(op) || llvm::isa(op) || + llvm::isa(op) || llvm::isa(op); +} + +static bool isConcernedCF(Operation *op) { + return llvm::isa(op) || llvm::isa(op) || + llvm::isa(op) || llvm::isa(op); +} + +static BrgemmCallType getBrgemmCallType(Operation *op) { + if (!llvm::isa(op)) { + return BrgemmCallType::INAPPLICABLE; + } + auto callOp = dyn_cast(op); + StringAttr callee = callOp.getCalleeAttr().getAttr(); + + if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_DISPATCH_NAME)) { + return BrgemmCallType::DISPATCH; + } + if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_TILECFG_NAME)) { + return BrgemmCallType::TILECFG; + } + if (callee == + StringAttr::get(op->getContext(), DNNL_BRGEMM_TILERELEASE_NAME)) { + return BrgemmCallType::TILERELEASE; + } + return BrgemmCallType::INAPPLICABLE; +} + +// Tree node of structure info tree, each node represents an Op +// This tree contains only concerned Ops +struct BrgemmContextStructInfo { + // Basic structure info retrieved by first walk + Operation *contextRoot; // Could be parallel loop or func + Operation *self, *parent; + DenseSet child; + SmallVector containBrgemmCallType; + // Rewrite-time info retrieved by analysing basic structure info + union { + Operation *maxInvariantScope; // Used by BrgemmCallOps for hoisting + bool hasTilereleased; // Used by other Ops as hoisting scopes to + // dedup tilerelease injection + }; + BrgemmContextStructInfo() { + contextRoot = nullptr; + self = nullptr; + parent = nullptr; + containBrgemmCallType = {false, false, false}; + maxInvariantScope = nullptr; + } +}; + +typedef DenseMap OpStructInfoMap; + +class BrgemmTilecfgRewriter : public OpRewritePattern { +private: + OpStructInfoMap &structInfo; + +public: + using OpRewritePattern::OpRewritePattern; + + BrgemmTilecfgRewriter(MLIRContext *context, OpStructInfoMap &si) + : OpRewritePattern(context), structInfo{si} {} + + LogicalResult matchAndRewrite(func::CallOp op, + PatternRewriter &rewriter) const final { + ModuleOp module = op->template getParentOfType(); + + StringAttr callee = op.getCalleeAttr().getAttr(); + if (!module.lookupSymbol(callee)) + return rewriter.notifyMatchFailure(op, + "Invalid CallOp to unknown callee"); + if (callee != + StringAttr::get(rewriter.getContext(), DNNL_BRGEMM_TILECFG_NAME)) + return rewriter.notifyMatchFailure(op, "Not call to BRGEMM tilecfg"); + auto opInfoIter = structInfo.find(op); + if (opInfoIter == structInfo.end()) { + return rewriter.notifyMatchFailure(op, "Cannot find structInfo for Op"); + } + auto &opStructInfo = opInfoIter->second; + + // Don't hoist if max invariant scope is itself to reduce + // unnecessary movement + if (opStructInfo.maxInvariantScope == op) { + return rewriter.notifyMatchFailure(op, "No need to hoist"); + } + rewriter.moveOpBefore(op, opStructInfo.maxInvariantScope); + // Avoid being hoisted again + opStructInfo.maxInvariantScope = op; + return success(); + } +}; + +static void markScopeAsReleased(OpStructInfoMap &structInfo, Operation *op) { + auto iter = structInfo.find(op); + assert(iter != structInfo.end()); + // Don't mark BrgemmCallOps + if (getBrgemmCallType(op) != BrgemmCallType::INAPPLICABLE) + return; + iter->second.hasTilereleased = true; + + for (auto ch : iter->second.child) { + markScopeAsReleased(structInfo, ch); + } +} + +class BrgemmTilereleaseRewriter : public OpRewritePattern { +private: + OpStructInfoMap &structInfo; + +public: + using OpRewritePattern::OpRewritePattern; + + BrgemmTilereleaseRewriter(MLIRContext *context, OpStructInfoMap &si) + : OpRewritePattern(context), structInfo{si} {} + LogicalResult matchAndRewrite(func::CallOp op, + PatternRewriter &rewriter) const final { + ModuleOp module = op->template getParentOfType(); + + StringAttr callee = op.getCalleeAttr().getAttr(); + if (!module.lookupSymbol(callee)) + return rewriter.notifyMatchFailure(op, + "Invalid CallOp to unknown callee"); + if (callee != + StringAttr::get(rewriter.getContext(), DNNL_BRGEMM_TILERELEASE_NAME)) + return rewriter.notifyMatchFailure(op, "Not call to BRGEMM tilerelease"); + + auto opInfoIter = structInfo.find(op); + if (opInfoIter == structInfo.end()) { + return rewriter.notifyMatchFailure(op, "Cannot find structInfo for Op"); + } + auto &opStructInfo = opInfoIter->second; + auto targetInfoIter = structInfo.find(opStructInfo.maxInvariantScope); + assert(opStructInfo.maxInvariantScope); + // Don't hoist if max invariant scope is itself to reduce + // unnecessary movement + if (opStructInfo.maxInvariantScope == op) { + return rewriter.notifyMatchFailure(op, "No need to hoist"); + } + assert(targetInfoIter != structInfo.end()); + // move last tilerelease to end of contextRoot, and remove all + // others + if (targetInfoIter->second.hasTilereleased) { + rewriter.eraseOp(op); + } else { + // auto region = opStructInfo.maxInvariantScope->getRegion(0); + // auto block = ®ion.getBlocks().front(); + // auto enditer = block->end(); + // rewriter.moveOpBefore(op, block, enditer); + rewriter.moveOpAfter(op, opStructInfo.maxInvariantScope); + // Mark all sub scope as released to avoid duplicate Tilerelease + markScopeAsReleased(structInfo, opStructInfo.maxInvariantScope); + // Avoid being hoisted again + opStructInfo.maxInvariantScope = op; + } + return success(); + } +}; + +class MicrokernelInvariantCodeMotion + : public impl::MicrokernelInvariantCodeMotionBase< + MicrokernelInvariantCodeMotion> { +private: + // This helper create structInfo tree node along the path from input Op(as + // leaf) to contextRoot(FuncOp or any parallel Op) on demand; + // This tree only contains concerned Ops, including BrgemmCall Ops, parallel + // ops and related SCF Ops etc.; + // Input Op should be a BrgemmCall Op or the op + // depent by BrgemmTilecfg/BrgemmRelease + BrgemmContextStructInfo getOrCreateBrgemmContext(OpStructInfoMap &structInfo, + Operation *op) { + auto resIter = structInfo.find(op); + if (resIter != structInfo.end()) { + return resIter->second; + } + + SmallVector createdInfo; + Operation *contextRootForCreatedInfo = nullptr; + + auto doCreateStructInfo = [&](Operation *child, Operation *op) { + BrgemmContextStructInfo info; + info.self = op; + if (child) { + auto iter = structInfo.find(child); + assert(iter != structInfo.end()); + iter->second.parent = op; + info.child.insert(child); + } + structInfo.insert(std::make_pair(op, std::move(info))); + auto iter = structInfo.find(op); + createdInfo.push_back(&iter->second); + return &iter->second; + }; + // Create info for input Op as leaf + auto brgemmInfo = doCreateStructInfo(nullptr, op); + auto callType = getBrgemmCallType(op); + if (callType != BrgemmCallType::INAPPLICABLE) { + brgemmInfo->containBrgemmCallType[callType] = true; + } + + auto last = op; + auto current = op->getParentOp(); + // Traverse up the IR tree, creating structInfo for each concerned Op + while (current) { + bool isParaLoop = isParallelLoop(current); + bool isCCF = isConcernedCF(current); + if (!llvm::isa(current) && !isParaLoop && !isCCF) { + // Only care about selected Ops + current = current->getParentOp(); + continue; + } + + auto iter = structInfo.find(current); + if (iter != structInfo.end()) { + // StructInfo exists for current Op, then we don't need to create info + // anymore as all ancestors have been created + // But we still need to propagate containBrgemmCallType if we are + // dealing with BrgemmCall Ops + if (last) { + auto lastIter = structInfo.find(last); + assert(lastIter != structInfo.end()); + lastIter->second.parent = current; + iter->second.child.insert(last); + // Invalidate last as we don't create new info anymore + last = nullptr; + } + if (callType != BrgemmCallType::INAPPLICABLE) { + // Propagate containCallType if needed + iter->second.containBrgemmCallType[callType] = true; + } else + break; + } else { + // StructInfo not exist, then create one for current Op and keep + // Traversing up + auto created = doCreateStructInfo(last, current); + if (callType != BrgemmCallType::INAPPLICABLE) { + created->containBrgemmCallType[callType] = true; + } + last = current; + } + if (llvm::isa(current) || isParaLoop) { + // Encounter `contextRoot`, then record and terminate traversing + contextRootForCreatedInfo = current; + break; + } + current = current->getParentOp(); + } + + // Assign `contextRoot` for newly created structInfo + if (contextRootForCreatedInfo) { + for (auto info : createdInfo) + info->contextRoot = contextRootForCreatedInfo; + } + + resIter = structInfo.find(op); + assert(resIter != structInfo.end()); + return resIter->second; + } + + // This helper expand invariant scope according to two function: + // 1. controlFlowAllow: Whether we can hoist the BrgemmCallOp out of the scope + // of current Op; For example, we won't move TILECFG out of an IfOp as it + // contains underministic control flow. + // 2. peerAllow: Whether we can hoist the BrgemmCallOp out of the scope of + // current Op without violating other peer BrgemmCallOp in the same level; For + // example, one scf.ForOp contains two TILECFG in the same level, then we + // cannot hoist any of them. + void expandInvariantScopeWithCond( + OpStructInfoMap &structInfo, Operation *op, + std::function controlFlowAllow, + std::function &)> + peerAllow) { + auto opIter = structInfo.find(op); + assert(opIter != structInfo.end()); + auto contextRoot = opIter->second.contextRoot; + auto current = op; + auto currIter = opIter; + auto parent = opIter->second.parent; + while (parent != contextRoot) { + auto parentIter = structInfo.find(parent); + assert(parentIter != structInfo.end()); + // Verify whether we can expand the scope to direct parent + bool isControlFlowAllow = controlFlowAllow(parent); + bool isPeerAllow = + peerAllow(op, structInfo, current, parentIter->second.child); + if (!isControlFlowAllow || !isPeerAllow) { + break; + } + current = parent; + currIter = parentIter; + parent = parentIter->second.parent; + } + + opIter->second.maxInvariantScope = current; + } + + void expandInvariantScope(OpStructInfoMap &structInfo, Operation *op) { + BrgemmCallType brgemmCallType = getBrgemmCallType(op); + assert(brgemmCallType == BrgemmCallType::TILECFG || + brgemmCallType == BrgemmCallType::TILERELEASE); + + if (brgemmCallType == BrgemmCallType::TILECFG) { + expandInvariantScopeWithCond( + structInfo, op, + [](Operation *op) -> bool { + return !llvm::isa(op) && + !llvm::isa(op); + }, + [](Operation *self, const OpStructInfoMap &structInfo, + Operation *current, const DenseSet &peers) -> bool { + for (auto peer : peers) { + if (peer == current) + continue; + if (peer == self->getOperand(0).getDefiningOp()) { + // Don't break operand domination + return false; + } + const auto iter = structInfo.find(peer); + assert(iter != structInfo.end()); + const auto &containType = iter->second.containBrgemmCallType; + if (containType[BrgemmCallType::DISPATCH] || + containType[BrgemmCallType::TILECFG]) { + return false; + } + } + return true; + }); + } else { // brgemmCallType == BrgemmCallType::TILERELEASE + expandInvariantScopeWithCond( + structInfo, op, + [](Operation *op) -> bool { + return !llvm::isa(op) && + !llvm::isa(op); + }, + [](Operation *self, const OpStructInfoMap &structInfo, + Operation *current, + const DenseSet &peers) -> bool { return true; }); + } + } + + LogicalResult collectBrgemmContextStructInfo(OpStructInfoMap &structInfo) { + // First walk to collect basic structure + getOperation()->walk( + [this, &structInfo](Operation *op) { + BrgemmCallType brgemmCallType = getBrgemmCallType(op); + if (brgemmCallType == BrgemmCallType::INAPPLICABLE) { + return; + } + + // Construct the structInfo tree lazily upon encountering BrgemmCall + // Op + auto info = getOrCreateBrgemmContext(structInfo, op); + structInfo.insert(std::make_pair(op, std::move(info))); + if (brgemmCallType == BrgemmCallType::TILECFG) { + // Also contruct tree node for the input of BrgemmTilecfg for + // dependency check in `expandInvariantScope` + auto dependOp = op->getOperand(0).getDefiningOp(); + auto dependInfo = getOrCreateBrgemmContext(structInfo, dependOp); + structInfo.insert(std::make_pair(dependOp, std::move(dependInfo))); + } + }); + + // Second walk to analyse hoist related info + getOperation()->walk( + [this, &structInfo](Operation *op) { + BrgemmCallType brgemmCallType = getBrgemmCallType(op); + if (brgemmCallType != BrgemmCallType::TILECFG && + brgemmCallType != BrgemmCallType::TILERELEASE) { + return; + } + + // find the maximal invariant scope for hoisting + expandInvariantScope(structInfo, op); + }); + + return success(); + } + +public: + using impl::MicrokernelInvariantCodeMotionBase< + MicrokernelInvariantCodeMotion>::MicrokernelInvariantCodeMotionBase; + void runOnOperation() final { + OpStructInfoMap structInfo; + + if (failed(collectBrgemmContextStructInfo(structInfo))) { + signalPassFailure(); + } + + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), structInfo); + patterns.add(&getContext(), structInfo); + FrozenRewritePatternSet patternSet(std::move(patterns)); + + // Ignore newly created Ops + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + if (failed( + applyPatternsAndFoldGreedily(getOperation(), patternSet, config))) { + signalPassFailure(); + } + } +}; + +} // namespace mlir::microkernel diff --git a/test/gc/cpu-runner/Microkernel/brgemm-multilevel-for.mlir b/test/gc/cpu-runner/Microkernel/brgemm-multilevel-for.mlir new file mode 100644 index 000000000..6943c8f4e --- /dev/null +++ b/test/gc/cpu-runner/Microkernel/brgemm-multilevel-for.mlir @@ -0,0 +1,61 @@ +// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c2_i64 = arith.constant 2 : i64 + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %1 = arith.addf %in, %in_7 : f32 + linalg.yield %1 : f32 + } + %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.maximumf %in, %cst : f32 + linalg.yield %1 : f32 + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_4 : memref<32x32xf32>) + %subview_7 = memref.subview %alloc[%c0_index, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_8 = memref.subview %alloc_0[%c0_index, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %10 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%10) : (i64) -> () + microkernel.brgemm(%10, %subview_7, %subview_8, %alloc_4, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%10) : (i64) -> () + } + return + } + + func.func @main() { + call @simple_brgemm() : ()->() + // COM: parallelcpu.printf "BRGEMM DONE\n" + return + } + + // COM: CHECK: BRGEMM DONE +} diff --git a/test/gc/cpu-runner/Microkernel/brgemm-parallel.mlir b/test/gc/cpu-runner/Microkernel/brgemm-parallel.mlir new file mode 100644 index 000000000..29903c6dc --- /dev/null +++ b/test/gc/cpu-runner/Microkernel/brgemm-parallel.mlir @@ -0,0 +1,50 @@ +// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc : memref<4x16x32x32xf32>) + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<8x16x32x32xf32>) + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) + scf.forall (%arg0, %arg1) in (4, 8) { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %1 = arith.addf %in, %in_7 : f32 + linalg.yield %1 : f32 + } + %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.maximumf %in, %cst : f32 + linalg.yield %1 : f32 + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + return + } + + func.func @main() { + call @simple_brgemm() : ()->() + // COM: parallelcpu.printf "BRGEMM DONE\n" + return + } + + // COM: CHECK: BRGEMM DONE +} diff --git a/test/gc/cpu-runner/Microkernel/brgemm-simple-for.mlir b/test/gc/cpu-runner/Microkernel/brgemm-simple-for.mlir new file mode 100644 index 000000000..1016ee7d8 --- /dev/null +++ b/test/gc/cpu-runner/Microkernel/brgemm-simple-for.mlir @@ -0,0 +1,56 @@ +// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc : memref<4x16x32x32xf32>) + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<8x16x32x32xf32>) + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %1 = arith.addf %in, %in_7 : f32 + linalg.yield %1 : f32 + } + %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.maximumf %in, %cst : f32 + linalg.yield %1 : f32 + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + } + return + } + + func.func @main() { + call @simple_brgemm() : ()->() + // COM: parallelcpu.printf "BRGEMM DONE\n" + return + } + + // COM: CHECK: BRGEMM DONE +} diff --git a/test/mlir/test/gc/Dialect/Microkernel/early-dispatch-microkernel.mlir b/test/mlir/test/gc/Dialect/Microkernel/early-dispatch-microkernel.mlir new file mode 100644 index 000000000..60b982d5c --- /dev/null +++ b/test/mlir/test/gc/Dialect/Microkernel/early-dispatch-microkernel.mlir @@ -0,0 +1,62 @@ +// RUN: gc-opt %s -early-dispatch-microkernel -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg0, %arg1) in (4, 8) { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %1 = arith.addf %in, %in_7 : f32 + linalg.yield %1 : f32 + } + %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.maximumf %in, %cst : f32 + linalg.yield %1 : f32 + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + return + } +} + +// CHECK: llvm.mlir.global_ctors {ctors = [@[[G_CTOR_NAME:.+]]], priorities = [[[G_CTOR_PRIOR:.+]] : i32]} +// CHECK: llvm.mlir.global internal @[[G_NAME:.+]]() {addr_space = 0 : i32} : i64 + +// CHECK: llvm.func @[[G_CTOR_NAME]]() -> i64 { +// CHECK-DAG: %[[G_PTR:.+]] = llvm.mlir.addressof @[[G_NAME]] : !llvm.ptr +// CHECK-DAG: %[[KERNEL:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) +// CHECK: llvm.store %[[KERNEL]], %[[G_PTR]] : i64, !llvm.ptr + +// CHECK-LABEL: simple_brgemm +// CHECK: %[[CST16:.+]] = arith.constant 16 : i64 +// CHECK: %[[CST0:.+]] = arith.constant 0 : i64 + +// CHECK: %[[G_PTR_2:.+]] = llvm.mlir.addressof @[[G_NAME]] : !llvm.ptr +// CHECK-NEXT: %[[KERNEL2:.+]] = llvm.load %[[G_PTR_2]] : !llvm.ptr -> i64 + +// CHECK: %[[memrefC:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + +// CHECK: %[[subviewA:.+]] = memref.subview %[[memrefA:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[subviewB:.+]] = memref.subview %[[memrefB:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> + +// CHECK: microkernel.brgemm(%[[KERNEL2]], %[[subviewA]], %[[subviewB]], %[[memrefC]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + +// ----- diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir new file mode 100644 index 000000000..824875886 --- /dev/null +++ b/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir @@ -0,0 +1,191 @@ +// RUN: gc-opt %s -early-dispatch-microkernel -convert-microkernel-to-dnnl-func -microkernel-invariant-code-motion -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.forall (%arg0, %arg1) in (4, 8) { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + memref.dealloc %alloc_3 : memref<32x32xf32> + } + return + } +} + +// CHECK-LABEL: simple_brgemm + +// CHECK: scf.forall (%arg0, %arg1) in (4, 8) +// CHECK: call @dnnl_brgemm_tileconfig(%[[A:.+]]) : (i64) -> () +// CHECK: call @dnnl_brgemm_execute([[B:.+]]) : ([[C:.+]]) -> () +// CHECK-NEXT: call @dnnl_brgemm_tilerelease() : () -> () + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c2_i64 = arith.constant 2 : i64 + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + + %subview_7 = memref.subview %alloc[%arg0, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_8 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %10 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%10) : (i64) -> () + microkernel.brgemm(%10, %subview_7, %subview_8, %alloc_3, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%10) : (i64) -> () + + memref.dealloc %alloc_3 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: simple_brgemm + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: call @dnnl_brgemm_tileconfig(%[[A:.+]]) : (i64) -> () +// CHECK: call @dnnl_brgemm_execute([[B:.+]]) : ([[C:.+]]) -> () +// CHECK-NOT: call @dnnl_brgemm_tilerelease() : () -> () + +// CHECK: call @dnnl_brgemm_tileconfig(%[[D:.+]]) : (i64) -> () +// CHECK: call @dnnl_brgemm_execute([[E:.+]]) : ([[F:.+]]) -> () +// CHECK-NOT: call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: call @dnnl_brgemm_tilerelease() : () -> () +// CHECK-NEXT: return + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c2_i64 = arith.constant 2 : i64 + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + memref.dealloc %alloc_3 : memref<32x32xf32> + } + scf.for %arg2 = %c0_index to %c4_index step %c1_index { + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_4 : memref<32x32xf32>) + %subview_7 = memref.subview %alloc[%arg2, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_8 = memref.subview %alloc_0[%arg2, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview_7, %subview_8, %alloc_4, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + } + } + return + } +} + +// CHECK-LABEL: simple_brgemm + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: call @dnnl_brgemm_tileconfig(%[[A:.+]]) : (i64) -> () +// CHECK: scf.for %arg1 = %c0 to %c8 step %c1 +// CHECK: call @dnnl_brgemm_execute([[B:.+]]) : ([[C:.+]]) -> () + +// CHECK: } +// CHECK-NEXT: call @dnnl_brgemm_tileconfig(%[[D:.+]]) : (i64) -> () +// CHECK-NEXT: scf.for %arg1 = %c0 to %c4 step %c1 +// CHECK: call @dnnl_brgemm_execute([[E:.+]]) : ([[F:.+]]) -> () + +// CHECK: call @dnnl_brgemm_tilerelease() : () -> () +// CHECK-NEXT: return + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + memref.dealloc %alloc_3 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: simple_brgemm + +// CHECK: call @dnnl_brgemm_tileconfig(%[[A:.+]]) : (i64) -> () +// CHECK-NEXT: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: call @dnnl_brgemm_execute([[B:.+]]) : ([[C:.+]]) -> () + +// CHECK: call @dnnl_brgemm_tilerelease() : () -> () +// CHECK-NEXT: return + +// ----- diff --git a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir index cd927f175..c43187e52 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s --convert-microkernel-to-dnnl-func --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-cpuruntime-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void +// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-cpuruntime-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void #map = affine_map<(d0, d1) -> (d0, d1)> module { From 2cd8df00f5f6961c1efa32af48fba0be96e8bd6b Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Mon, 17 Jun 2024 00:54:37 -0700 Subject: [PATCH 74/93] move test mlir to right place --- .../Microkernel/brgemm-parallel.mlir | 50 ------------------- .../gc/cpu-runner}/brgemm-multilevel-for.mlir | 0 .../gc/cpu-runner}/brgemm-simple-for.mlir | 0 3 files changed, 50 deletions(-) delete mode 100644 test/gc/cpu-runner/Microkernel/brgemm-parallel.mlir rename test/{gc/cpu-runner/Microkernel => mlir/test/gc/cpu-runner}/brgemm-multilevel-for.mlir (100%) rename test/{gc/cpu-runner/Microkernel => mlir/test/gc/cpu-runner}/brgemm-simple-for.mlir (100%) diff --git a/test/gc/cpu-runner/Microkernel/brgemm-parallel.mlir b/test/gc/cpu-runner/Microkernel/brgemm-parallel.mlir deleted file mode 100644 index 29903c6dc..000000000 --- a/test/gc/cpu-runner/Microkernel/brgemm-parallel.mlir +++ /dev/null @@ -1,50 +0,0 @@ -// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void - -#map = affine_map<(d0, d1) -> (d0, d1)> -module { - func.func @simple_brgemm() { - %c0_i64 = arith.constant 0 : i64 - %c16_i64 = arith.constant 16 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc : memref<4x16x32x32xf32>) - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<8x16x32x32xf32>) - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) - %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) - scf.forall (%arg0, %arg1) in (4, 8) { - %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) - %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) - microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%0) : (i64) -> () - %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { - ^bb0(%in: f32, %in_7: f32, %out: f32): - %1 = arith.addf %in, %in_7 : f32 - linalg.yield %1 : f32 - } - %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { - ^bb0(%in: f32, %out: f32): - %1 = arith.maximumf %in, %cst : f32 - linalg.yield %1 : f32 - } - memref.dealloc %alloc_3 : memref<32x32xf32> - } - return - } - - func.func @main() { - call @simple_brgemm() : ()->() - // COM: parallelcpu.printf "BRGEMM DONE\n" - return - } - - // COM: CHECK: BRGEMM DONE -} diff --git a/test/gc/cpu-runner/Microkernel/brgemm-multilevel-for.mlir b/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir similarity index 100% rename from test/gc/cpu-runner/Microkernel/brgemm-multilevel-for.mlir rename to test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir diff --git a/test/gc/cpu-runner/Microkernel/brgemm-simple-for.mlir b/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir similarity index 100% rename from test/gc/cpu-runner/Microkernel/brgemm-simple-for.mlir rename to test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir From b3e60afe9cfb54addcaef9a7d850d0a743a86512 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 18 Jun 2024 02:38:38 -0700 Subject: [PATCH 75/93] [To be tested] add pass --- .../Microkernel/MicrokernelPasses.td | 9 + lib/gc/Transforms/Microkernel/CMakeLists.txt | 1 + .../MergeBranchMicrokernelContext.cpp | 301 ++++++++++++++++++ 3 files changed, 311 insertions(+) create mode 100644 lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td index 021dd96a7..93f23f28a 100644 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.td +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.td @@ -87,6 +87,15 @@ def EarlyDispatchMicrokernel: Pass<"early-dispatch-microkernel", "::mlir::Module "microkernel::MicrokernelDialect"]; } +def MergeBranchMicrokernelContext: Pass<"merge-branch-microkernel-context", "::mlir::func::FuncOp"> { + let summary = "Find and merge identical microkernel context operations in branches into one"; + let description = [{ + Find and merge identical microkernel context operations in branches into one. + }]; + let dependentDialects = ["func::FuncDialect", + "memref::MemRefDialect"]; +} + def MicrokernelInvariantCodeMotion: Pass<"microkernel-invariant-code-motion", "::mlir::ModuleOp"> { let summary = "Hoist invariant microkernel code to avoid redundant execution"; let description = [{ diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt index c7bdc4617..a555a752a 100644 --- a/lib/gc/Transforms/Microkernel/CMakeLists.txt +++ b/lib/gc/Transforms/Microkernel/CMakeLists.txt @@ -8,6 +8,7 @@ gc_add_mlir_dialect_library(MLIRMicrokernelTransforms ConvertMicrokernelToDnnlFunc.cpp EarlyDispatchMicrokernel.cpp MicrokernelInvariantCodeMotion.cpp + MergeBranchMicrokernelContext.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ diff --git a/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp new file mode 100644 index 000000000..170860cc0 --- /dev/null +++ b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp @@ -0,0 +1,301 @@ +//===- MergeBranchMicrokernelContext.cpp -----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" +#include "gc/Utils/ValueUtils.h" +#include "oneapi/dnnl/dnnl_types.h" + +namespace mlir::microkernel { +#define GEN_PASS_DEF_MERGEBRANCHMICROKERNELCONTEXT +#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" + +#define DEBUG_TYPE "merge-branch-microkernel-context" + +// enum BrgemmCallType { INAPPLICABLE = -1, DISPATCH, TILECFG, TILERELEASE }; + +class BrgemmDispatchAnalysis { +private: + // A map for tile_config -> tile_dispatch + DenseMap brgemmDispatches; + + Operation *traceKernelDispatch(Operation *op); + Operation *traceDispatchInGlobalCtor(ModuleOp module, + llvm::StringRef global_name); + +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BrgemmDispatchAnalysis) + explicit BrgemmDispatchAnalysis(Operation *); + void setKernelDispatch(Operation *tilecfg, Operation *dispatch) { + brgemmDispatches[tilecfg] = dispatch; + }; + Operation *getKernelDispatch(Operation *tilecfg) const { + auto iter = brgemmDispatches.find(tilecfg); + if (iter == brgemmDispatches.end()) { + return nullptr; + } + return iter->second; + }; +}; + +BrgemmDispatchAnalysis::BrgemmDispatchAnalysis(Operation *root) { + ModuleOp module = dyn_cast_or_null(root); + if (!module) + return; + + module->walk([this](Operation *op) { + auto callOp = dyn_cast_or_null(op); + if (!callOp) + return; + StringAttr callee = callOp.getCalleeAttr().getAttr(); + if (callee != StringAttr::get(op->getContext(), DNNL_BRGEMM_TILECFG_NAME)) + return; + + auto dispatch = traceKernelDispatch(callOp); + assert(dispatch && "No dispatch found for tilecfg Op"); + setKernelDispatch(callOp, dispatch); + }); +} + +Operation *BrgemmDispatchAnalysis::traceKernelDispatch(Operation *op) { + ModuleOp module = op->template getParentOfType(); + auto callOp = dyn_cast_or_null(op); + assert(callOp); + StringAttr callee = callOp.getCalleeAttr().getAttr(); + assert(callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_TILECFG_NAME)); + auto kernelProducer = callOp.getOperand(0).getDefiningOp(); + // Direct producer is supposed to be either `brgemm.dispatch` or LLVM::load + // global Any other cases are extremely rare (mostly invalid MLIR), so + // considered as not found + if (auto tryCallOp = dyn_cast_or_null(kernelProducer)) { + callee = tryCallOp.getCalleeAttr().getAttr(); + if (callee != StringAttr::get(op->getContext(), DNNL_BRGEMM_DISPATCH_NAME)) + return nullptr; + return tryCallOp; + } else if (auto tryLoadOp = dyn_cast_or_null(kernelProducer)) { + auto tryAddrOfOp = dyn_cast_or_null( + tryLoadOp.getOperand().getDefiningOp()); + if (!tryAddrOfOp) + return nullptr; + return traceDispatchInGlobalCtor(module, tryAddrOfOp.getGlobalName()); + } + return nullptr; +} + +Operation * +BrgemmDispatchAnalysis::traceDispatchInGlobalCtor(ModuleOp module, + llvm::StringRef global_name) { + std::string gctor_name = std::string(global_name) + "_ctor"; + FlatSymbolRefAttr ctorName = + SymbolRefAttr::get(module->getContext(), gctor_name); + auto ctor = module.lookupSymbol(ctorName); + if (!ctor) + return nullptr; + + // ctor should contain only one call for kernel dispatch + auto &body = ctor.getBody(); + for (auto &opRef : body.getOps()) { + auto *op = &opRef; + auto tryCallOp = dyn_cast_or_null(op); + auto callee = tryCallOp.getCalleeAttr().getAttr(); + if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_DISPATCH_NAME)) + return op; + } + return nullptr; +} + +// return a pair of extracted from given region, +// only check direct descendants +static std::pair +extractTileOpsFromRegion(Region ®ion) { + std::pair ret{nullptr, nullptr}; + + for (auto &opRef : region.getOps()) { + auto *op = &opRef; + auto tryCallOp = dyn_cast_or_null(op); + auto callee = tryCallOp.getCalleeAttr().getAttr(); + if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_TILECFG_NAME)) + ret.first = op; + else if (callee == + StringAttr::get(op->getContext(), DNNL_BRGEMM_TILERELEASE_NAME)) + ret.second = op; + } + + return ret; +} + +static bool dispatchHasSameContext(Operation *lhs, Operation *rhs) { + auto lhsDispatch = dyn_cast_or_null(lhs); + auto rhsDispatch = dyn_cast_or_null(rhs); + if (!lhsDispatch || !rhsDispatch) + return false; + auto dispatchNameAttr = + StringAttr::get(lhs->getContext(), DNNL_BRGEMM_DISPATCH_NAME); + if (lhsDispatch.getCalleeAttr().getAttr() != dispatchNameAttr || + rhsDispatch.getCalleeAttr().getAttr() != dispatchNameAttr) + return false; + + auto lhsOperands = lhsDispatch.getOperands(); + auto rhsOperands = rhsDispatch.getOperands(); + assert(lhsOperands.size() == rhsOperands.size() && + "Inconsistent operand size"); + for (size_t idx = 0; idx < lhsOperands.size(); idx++) { + if (idx == 8) { + // skip `beta` operand in index no.8 + // since per dnnl design, it does not affect BRGEMM blocking & palettes + continue; + } + auto lhsCstOp = + dyn_cast_or_null(lhsOperands[idx].getDefiningOp()); + auto rhsCstOp = + dyn_cast_or_null(rhsOperands[idx].getDefiningOp()); + if (!lhsCstOp || !rhsCstOp) + return false; + if (lhsCstOp.getValue() != rhsCstOp.getValue()) + return false; + } + return true; +} + +class ScfIfRewriter : public OpRewritePattern { +private: + BrgemmDispatchAnalysis &analysis; + +public: + using OpRewritePattern::OpRewritePattern; + + ScfIfRewriter(MLIRContext *context, BrgemmDispatchAnalysis &ana) + : OpRewritePattern(context), analysis{ana} {} + + LogicalResult matchAndRewrite(scf::IfOp op, + PatternRewriter &rewriter) const final { + ModuleOp module = op->template getParentOfType(); + auto &ifRegion = op.getThenRegion(); + auto &elseRegion = op.getElseRegion(); + if (!ifRegion.hasOneBlock() || !elseRegion.hasOneBlock()) + return rewriter.notifyMatchFailure(op, + "Cannot merge for non-full branch"); + auto ifTileOps = extractTileOpsFromRegion(ifRegion); + auto elseTileOps = extractTileOpsFromRegion(elseRegion); + if (!ifTileOps.first || !ifTileOps.second || !elseTileOps.first || + !elseTileOps.second) + return rewriter.notifyMatchFailure( + op, "Cannot merge for inconsistent branch"); + + auto ifTileDispatch = analysis.getKernelDispatch(ifTileOps.first); + auto elseTileDispatch = analysis.getKernelDispatch(elseTileOps.first); + if (!ifTileDispatch || !elseTileDispatch) + return rewriter.notifyMatchFailure(op, "Cannot find kernel dispatch"); + + if (!dispatchHasSameContext(ifTileDispatch, elseTileDispatch)) + return rewriter.notifyMatchFailure( + op, "Kernels in branch has different context"); + + // Avoid breaking dominance of dispatch + if (ifTileDispatch->getParentRegion() == &ifRegion) + return rewriter.notifyMatchFailure(op, + "Dispatch dominance prevents merging"); + + // Whole branch reuses internal context of kernel in `if` region + rewriter.eraseOp(elseTileOps.first); + rewriter.eraseOp(elseTileOps.second); + rewriter.moveOpBefore(ifTileOps.first, op); + rewriter.moveOpAfter(ifTileOps.second, op); + + return success(); + } +}; + +class ScfIndexSwitchRewriter : public OpRewritePattern { +private: + BrgemmDispatchAnalysis &analysis; + +public: + using OpRewritePattern::OpRewritePattern; + + ScfIndexSwitchRewriter(MLIRContext *context, BrgemmDispatchAnalysis &ana) + : OpRewritePattern(context), analysis{ana} {} + + LogicalResult matchAndRewrite(scf::IndexSwitchOp op, + PatternRewriter &rewriter) const final { + ModuleOp module = op->template getParentOfType(); + auto &defaultRegion = op.getDefaultRegion(); + auto caseRegions = op.getCaseRegions(); + + auto defaultTileOps = extractTileOpsFromRegion(defaultRegion); + if (!defaultTileOps.first || !defaultTileOps.second) + return rewriter.notifyMatchFailure( + op, "Cannot merge for inconsistent branch"); + SmallVector, 5> caseTilesOps; + for (auto &caseRegion : caseRegions) { + auto caseTileOps = extractTileOpsFromRegion(caseRegion); + if (!caseTileOps.first || !caseTileOps.second) + return rewriter.notifyMatchFailure( + op, "Cannot merge for inconsistent branch"); + caseTilesOps.push_back(caseTileOps); + } + + auto defaultTileDispatch = analysis.getKernelDispatch(defaultTileOps.first); + if (!defaultTileDispatch) + return rewriter.notifyMatchFailure(op, "Cannot find kernel dispatch"); + + for (size_t idx = 0; idx < caseRegions.size(); idx++) { + auto &caseRegion = caseRegions[idx]; + auto caseTileDispatch = + analysis.getKernelDispatch(caseTilesOps[idx].first); + if (!defaultTileDispatch) + return rewriter.notifyMatchFailure(op, "Cannot find kernel dispatch"); + if (!dispatchHasSameContext(defaultTileDispatch, caseTileDispatch)) + return rewriter.notifyMatchFailure( + op, "Kernels in branch has different context"); + } + + // Avoid breaking dominance of dispatch + if (defaultTileDispatch->getParentRegion() == &defaultRegion) + return rewriter.notifyMatchFailure(op, + "Dispatch dominance prevents merging"); + + // Whole branch reuses internal context of kernel in `default` region + for (auto &ops : caseTilesOps) { + rewriter.eraseOp(ops.first); + rewriter.eraseOp(ops.second); + } + rewriter.moveOpBefore(defaultTileOps.first, op); + rewriter.moveOpAfter(defaultTileOps.second, op); + + return success(); + } +}; + +class MergeBranchMicrokernelContext + : public impl::MergeBranchMicrokernelContextBase< + MergeBranchMicrokernelContext> { +public: + using impl::MergeBranchMicrokernelContextBase< + MergeBranchMicrokernelContext>::MergeBranchMicrokernelContextBase; + void runOnOperation() final { + auto &dispatchAnalysis = getAnalysis(); + + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), dispatchAnalysis); + patterns.add(&getContext(), dispatchAnalysis); + FrozenRewritePatternSet patternSet(std::move(patterns)); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { + signalPassFailure(); + } + } +}; + +} // namespace mlir::microkernel From fa44d9ada8174911dc8ecf1ad054b75a1fda2e09 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 18 Jun 2024 20:55:45 -0700 Subject: [PATCH 76/93] add test & bugfix for new pass --- .../Microkernel/MicrokernelPasses.td | 2 +- .../MergeBranchMicrokernelContext.cpp | 9 + .../merge-branch-microkernel-context.mlir | 278 ++++++++++++++++++ 3 files changed, 288 insertions(+), 1 deletion(-) create mode 100644 test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir diff --git a/include/gc/Transforms/Microkernel/MicrokernelPasses.td b/include/gc/Transforms/Microkernel/MicrokernelPasses.td index 93f23f28a..41a3fd070 100644 --- a/include/gc/Transforms/Microkernel/MicrokernelPasses.td +++ b/include/gc/Transforms/Microkernel/MicrokernelPasses.td @@ -87,7 +87,7 @@ def EarlyDispatchMicrokernel: Pass<"early-dispatch-microkernel", "::mlir::Module "microkernel::MicrokernelDialect"]; } -def MergeBranchMicrokernelContext: Pass<"merge-branch-microkernel-context", "::mlir::func::FuncOp"> { +def MergeBranchMicrokernelContext: Pass<"merge-branch-microkernel-context", "::mlir::ModuleOp"> { let summary = "Find and merge identical microkernel context operations in branches into one"; let description = [{ Find and merge identical microkernel context operations in branches into one. diff --git a/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp index 170860cc0..b0dbaeb91 100644 --- a/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp +++ b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp @@ -38,6 +38,8 @@ class BrgemmDispatchAnalysis { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BrgemmDispatchAnalysis) explicit BrgemmDispatchAnalysis(Operation *); void setKernelDispatch(Operation *tilecfg, Operation *dispatch) { + LLVM_DEBUG(llvm::dbgs() << "* setKernelDispatch: " << tilecfg << "; " + << dispatch << "\n"); brgemmDispatches[tilecfg] = dispatch; }; Operation *getKernelDispatch(Operation *tilecfg) const { @@ -50,6 +52,8 @@ class BrgemmDispatchAnalysis { }; BrgemmDispatchAnalysis::BrgemmDispatchAnalysis(Operation *root) { + LLVM_DEBUG(llvm::dbgs() << "* construct BrgemmDispatchAnalysis: " << *root + << "\n"); ModuleOp module = dyn_cast_or_null(root); if (!module) return; @@ -108,6 +112,8 @@ BrgemmDispatchAnalysis::traceDispatchInGlobalCtor(ModuleOp module, for (auto &opRef : body.getOps()) { auto *op = &opRef; auto tryCallOp = dyn_cast_or_null(op); + if (!tryCallOp) + continue; auto callee = tryCallOp.getCalleeAttr().getAttr(); if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_DISPATCH_NAME)) return op; @@ -122,8 +128,11 @@ extractTileOpsFromRegion(Region ®ion) { std::pair ret{nullptr, nullptr}; for (auto &opRef : region.getOps()) { + LLVM_DEBUG(llvm::dbgs() << ">>> " << opRef << "\n"); auto *op = &opRef; auto tryCallOp = dyn_cast_or_null(op); + if (!tryCallOp) + continue; auto callee = tryCallOp.getCalleeAttr().getAttr(); if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_TILECFG_NAME)) ret.first = op; diff --git a/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir b/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir new file mode 100644 index 000000000..240788962 --- /dev/null +++ b/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir @@ -0,0 +1,278 @@ +// RUN: gc-opt %s -early-dispatch-microkernel -convert-microkernel-to-dnnl-func -merge-branch-microkernel-context -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %cmp = arith.cmpi eq, %arg0, %c0_index : index + scf.if %cmp { + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + } else { + %1 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%1) : (i64) -> () + microkernel.brgemm(%1, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%1) : (i64) -> () + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: simple_brgemm + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK-NEXT: scf.if +// CHECK: } else { +// CHECK: } +// CHECK-NEXT: func.call @dnnl_brgemm_tilerelease() : () -> () + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %cmp = arith.cmpi eq, %arg0, %c0_index : index + scf.if %cmp { + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: simple_brgemm + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: scf.if +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: } + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %cmp = arith.cmpi eq, %arg0, %c0_index : index + scf.if %cmp { + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + } else { + %1 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 512, 512] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%1) : (i64) -> () + microkernel.brgemm(%1, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%1) : (i64) -> () + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: simple_brgemm + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: scf.if +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: } else { +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: } + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + scf.index_switch %arg0 + case 0 { + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + scf.yield + } + case 1 { + %1 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%1) : (i64) -> () + microkernel.brgemm(%1, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%1) : (i64) -> () + scf.yield + } + default { + %2 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%2) : (i64) -> () + microkernel.brgemm(%2, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%2) : (i64) -> () + scf.yield + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: simple_brgemm + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK-NEXT: scf.index_switch +// CHECK: case 0 { +// CHECK: case 1 { +// CHECK: default { +// CHECK: } +// CHECK-NEXT: func.call @dnnl_brgemm_tilerelease() : () -> () + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @simple_brgemm() { + %c0_i64 = arith.constant 0 : i64 + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c16_i64 = arith.constant 16 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + scf.index_switch %arg0 + case 0 { + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%0) : (i64) -> () + microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%0) : (i64) -> () + scf.yield + } + case 1 { + %1 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%1) : (i64) -> () + microkernel.brgemm(%1, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%1) : (i64) -> () + scf.yield + } + default { + %2 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 512, 512] flags = (stride) data_type = (bf16, bf16) + microkernel.brgemm.prologue(%2) : (i64) -> () + microkernel.brgemm(%2, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.epilogue(%2) : (i64) -> () + scf.yield + } + memref.dealloc %alloc_3 : memref<32x32xf32> + } + } + return + } +} + +// CHECK-LABEL: simple_brgemm + +// CHECK: scf.for %arg0 = %c0 to %c4 step %c1 +// CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 + +// CHECK: scf.index_switch +// CHECK: case 0 { +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: case 1 { +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: default { +// CHECK: func.call @dnnl_brgemm_tileconfig +// CHECK: func.call @dnnl_brgemm_tilerelease() : () -> () +// CHECK: } From 994c8b0bc5d0d2db2b28067535331900793222f8 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 19 Jun 2024 00:39:29 -0700 Subject: [PATCH 77/93] fix global_ctors lookup --- .../Microkernel/EarlyDispatchMicrokernel.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp index a3843e01d..ef3ab9c54 100644 --- a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp @@ -106,8 +106,15 @@ getOrCreateGlobalKernelHandle(RewriterBase &rewriter, ModuleOp module, // initialize the gloabl with global_ctors, as the initializer of global // does not allow side effect rewriter.setInsertionPointToStart(module.getBody()); - LLVM::GlobalCtorsOp global_ctors = - module.lookupSymbol(getGlobalCtorsVarName()); + LLVM::GlobalCtorsOp global_ctors = nullptr; + for (auto &op : module->getRegion(0).front()) { + auto ctorOp = dyn_cast(op); + if (ctorOp) { + global_ctors = ctorOp; + break; + } + } + SmallVector ctorRefs; SmallVector priorities; if (global_ctors) { @@ -117,16 +124,20 @@ getOrCreateGlobalKernelHandle(RewriterBase &rewriter, ModuleOp module, ctorRefs.push_back(ctor); priorities.push_back(prior); } + LLVM_DEBUG(llvm::dbgs() + << "After append ctors: " << ctorRefs.size() << "\n"); } ctorRefs.push_back(ctorName); // Set new ctor's priority to lowest priorities.push_back(IntegerAttr::get(rewriter.getI32Type(), INT_MAX)); if (global_ctors) { + LLVM_DEBUG(llvm::dbgs() << "Replace existing ctors\n"); // If there's existing ctors rewriter.replaceOpWithNewOp( global_ctors, rewriter.getArrayAttr(ctorRefs), rewriter.getArrayAttr(priorities)); } else { + LLVM_DEBUG(llvm::dbgs() << "Create new ctor\n"); rewriter.create(module.getLoc(), rewriter.getArrayAttr(ctorRefs), rewriter.getArrayAttr(priorities)); From b13b5cd77129c55f3fbe61146af0601c74241a97 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 26 Jun 2024 02:21:57 -0700 Subject: [PATCH 78/93] refine test cases --- .../Dialect/Microkernel/merge-branch-microkernel-context.mlir | 2 +- .../Dialect/Microkernel/microkernel-invariant-code-motion.mlir | 2 +- test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir | 2 +- test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir | 2 +- test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir b/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir index 240788962..4eb0f4e74 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s -early-dispatch-microkernel -convert-microkernel-to-dnnl-func -merge-branch-microkernel-context -split-input-file | FileCheck %s +// RUN: gc-opt %s -early-dispatch-microkernel -convert-microkernel-to-dnnl-func -cse -merge-branch-microkernel-context -split-input-file | FileCheck %s #map = affine_map<(d0, d1) -> (d0, d1)> module { diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir index 824875886..0996be846 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s -early-dispatch-microkernel -convert-microkernel-to-dnnl-func -microkernel-invariant-code-motion -split-input-file | FileCheck %s +// RUN: gc-opt %s -early-dispatch-microkernel -convert-microkernel-to-dnnl-func -cse -microkernel-invariant-code-motion -split-input-file | FileCheck %s #map = affine_map<(d0, d1) -> (d0, d1)> module { diff --git a/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir b/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir index 6943c8f4e..fe5c833a4 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void +// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --cse --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void #map = affine_map<(d0, d1) -> (d0, d1)> module { diff --git a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir index c43187e52..12f63edb9 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-cpuruntime-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void +// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --cse --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-cpuruntime-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void #map = affine_map<(d0, d1) -> (d0, d1)> module { diff --git a/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir b/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir index 1016ee7d8..9fb762425 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void +// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --cse --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void #map = affine_map<(d0, d1) -> (d0, d1)> module { From af45c6c14050af242263a13d6eeea367f5466a93 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 02:07:41 -0700 Subject: [PATCH 79/93] fix util headers --- lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp | 2 +- lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp | 2 +- .../Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp index ef3ab9c54..38314bfec 100644 --- a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp @@ -17,7 +17,7 @@ #include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" #include "gc/Transforms/Microkernel/MicrokernelPasses.h" -#include "gc/Utils/ValueUtils.h" +#include "gc/Transforms/Utils/ValueUtils.h" #include "oneapi/dnnl/dnnl_types.h" namespace mlir::microkernel { diff --git a/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp index b0dbaeb91..782db63ef 100644 --- a/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp +++ b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp @@ -14,7 +14,7 @@ #include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" #include "gc/Transforms/Microkernel/MicrokernelPasses.h" -#include "gc/Utils/ValueUtils.h" +#include "gc/Transforms/Utils/ValueUtils.h" #include "oneapi/dnnl/dnnl_types.h" namespace mlir::microkernel { diff --git a/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp index 8604e9c59..1bf7d075a 100644 --- a/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp +++ b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp @@ -18,7 +18,7 @@ #include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" #include "gc/Transforms/Microkernel/MicrokernelPasses.h" -#include "gc/Utils/ValueUtils.h" +#include "gc/Transforms/Utils/ValueUtils.h" #include "oneapi/dnnl/dnnl_types.h" namespace mlir::microkernel { From b42b5b0425cffc36e3cd041907d2efcb0a9a4440 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 02:20:41 -0700 Subject: [PATCH 80/93] fix license and tidy --- .../Microkernel/EarlyDispatchMicrokernel.cpp | 16 ++++++---------- .../MergeBranchMicrokernelContext.cpp | 4 ++-- .../MicrokernelInvariantCodeMotion.cpp | 8 +++++--- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp index 38314bfec..972c1e631 100644 --- a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp @@ -1,10 +1,10 @@ -//===- EarlyDispatchMicrokernel.cpp ----------------------------*- C++ -*-===// +//===-- EarlyDispatchMicrokernel.cpp - Dispatch before runtime --*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===---------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -26,10 +26,6 @@ namespace mlir::microkernel { #define DEBUG_TYPE "early-dispatch-microkernel" -static constexpr StringRef getGlobalCtorsVarName() { - return "llvm.global_ctors"; -} - static FailureOr createGlobalKernelHandleName(RewriterBase &rewriter, microkernel::BrgemmDispatchOp op) { @@ -41,9 +37,9 @@ createGlobalKernelHandleName(RewriterBase &rewriter, for (auto flag : flags) { auto brgemmFlag = dyn_cast_or_null(flag); if (!brgemmFlag) - return failure("unknown flag for BRGEMM"); + return failure(); if (brgemmFlag.getValue() == BrgemmFlags::LIST) - return failure("addr mode BRGEMM not supported yet"); + return failure(); if (brgemmFlag.getValue() == BrgemmFlags::BETA_0) ss << "_init"; } @@ -58,7 +54,7 @@ createGlobalKernelHandleName(RewriterBase &rewriter, // dtypeA, dtypeB auto dtypes = op.getDataType(); if (dtypes.size() != 2) - return failure("invalid number of DataType for BRGEMM"); + return failure(); ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[0]); ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[1]); @@ -78,7 +74,7 @@ getOrCreateGlobalKernelHandle(RewriterBase &rewriter, ModuleOp module, FlatSymbolRefAttr ctorName = SymbolRefAttr::get(module->getContext(), kernelName + "_ctor"); if (module.lookupSymbol(ctorName.getAttr())) { - return failure("Existing ctor for new global kernel handle"); + return failure(); } OpBuilder::InsertionGuard insertGuard(rewriter); diff --git a/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp index 782db63ef..0bc98be44 100644 --- a/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp +++ b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp @@ -1,10 +1,10 @@ -//===- MergeBranchMicrokernelContext.cpp -----------------------*- C++ -*-===// +//===-- MergeBranchMicrokernelContext.cpp - Merge same context --*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===---------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" diff --git a/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp index 1bf7d075a..8014777ad 100644 --- a/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp +++ b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp @@ -1,10 +1,10 @@ -//===- MicrokernelInvariantCodeMotion.cpp ----------------------*- C++ -*-===// +//===-- MicrokernelInvariantCodeMotion.cpp - Hoist invariance ---*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===---------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -82,7 +82,7 @@ struct BrgemmContextStructInfo { } }; -typedef DenseMap OpStructInfoMap; +using OpStructInfoMap = DenseMap; class BrgemmTilecfgRewriter : public OpRewritePattern { private: @@ -297,12 +297,14 @@ class MicrokernelInvariantCodeMotion // current Op without violating other peer BrgemmCallOp in the same level; For // example, one scf.ForOp contains two TILECFG in the same level, then we // cannot hoist any of them. + // NOLINTBEGIN(performance-unnecessary-value-param) void expandInvariantScopeWithCond( OpStructInfoMap &structInfo, Operation *op, std::function controlFlowAllow, std::function &)> peerAllow) { + // NOLINTEND(performance-unnecessary-value-param) auto opIter = structInfo.find(op); assert(opIter != structInfo.end()); auto contextRoot = opIter->second.contextRoot; From f195953fd60a4080594b718899a88afe24474bc4 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 19:32:23 -0700 Subject: [PATCH 81/93] fix clang-tidy --- .../Transforms/Microkernel/MergeBranchMicrokernelContext.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp index 0bc98be44..a1a09ea21 100644 --- a/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp +++ b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp @@ -189,7 +189,6 @@ class ScfIfRewriter : public OpRewritePattern { LogicalResult matchAndRewrite(scf::IfOp op, PatternRewriter &rewriter) const final { - ModuleOp module = op->template getParentOfType(); auto &ifRegion = op.getThenRegion(); auto &elseRegion = op.getElseRegion(); if (!ifRegion.hasOneBlock() || !elseRegion.hasOneBlock()) @@ -238,7 +237,6 @@ class ScfIndexSwitchRewriter : public OpRewritePattern { LogicalResult matchAndRewrite(scf::IndexSwitchOp op, PatternRewriter &rewriter) const final { - ModuleOp module = op->template getParentOfType(); auto &defaultRegion = op.getDefaultRegion(); auto caseRegions = op.getCaseRegions(); @@ -260,7 +258,6 @@ class ScfIndexSwitchRewriter : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "Cannot find kernel dispatch"); for (size_t idx = 0; idx < caseRegions.size(); idx++) { - auto &caseRegion = caseRegions[idx]; auto caseTileDispatch = analysis.getKernelDispatch(caseTilesOps[idx].first); if (!defaultTileDispatch) From dc119edcd2b5fab517594629398d03797a4344e2 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 20:10:30 -0700 Subject: [PATCH 82/93] Revert "change clang-tidy-version" This reverts commit fccd6834c66a268654f57613246a93246874faa4. --- .github/workflows/clang-tidy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/clang-tidy.yml b/.github/workflows/clang-tidy.yml index f65296c92..a3034d380 100644 --- a/.github/workflows/clang-tidy.yml +++ b/.github/workflows/clang-tidy.yml @@ -103,4 +103,4 @@ jobs: shell: bash run: | cd build - python3 ../llvm-project/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py -warnings-as-errors=* -p ./ -config-file ../llvm-project/mlir/.clang-tidy -clang-tidy-binary $(which clang-tidy-15) ${{ env.CHANGED_FILES }} + python3 ../llvm-project/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py -warnings-as-errors=* -p ./ -config-file ../llvm-project/mlir/.clang-tidy -clang-tidy-binary $(which clang-tidy) ${{ env.CHANGED_FILES }} From e48273555e91cd8cab3a88e5cf110e61a3e9761e Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 24 Jul 2024 21:13:19 -0700 Subject: [PATCH 83/93] Revert "Revert "change clang-tidy-version"" This reverts commit 2bcaa279de6436672d9eb0145f527435eee16362. --- .github/workflows/clang-tidy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/clang-tidy.yml b/.github/workflows/clang-tidy.yml index a3034d380..f65296c92 100644 --- a/.github/workflows/clang-tidy.yml +++ b/.github/workflows/clang-tidy.yml @@ -103,4 +103,4 @@ jobs: shell: bash run: | cd build - python3 ../llvm-project/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py -warnings-as-errors=* -p ./ -config-file ../llvm-project/mlir/.clang-tidy -clang-tidy-binary $(which clang-tidy) ${{ env.CHANGED_FILES }} + python3 ../llvm-project/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py -warnings-as-errors=* -p ./ -config-file ../llvm-project/mlir/.clang-tidy -clang-tidy-binary $(which clang-tidy-15) ${{ env.CHANGED_FILES }} From c606f6d85a4efe7c964a95848a97345cb0a85a69 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 30 Jul 2024 02:04:00 -0700 Subject: [PATCH 84/93] refactor per reviews --- .../Microkernel/EarlyDispatchMicrokernel.cpp | 132 ++--- .../MicrokernelInvariantCodeMotion.cpp | 15 +- .../merge-branch-microkernel-context.mlir | 504 ++++++++++++------ 3 files changed, 414 insertions(+), 237 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp index 972c1e631..4b4b169ef 100644 --- a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp @@ -68,76 +68,76 @@ getOrCreateGlobalKernelHandle(RewriterBase &rewriter, ModuleOp module, const std::string &kernelName, microkernel::BrgemmDispatchOp op) { // Create the global at the entry of the module - LLVM::GlobalOp global; - if (!(global = module.lookupSymbol(kernelName))) { - auto global_type = op.getResults().getType(); - FlatSymbolRefAttr ctorName = - SymbolRefAttr::get(module->getContext(), kernelName + "_ctor"); - if (module.lookupSymbol(ctorName.getAttr())) { - return failure(); - } + LLVM::GlobalOp global = module.lookupSymbol(kernelName); + if (global) + return global; + + auto global_type = op.getResults().getType(); + FlatSymbolRefAttr ctorName = + SymbolRefAttr::get(module->getContext(), kernelName + "_ctor"); + if (module.lookupSymbol(ctorName.getAttr())) + return failure(); - OpBuilder::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - global = rewriter.create( - module.getLoc(), global_type, /*isConstant=*/false, - LLVM::Linkage::Internal, kernelName, Attribute(), - /*alignment=*/0); - - // create ctor for this global, which needs to be LLVMFuncOp - LLVM::LLVMFuncOp ctorFunc = rewriter.create( - module.getLoc(), ctorName.getValue(), - LLVM::LLVMFunctionType::get(global_type, {}, false)); - - Location loc = ctorFunc.getLoc(); - Block *entryBlock = ctorFunc.addEntryBlock(rewriter); - rewriter.setInsertionPointToEnd(entryBlock); - - auto dispatch = op.clone(); - rewriter.insert(dispatch); - Value globalPtr = rewriter.create(loc, global); - rewriter.create(loc, dispatch.getResults(), globalPtr); - rewriter.create(loc, dispatch.getResults()); - - // initialize the gloabl with global_ctors, as the initializer of global - // does not allow side effect - rewriter.setInsertionPointToStart(module.getBody()); - LLVM::GlobalCtorsOp global_ctors = nullptr; - for (auto &op : module->getRegion(0).front()) { - auto ctorOp = dyn_cast(op); - if (ctorOp) { - global_ctors = ctorOp; - break; - } + OpBuilder::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + global = rewriter.create( + module.getLoc(), global_type, /*isConstant=*/false, + LLVM::Linkage::Internal, kernelName, Attribute(), + /*alignment=*/0); + + // create ctor for this global, which needs to be LLVMFuncOp + LLVM::LLVMFuncOp ctorFunc = rewriter.create( + module.getLoc(), ctorName.getValue(), + LLVM::LLVMFunctionType::get(global_type, {}, false)); + + Location loc = ctorFunc.getLoc(); + Block *entryBlock = ctorFunc.addEntryBlock(rewriter); + rewriter.setInsertionPointToEnd(entryBlock); + + auto dispatch = op.clone(); + rewriter.insert(dispatch); + Value globalPtr = rewriter.create(loc, global); + rewriter.create(loc, dispatch.getResults(), globalPtr); + rewriter.create(loc, dispatch.getResults()); + + // initialize the gloabl with global_ctors, as the initializer of global + // does not allow side effect + rewriter.setInsertionPointToStart(module.getBody()); + LLVM::GlobalCtorsOp global_ctors = nullptr; + for (auto &op : module->getRegion(0).front()) { + auto ctorOp = dyn_cast(op); + if (ctorOp) { + global_ctors = ctorOp; + break; } + } - SmallVector ctorRefs; - SmallVector priorities; - if (global_ctors) { - auto ctorRefsAttr = global_ctors.getCtors(); - auto prioritiesAttr = global_ctors.getPriorities(); - for (auto &&[ctor, prior] : llvm::zip(ctorRefsAttr, prioritiesAttr)) { - ctorRefs.push_back(ctor); - priorities.push_back(prior); - } - LLVM_DEBUG(llvm::dbgs() - << "After append ctors: " << ctorRefs.size() << "\n"); - } - ctorRefs.push_back(ctorName); - // Set new ctor's priority to lowest - priorities.push_back(IntegerAttr::get(rewriter.getI32Type(), INT_MAX)); - if (global_ctors) { - LLVM_DEBUG(llvm::dbgs() << "Replace existing ctors\n"); - // If there's existing ctors - rewriter.replaceOpWithNewOp( - global_ctors, rewriter.getArrayAttr(ctorRefs), - rewriter.getArrayAttr(priorities)); - } else { - LLVM_DEBUG(llvm::dbgs() << "Create new ctor\n"); - rewriter.create(module.getLoc(), - rewriter.getArrayAttr(ctorRefs), - rewriter.getArrayAttr(priorities)); + SmallVector ctorRefs; + SmallVector priorities; + if (global_ctors) { + auto ctorRefsAttr = global_ctors.getCtors(); + auto prioritiesAttr = global_ctors.getPriorities(); + for (auto &&[ctor, prior] : llvm::zip(ctorRefsAttr, prioritiesAttr)) { + ctorRefs.push_back(ctor); + priorities.push_back(prior); } + LLVM_DEBUG(llvm::dbgs() + << "After append ctors: " << ctorRefs.size() << "\n"); + } + ctorRefs.push_back(ctorName); + // Set new ctor's priority to lowest + priorities.push_back(IntegerAttr::get(rewriter.getI32Type(), INT_MAX)); + if (global_ctors) { + LLVM_DEBUG(llvm::dbgs() << "Replace existing ctors\n"); + // If there's existing ctors + rewriter.replaceOpWithNewOp( + global_ctors, rewriter.getArrayAttr(ctorRefs), + rewriter.getArrayAttr(priorities)); + } else { + LLVM_DEBUG(llvm::dbgs() << "Create new ctor\n"); + rewriter.create(module.getLoc(), + rewriter.getArrayAttr(ctorRefs), + rewriter.getArrayAttr(priorities)); } return global; } diff --git a/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp index 8014777ad..1eaa18980 100644 --- a/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp +++ b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp @@ -44,18 +44,14 @@ static BrgemmCallType getBrgemmCallType(Operation *op) { return BrgemmCallType::INAPPLICABLE; } auto callOp = dyn_cast(op); - StringAttr callee = callOp.getCalleeAttr().getAttr(); + auto calleeName = callOp.getCalleeAttr().getAttr().getValue(); - if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_DISPATCH_NAME)) { + if (calleeName == DNNL_BRGEMM_DISPATCH_NAME) return BrgemmCallType::DISPATCH; - } - if (callee == StringAttr::get(op->getContext(), DNNL_BRGEMM_TILECFG_NAME)) { + if (calleeName == DNNL_BRGEMM_TILECFG_NAME) return BrgemmCallType::TILECFG; - } - if (callee == - StringAttr::get(op->getContext(), DNNL_BRGEMM_TILERELEASE_NAME)) { + if (calleeName == DNNL_BRGEMM_TILERELEASE_NAME) return BrgemmCallType::TILERELEASE; - } return BrgemmCallType::INAPPLICABLE; } @@ -364,8 +360,7 @@ class MicrokernelInvariantCodeMotion expandInvariantScopeWithCond( structInfo, op, [](Operation *op) -> bool { - return !llvm::isa(op) && - !llvm::isa(op); + return !llvm::isa(op); }, [](Operation *self, const OpStructInfoMap &structInfo, Operation *current, diff --git a/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir b/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir index 4eb0f4e74..f2dece5d4 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir @@ -1,44 +1,81 @@ -// RUN: gc-opt %s -early-dispatch-microkernel -convert-microkernel-to-dnnl-func -cse -merge-branch-microkernel-context -split-input-file | FileCheck %s +// RUN: gc-opt %s -merge-branch-microkernel-context -split-input-file | FileCheck %s -#map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @simple_brgemm() { - %c0_i64 = arith.constant 0 : i64 - %c0_index = arith.constant 0 : index - %c1_index = arith.constant 1 : index - %c4_index = arith.constant 4 : index - %c8_index = arith.constant 8 : index - %c16_i64 = arith.constant 16 : i64 + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @if_branch_context_merge() { %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %2 = llvm.load %1 : !llvm.ptr -> i64 + %3 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.for %arg0 = %c0_index to %c4_index step %c1_index { - scf.for %arg1 = %c0_index to %c8_index step %c1_index { - %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) - %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %cmp = arith.cmpi eq, %arg0, %c0_index : index - scf.if %cmp { - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) - microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%0) : (i64) -> () - } else { - %1 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) - microkernel.brgemm.prologue(%1) : (i64) -> () - microkernel.brgemm(%1, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%1) : (i64) -> () - } - memref.dealloc %alloc_3 : memref<32x32xf32> + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %4 = arith.index_cast %intptr : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_2 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %8 = arith.index_cast %intptr_8 : index to i64 + %9 = llvm.inttoptr %8 : i64 to !llvm.ptr + %10 = arith.cmpi eq, %arg0, %c0 : index + scf.if %10 { + func.call @dnnl_brgemm_tileconfig(%2) : (i64) -> () + func.call @dnnl_brgemm_execute(%2, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + } else { + func.call @dnnl_brgemm_tileconfig(%3) : (i64) -> () + func.call @dnnl_brgemm_execute(%3, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () } + memref.dealloc %alloc_1 : memref<32x32xf32> + } } return } } -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: if_branch_context_merge // CHECK: scf.for %arg0 = %c0 to %c4 step %c1 // CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 @@ -51,40 +88,65 @@ module { // ----- -#map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @simple_brgemm() { - %c0_i64 = arith.constant 0 : i64 - %c0_index = arith.constant 0 : index - %c1_index = arith.constant 1 : index - %c4_index = arith.constant 4 : index - %c8_index = arith.constant 8 : index - %c16_i64 = arith.constant 16 : i64 + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @if_only_branch_context_merge() { %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.for %arg0 = %c0_index to %c4_index step %c1_index { - scf.for %arg1 = %c0_index to %c8_index step %c1_index { - %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) - %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %cmp = arith.cmpi eq, %arg0, %c0_index : index - scf.if %cmp { - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) - microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%0) : (i64) -> () - } - memref.dealloc %alloc_3 : memref<32x32xf32> + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %2 = arith.index_cast %intptr : index to i64 + %3 = llvm.inttoptr %2 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %4 = arith.index_cast %intptr_2 : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_8 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + %8 = arith.cmpi eq, %arg0, %c0 : index + scf.if %8 { + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %5, %offset, %7, %offset_5, %3, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () } + memref.dealloc %alloc_1 : memref<32x32xf32> + } } return } } -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: if_only_branch_context_merge // CHECK: scf.for %arg0 = %c0 to %c4 step %c1 // CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 @@ -96,45 +158,82 @@ module { // ----- -#map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @simple_brgemm() { - %c0_i64 = arith.constant 0 : i64 - %c0_index = arith.constant 0 : index - %c1_index = arith.constant 1 : index - %c4_index = arith.constant 4 : index - %c8_index = arith.constant 8 : index - %c16_i64 = arith.constant 16 : i64 + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2_ctor, @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 0.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c512_i64, %c512_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @if_branch_context_no_merge() { + %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr + %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %2 = llvm.load %1 : !llvm.ptr -> i64 + %3 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.for %arg0 = %c0_index to %c4_index step %c1_index { - scf.for %arg1 = %c0_index to %c8_index step %c1_index { - %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) - %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %cmp = arith.cmpi eq, %arg0, %c0_index : index - scf.if %cmp { - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) - microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%0) : (i64) -> () - } else { - %1 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 512, 512] flags = (stride) data_type = (bf16, bf16) - microkernel.brgemm.prologue(%1) : (i64) -> () - microkernel.brgemm(%1, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%1) : (i64) -> () - } - memref.dealloc %alloc_3 : memref<32x32xf32> + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %4 = arith.index_cast %intptr : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_2 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %8 = arith.index_cast %intptr_8 : index to i64 + %9 = llvm.inttoptr %8 : i64 to !llvm.ptr + %10 = arith.cmpi eq, %arg0, %c0 : index + scf.if %10 { + func.call @dnnl_brgemm_tileconfig(%2) : (i64) -> () + func.call @dnnl_brgemm_execute(%2, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + } else { + func.call @dnnl_brgemm_tileconfig(%3) : (i64) -> () + func.call @dnnl_brgemm_execute(%3, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () } + memref.dealloc %alloc_1 : memref<32x32xf32> + } } return } } -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: if_branch_context_no_merge // CHECK: scf.for %arg0 = %c0 to %c4 step %c1 // CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 @@ -149,55 +248,90 @@ module { // ----- -#map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @simple_brgemm() { - %c0_i64 = arith.constant 0 : i64 - %c0_index = arith.constant 0 : index - %c1_index = arith.constant 1 : index - %c4_index = arith.constant 4 : index - %c8_index = arith.constant 8 : index - %c16_i64 = arith.constant 16 : i64 + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 0.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @switch_branch_context_merge() { + %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %2 = llvm.load %1 : !llvm.ptr -> i64 + %3 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.for %arg0 = %c0_index to %c4_index step %c1_index { - scf.for %arg1 = %c0_index to %c8_index step %c1_index { - %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) - %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - scf.index_switch %arg0 - case 0 { - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) - microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%0) : (i64) -> () - scf.yield - } - case 1 { - %1 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) - microkernel.brgemm.prologue(%1) : (i64) -> () - microkernel.brgemm(%1, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%1) : (i64) -> () - scf.yield - } - default { - %2 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) - microkernel.brgemm.prologue(%2) : (i64) -> () - microkernel.brgemm(%2, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%2) : (i64) -> () - scf.yield - } - memref.dealloc %alloc_3 : memref<32x32xf32> + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %4 = arith.index_cast %intptr : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_2 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %8 = arith.index_cast %intptr_8 : index to i64 + %9 = llvm.inttoptr %8 : i64 to !llvm.ptr + scf.index_switch %arg0 + case 0 { + func.call @dnnl_brgemm_tileconfig(%3) : (i64) -> () + func.call @dnnl_brgemm_execute(%3, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + scf.yield } + case 1 { + func.call @dnnl_brgemm_tileconfig(%2) : (i64) -> () + func.call @dnnl_brgemm_execute(%2, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + scf.yield + } + default { + func.call @dnnl_brgemm_tileconfig(%2) : (i64) -> () + func.call @dnnl_brgemm_execute(%2, %7, %offset, %9, %offset_5, %5, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + } + memref.dealloc %alloc_1 : memref<32x32xf32> + } } return } } -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: switch_branch_context_merge // CHECK: scf.for %arg0 = %c0 to %c4 step %c1 // CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 @@ -212,55 +346,103 @@ module { // ----- -#map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @simple_brgemm() { - %c0_i64 = arith.constant 0 : i64 - %c0_index = arith.constant 0 : index - %c1_index = arith.constant 1 : index - %c4_index = arith.constant 4 : index - %c8_index = arith.constant 8 : index - %c16_i64 = arith.constant 16 : i64 + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32, 2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c512_i64 = arith.constant 512 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c512_i64, %c512_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 0.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @switch_branch_context_no_merge() { %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %2 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr + %3 = llvm.load %2 : !llvm.ptr -> i64 + %4 = llvm.load %1 : !llvm.ptr -> i64 + %5 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.for %arg0 = %c0_index to %c4_index step %c1_index { - scf.for %arg1 = %c0_index to %c8_index step %c1_index { - %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) - %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - scf.index_switch %arg0 - case 0 { - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0, stride) data_type = (bf16, bf16) - microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%0) : (i64) -> () - scf.yield - } - case 1 { - %1 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) - microkernel.brgemm.prologue(%1) : (i64) -> () - microkernel.brgemm(%1, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%1) : (i64) -> () - scf.yield - } - default { - %2 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 512, 512] flags = (stride) data_type = (bf16, bf16) - microkernel.brgemm.prologue(%2) : (i64) -> () - microkernel.brgemm(%2, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%2) : (i64) -> () - scf.yield - } - memref.dealloc %alloc_3 : memref<32x32xf32> + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %6 = arith.index_cast %intptr : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %8 = arith.index_cast %intptr_2 : index to i64 + %9 = llvm.inttoptr %8 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %10 = arith.index_cast %intptr_8 : index to i64 + %11 = llvm.inttoptr %10 : i64 to !llvm.ptr + scf.index_switch %arg0 + case 0 { + func.call @dnnl_brgemm_tileconfig(%4) : (i64) -> () + func.call @dnnl_brgemm_execute(%4, %9, %offset, %11, %offset_5, %7, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + scf.yield + } + case 1 { + func.call @dnnl_brgemm_tileconfig(%5) : (i64) -> () + func.call @dnnl_brgemm_execute(%5, %9, %offset, %11, %offset_5, %7, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + scf.yield + } + default { + func.call @dnnl_brgemm_tileconfig(%3) : (i64) -> () + func.call @dnnl_brgemm_execute(%3, %9, %offset, %11, %offset_5, %7, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () } + memref.dealloc %alloc_1 : memref<32x32xf32> + } } return } } -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: switch_branch_context_no_merge // CHECK: scf.for %arg0 = %c0 to %c4 step %c1 // CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 From 24e8681d80ea3e70ce918dd23c96b9d3ca8b64b6 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 21 Aug 2024 19:26:58 -0700 Subject: [PATCH 85/93] fix mlir test --- .../early-dispatch-microkernel.mlir | 8 +++---- .../microkernel-invariant-code-motion.mlir | 24 +++++++++---------- .../gc/cpu-runner/brgemm-multilevel-for.mlir | 8 +++---- .../test/gc/cpu-runner/brgemm-simple-for.mlir | 4 ++-- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/test/mlir/test/gc/Dialect/Microkernel/early-dispatch-microkernel.mlir b/test/mlir/test/gc/Dialect/Microkernel/early-dispatch-microkernel.mlir index 60b982d5c..364b88dd1 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/early-dispatch-microkernel.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/early-dispatch-microkernel.mlir @@ -15,9 +15,9 @@ module { linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%0) : (i64) -> () %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { @@ -42,7 +42,7 @@ module { // CHECK: llvm.func @[[G_CTOR_NAME]]() -> i64 { // CHECK-DAG: %[[G_PTR:.+]] = llvm.mlir.addressof @[[G_NAME]] : !llvm.ptr -// CHECK-DAG: %[[KERNEL:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) +// CHECK-DAG: %[[KERNEL:.+]] = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) // CHECK: llvm.store %[[KERNEL]], %[[G_PTR]] : i64, !llvm.ptr // CHECK-LABEL: simple_brgemm @@ -57,6 +57,6 @@ module { // CHECK: %[[subviewA:.+]] = memref.subview %[[memrefA:.+]][%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> // CHECK: %[[subviewB:.+]] = memref.subview %[[memrefB:.+]][%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: microkernel.brgemm(%[[KERNEL2]], %[[subviewA]], %[[subviewB]], %[[memrefC]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () +// CHECK: microkernel.brgemm.execute(%[[KERNEL2]], %[[subviewA]], %[[subviewB]], %[[memrefC]], %[[CST16]], %[[CST0]]) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () // ----- diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir index 0996be846..762ccd64b 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir @@ -14,9 +14,9 @@ module { linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%0) : (i64) -> () memref.dealloc %alloc_3 : memref<32x32xf32> } @@ -53,16 +53,16 @@ module { linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%0) : (i64) -> () %subview_7 = memref.subview %alloc[%arg0, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> %subview_8 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %10 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + %10 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) microkernel.brgemm.prologue(%10) : (i64) -> () - microkernel.brgemm(%10, %subview_7, %subview_8, %alloc_3, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.execute(%10, %subview_7, %subview_8, %alloc_3, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%10) : (i64) -> () memref.dealloc %alloc_3 : memref<32x32xf32> @@ -109,9 +109,9 @@ module { linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%0) : (i64) -> () memref.dealloc %alloc_3 : memref<32x32xf32> } @@ -120,9 +120,9 @@ module { linalg.fill ins(%cst : f32) outs(%alloc_4 : memref<32x32xf32>) %subview_7 = memref.subview %alloc[%arg2, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> %subview_8 = memref.subview %alloc_0[%arg2, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview_7, %subview_8, %alloc_4, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.execute(%0, %subview_7, %subview_8, %alloc_4, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%0) : (i64) -> () } } @@ -166,9 +166,9 @@ module { linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%0) : (i64) -> () memref.dealloc %alloc_3 : memref<32x32xf32> } diff --git a/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir b/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir index fe5c833a4..2033942a7 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir @@ -21,9 +21,9 @@ module { linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%0) : (i64) -> () %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { @@ -43,9 +43,9 @@ module { linalg.fill ins(%cst : f32) outs(%alloc_4 : memref<32x32xf32>) %subview_7 = memref.subview %alloc[%c0_index, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> %subview_8 = memref.subview %alloc_0[%c0_index, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %10 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (bf16, bf16) + %10 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) microkernel.brgemm.prologue(%10) : (i64) -> () - microkernel.brgemm(%10, %subview_7, %subview_8, %alloc_4, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.execute(%10, %subview_7, %subview_8, %alloc_4, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%10) : (i64) -> () } return diff --git a/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir b/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir index 9fb762425..53c4036bd 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir @@ -24,9 +24,9 @@ module { linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (stride) data_type = (f32, f32) + %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () microkernel.brgemm.epilogue(%0) : (i64) -> () %subview_5 = memref.subview %alloc_1[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3, %subview_5 : memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%alloc_3 : memref<32x32xf32>) { From 0f1fa86d6ef038cd715f0e2e7a33286ad38474fa Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 21 Aug 2024 19:56:13 -0700 Subject: [PATCH 86/93] improve cpu-runner test --- test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir | 6 +++--- test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir b/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir index 2033942a7..941dbd614 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --cse --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void +// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --cse --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-cpuruntime-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void #map = affine_map<(d0, d1) -> (d0, d1)> module { @@ -53,9 +53,9 @@ module { func.func @main() { call @simple_brgemm() : ()->() - // COM: parallelcpu.printf "BRGEMM DONE\n" + cpuruntime.printf "BRGEMM DONE\n" return } - // COM: CHECK: BRGEMM DONE + // CHECK: BRGEMM DONE } diff --git a/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir b/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir index 53c4036bd..8f9a6696f 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --cse --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void +// RUN: gc-opt %s --early-dispatch-microkernel --convert-microkernel-to-dnnl-func --cse --microkernel-invariant-code-motion --convert-linalg-to-loops --convert-scf-to-cf --expand-strided-metadata --lower-affine -finalize-memref-to-llvm --convert-cpuruntime-to-llvm --convert-func-to-llvm --convert-arith-to-llvm --convert-cf-to-llvm --convert-complex-to-llvm --canonicalize --cse --reconcile-unrealized-casts --symbol-dce | gc-cpu-runner -e main -entry-point-result=void #map = affine_map<(d0, d1) -> (d0, d1)> module { @@ -48,9 +48,9 @@ module { func.func @main() { call @simple_brgemm() : ()->() - // COM: parallelcpu.printf "BRGEMM DONE\n" + cpuruntime.printf "BRGEMM DONE\n" return } - // COM: CHECK: BRGEMM DONE + // CHECK: BRGEMM DONE } From 90efaf694f35c399c6379f87934346aed70cbfe2 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 21 Aug 2024 22:50:04 -0700 Subject: [PATCH 87/93] refine mlir test --- .../microkernel-invariant-code-motion.mlir | 314 ++++++++++++------ 1 file changed, 215 insertions(+), 99 deletions(-) diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir index 762ccd64b..a03729d5f 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir @@ -1,30 +1,56 @@ -// RUN: gc-opt %s -early-dispatch-microkernel -convert-microkernel-to-dnnl-func -cse -microkernel-invariant-code-motion -split-input-file | FileCheck %s +// RUN: gc-opt %s -microkernel-invariant-code-motion -split-input-file | FileCheck %s -#map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @simple_brgemm() { - %c0_i64 = arith.constant 0 : i64 - %c16_i64 = arith.constant 16 : i64 + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @parallel_no_hoist() { + %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> scf.forall (%arg0, %arg1) in (4, 8) { - %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %2 = arith.index_cast %intptr : index to i64 + %3 = llvm.inttoptr %2 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) - microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%0) : (i64) -> () - memref.dealloc %alloc_3 : memref<32x32xf32> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %4 = arith.index_cast %intptr_2 : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_8 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %5, %offset, %7, %offset_5, %3, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + memref.dealloc %alloc_1 : memref<32x32xf32> } return } } -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: parallel_no_hoist // CHECK: scf.forall (%arg0, %arg1) in (4, 8) // CHECK: call @dnnl_brgemm_tileconfig(%[[A:.+]]) : (i64) -> () @@ -33,46 +59,76 @@ module { // ----- -#map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @simple_brgemm() { - %c0_i64 = arith.constant 0 : i64 + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 %c2_i64 = arith.constant 2 : i64 - %c0_index = arith.constant 0 : index - %c1_index = arith.constant 1 : index - %c4_index = arith.constant 4 : index - %c8_index = arith.constant 8 : index - %c16_i64 = arith.constant 16 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @multi_level_conflict() { %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c2_i64 = arith.constant 2 : i64 + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.for %arg0 = %c0_index to %c4_index step %c1_index { - scf.for %arg1 = %c0_index to %c8_index step %c1_index { - %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) - %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) - microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%0) : (i64) -> () - - %subview_7 = memref.subview %alloc[%arg0, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_8 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %10 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) - microkernel.brgemm.prologue(%10) : (i64) -> () - microkernel.brgemm.execute(%10, %subview_7, %subview_8, %alloc_3, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%10) : (i64) -> () - - memref.dealloc %alloc_3 : memref<32x32xf32> - } + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %2 = arith.index_cast %intptr : index to i64 + %3 = llvm.inttoptr %2 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %4 = arith.index_cast %intptr_2 : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_8 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %5, %offset, %7, %offset_5, %3, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + %subview_9 = memref.subview %alloc[%arg0, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer_10, %offset_11, %sizes_12:3, %strides_13:3 = memref.extract_strided_metadata %subview_9 : memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_14 = memref.extract_aligned_pointer_as_index %subview_9 : memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %8 = arith.index_cast %intptr_14 : index to i64 + %9 = llvm.inttoptr %8 : i64 to !llvm.ptr + %subview_15 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_16, %offset_17, %sizes_18:4, %strides_19:4 = memref.extract_strided_metadata %subview_15 : memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_20 = memref.extract_aligned_pointer_as_index %subview_15 : memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %10 = arith.index_cast %intptr_20 : index to i64 + %11 = llvm.inttoptr %10 : i64 to !llvm.ptr + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %9, %offset_11, %11, %offset_17, %3, %c0, %c2_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + memref.dealloc %alloc_1 : memref<32x32xf32> + } } return } } -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: multi_level_conflict // CHECK: scf.for %arg0 = %c0 to %c4 step %c1 // CHECK-NEXT: scf.for %arg1 = %c0 to %c8 step %c1 @@ -89,48 +145,83 @@ module { // ----- -#map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @simple_brgemm() { - %c0_i64 = arith.constant 0 : i64 + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 %c2_i64 = arith.constant 2 : i64 - %c0_index = arith.constant 0 : index - %c1_index = arith.constant 1 : index - %c4_index = arith.constant 4 : index - %c8_index = arith.constant 8 : index - %c16_i64 = arith.constant 16 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @multi_level_partial_hoist() { %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c2_i64 = arith.constant 2 : i64 + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.for %arg0 = %c0_index to %c4_index step %c1_index { - scf.for %arg1 = %c0_index to %c8_index step %c1_index { - %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) - %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) - microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%0) : (i64) -> () - memref.dealloc %alloc_3 : memref<32x32xf32> - } - scf.for %arg2 = %c0_index to %c4_index step %c1_index { - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_4 : memref<32x32xf32>) - %subview_7 = memref.subview %alloc[%arg2, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_8 = memref.subview %alloc_0[%arg2, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) - microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm.execute(%0, %subview_7, %subview_8, %alloc_4, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%0) : (i64) -> () - } + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %2 = arith.index_cast %intptr : index to i64 + %3 = llvm.inttoptr %2 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %4 = arith.index_cast %intptr_2 : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_8 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %5, %offset, %7, %offset_5, %3, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + memref.dealloc %alloc_1 : memref<32x32xf32> + } + scf.for %arg1 = %c0 to %c4 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %2 = arith.index_cast %intptr : index to i64 + %3 = llvm.inttoptr %2 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg1, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %4 = arith.index_cast %intptr_2 : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_8 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %5, %offset, %7, %offset_5, %3, %c0, %c2_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + } } return } } -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: multi_level_partial_hoist // CHECK: scf.for %arg0 = %c0 to %c4 step %c1 // CHECK-NEXT: call @dnnl_brgemm_tileconfig(%[[A:.+]]) : (i64) -> () @@ -147,37 +238,62 @@ module { // ----- -#map = affine_map<(d0, d1) -> (d0, d1)> module { - func.func @simple_brgemm() { - %c0_i64 = arith.constant 0 : i64 - %c0_index = arith.constant 0 : index - %c1_index = arith.constant 1 : index - %c4_index = arith.constant 4 : index - %c8_index = arith.constant 8 : index - %c16_i64 = arith.constant 16 : i64 + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %c32_i64 = arith.constant 32 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant 1.000000e+00 : f32 + %c2_i64 = arith.constant 2 : i64 + %1 = func.call @dnnl_brgemm_dispatch(%c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1024_i64, %c1024_i64, %cst, %c2_i64, %c2_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + llvm.store %1, %0 : i64, !llvm.ptr + llvm.return %1 : i64 + } + func.func private @dnnl_brgemm_tilerelease() + func.func private @dnnl_brgemm_execute(i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) + func.func private @dnnl_brgemm_tileconfig(i64) + func.func private @dnnl_brgemm_dispatch(i64, i64, i64, i64, i64, i64, i64, i64, f32, i64, i64) -> i64 + func.func @multi_level_full_hoist() { %cst = arith.constant 0.000000e+00 : f32 + %c16_i64 = arith.constant 16 : i64 + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - scf.for %arg0 = %c0_index to %c4_index step %c1_index { - scf.for %arg1 = %c0_index to %c8_index step %c1_index { - %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) - %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> - %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) - microkernel.brgemm.prologue(%0) : (i64) -> () - microkernel.brgemm.execute(%0, %subview, %subview_4, %alloc_3, %c16_i64, %c0_i64) : (i64, memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () - microkernel.brgemm.epilogue(%0) : (i64) -> () - memref.dealloc %alloc_3 : memref<32x32xf32> - } + scf.for %arg0 = %c0 to %c4 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %intptr = memref.extract_aligned_pointer_as_index %alloc_1 : memref<32x32xf32> -> index + %2 = arith.index_cast %intptr : index to i64 + %3 = llvm.inttoptr %2 : i64 to !llvm.ptr + linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<32x32xf32>) + %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> memref, index, index, index, index, index, index, index + %intptr_2 = memref.extract_aligned_pointer_as_index %subview : memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> -> index + %4 = arith.index_cast %intptr_2 : index to i64 + %5 = llvm.inttoptr %4 : i64 to !llvm.ptr + %subview_3 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %base_buffer_4, %offset_5, %sizes_6:4, %strides_7:4 = memref.extract_strided_metadata %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> memref, index, index, index, index, index, index, index, index, index + %intptr_8 = memref.extract_aligned_pointer_as_index %subview_3 : memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> -> index + %6 = arith.index_cast %intptr_8 : index to i64 + %7 = llvm.inttoptr %6 : i64 to !llvm.ptr + func.call @dnnl_brgemm_tileconfig(%1) : (i64) -> () + func.call @dnnl_brgemm_execute(%1, %5, %offset, %7, %offset_5, %3, %c0, %c16_i64) : (i64, !llvm.ptr, index, !llvm.ptr, index, !llvm.ptr, index, i64) -> () + func.call @dnnl_brgemm_tilerelease() : () -> () + memref.dealloc %alloc_1 : memref<32x32xf32> + } } return } } -// CHECK-LABEL: simple_brgemm +// CHECK-LABEL: multi_level_full_hoist // CHECK: call @dnnl_brgemm_tileconfig(%[[A:.+]]) : (i64) -> () // CHECK-NEXT: scf.for %arg0 = %c0 to %c4 step %c1 From e4b96f6b0d68e0622e0ab66d22a38f78efd8f6eb Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Fri, 30 Aug 2024 01:38:40 -0700 Subject: [PATCH 88/93] code & test refinements --- .../MergeBranchMicrokernelContext.cpp | 18 ++- .../gc/cpu-runner/brgemm-multilevel-for.mlir | 103 ++++++++++++++++-- .../test/gc/cpu-runner/brgemm-parallel.mlir | 102 +++++++++++++++-- .../test/gc/cpu-runner/brgemm-simple-for.mlir | 98 +++++++++++++++-- 4 files changed, 289 insertions(+), 32 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp index a1a09ea21..9865f5220 100644 --- a/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp +++ b/lib/gc/Transforms/Microkernel/MergeBranchMicrokernelContext.cpp @@ -23,8 +23,6 @@ namespace mlir::microkernel { #define DEBUG_TYPE "merge-branch-microkernel-context" -// enum BrgemmCallType { INAPPLICABLE = -1, DISPATCH, TILECFG, TILERELEASE }; - class BrgemmDispatchAnalysis { private: // A map for tile_config -> tile_dispatch @@ -87,12 +85,11 @@ Operation *BrgemmDispatchAnalysis::traceKernelDispatch(Operation *op) { if (callee != StringAttr::get(op->getContext(), DNNL_BRGEMM_DISPATCH_NAME)) return nullptr; return tryCallOp; - } else if (auto tryLoadOp = dyn_cast_or_null(kernelProducer)) { - auto tryAddrOfOp = dyn_cast_or_null( - tryLoadOp.getOperand().getDefiningOp()); - if (!tryAddrOfOp) - return nullptr; - return traceDispatchInGlobalCtor(module, tryAddrOfOp.getGlobalName()); + } + if (auto tryLoadOp = dyn_cast_or_null(kernelProducer)) { + if (auto tryAddrOfOp = dyn_cast_or_null( + tryLoadOp.getOperand().getDefiningOp())) + return traceDispatchInGlobalCtor(module, tryAddrOfOp.getGlobalName()); } return nullptr; } @@ -144,6 +141,7 @@ extractTileOpsFromRegion(Region ®ion) { return ret; } +static size_t DNNL_BRGEMM_DISPATCH_BETA_PARAM_INDEX = 8; static bool dispatchHasSameContext(Operation *lhs, Operation *rhs) { auto lhsDispatch = dyn_cast_or_null(lhs); auto rhsDispatch = dyn_cast_or_null(rhs); @@ -160,7 +158,7 @@ static bool dispatchHasSameContext(Operation *lhs, Operation *rhs) { assert(lhsOperands.size() == rhsOperands.size() && "Inconsistent operand size"); for (size_t idx = 0; idx < lhsOperands.size(); idx++) { - if (idx == 8) { + if (idx == DNNL_BRGEMM_DISPATCH_BETA_PARAM_INDEX) { // skip `beta` operand in index no.8 // since per dnnl design, it does not affect BRGEMM blocking & palettes continue; @@ -260,7 +258,7 @@ class ScfIndexSwitchRewriter : public OpRewritePattern { for (size_t idx = 0; idx < caseRegions.size(); idx++) { auto caseTileDispatch = analysis.getKernelDispatch(caseTilesOps[idx].first); - if (!defaultTileDispatch) + if (!caseTileDispatch) return rewriter.notifyMatchFailure(op, "Cannot find kernel dispatch"); if (!dispatchHasSameContext(defaultTileDispatch, caseTileDispatch)) return rewriter.notifyMatchFailure( diff --git a/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir b/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir index 941dbd614..759f22776 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-multilevel-for.mlir @@ -9,16 +9,30 @@ module { %c1_index = arith.constant 1 : index %c4_index = arith.constant 4 : index %c8_index = arith.constant 8 : index + %c32_index = arith.constant 32 : index %c16_i64 = arith.constant 16 : i64 - %cst = arith.constant 0.000000e+00 : f32 + %cst0f = arith.constant 0.000000e+00 : f32 + %cstn64f = arith.constant -64.000000e+00 : f32 + %cst1f = arith.constant 1.000000e+00 : bf16 + %cst2f = arith.constant 2.000000e+00 : bf16 + %cst3f = arith.constant 3.000000e+00 : f32 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> + linalg.fill ins(%cst1f : bf16) outs(%alloc : memref<4x16x32x32xbf16>) %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> + linalg.fill ins(%cst2f : bf16) outs(%alloc_0 : memref<8x16x16x32x2xbf16>) %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst3f : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> + linalg.fill ins(%cst0f : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) scf.for %arg0 = %c0_index to %c4_index step %c1_index { scf.for %arg1 = %c0_index to %c8_index step %c1_index { %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %arg0i = arith.index_castui %arg0 : index to i64 + %arg1i = arith.index_castui %arg1 : index to i64 + %argmulti = arith.muli %arg0i, %arg1i : i64 + %v = arith.uitofp %argmulti : i64 to f32 + %v1 = arith.mulf %v, %cstn64f : f32 + linalg.fill ins(%v1 : f32) outs(%alloc_3 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<16x32x32xbf16, strided<[1024, 32, 1], offset: ?>> %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<16x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) @@ -34,20 +48,31 @@ module { %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { ^bb0(%in: f32, %out: f32): - %1 = arith.maximumf %in, %cst : f32 + %1 = arith.maximumf %in, %cst0f : f32 linalg.yield %1 : f32 } memref.dealloc %alloc_3 : memref<32x32xf32> } - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_4 : memref<32x32xf32>) - %subview_7 = memref.subview %alloc[%c0_index, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - %subview_8 = memref.subview %alloc_0[%c0_index, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> + %subview_c = memref.subview %alloc_2[%arg0, 0, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + %subview_7 = memref.subview %alloc[%arg0, 0, 0, 0] [1, 2, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xbf16> to memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>> + %subview_8 = memref.subview %alloc_0[%arg0, 0, 0, 0, 0] [1, 2, 16, 32, 2] [1, 1, 1, 1, 1] : memref<8x16x16x32x2xbf16> to memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>> %10 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(bf16, bf16) microkernel.brgemm.prologue(%10) : (i64) -> () - microkernel.brgemm.execute(%10, %subview_7, %subview_8, %alloc_4, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32>, i64, i64) -> () + microkernel.brgemm.execute(%10, %subview_7, %subview_8, %subview_c, %c2_i64, %c0_i64) : (i64, memref<2x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<2x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64, i64) -> () microkernel.brgemm.epilogue(%10) : (i64) -> () } + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + scf.for %arg2 = %c0_index to %c32_index step %c1_index { + scf.for %arg3 = %c0_index to %c32_index step %c1_index { + %elem = memref.load %alloc_2[%arg0, %arg1, %arg2, %arg3] : memref<4x8x32x32xf32> + cpuruntime.printf "%f, " %elem : f32 + } + cpuruntime.printf "\n" + } + cpuruntime.printf "==================================\n" + } + } return } @@ -57,5 +82,67 @@ module { return } + // CHECK: 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, + // CHECK: ================================== + // CHECK: 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, + // CHECK: ================================== + // CHECK: 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, + // CHECK: ================================== + // CHECK: 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, + // CHECK: ================================== + // CHECK: 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, + // CHECK: ================================== + // CHECK: 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, + // CHECK: ================================== + // CHECK: 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, + // CHECK: ================================== + // CHECK: 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, + // CHECK: ================================== + // CHECK: 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, + // CHECK: ================================== + // CHECK: 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, + // CHECK: ================================== + // CHECK: 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, + // CHECK: ================================== + // CHECK: 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, 1155.000000, + // CHECK: ================================== + // CHECK: 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, + // CHECK: ================================== + // CHECK: 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, + // CHECK: ================================== + // CHECK: 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, + // CHECK: ================================== + // CHECK: 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // CHECK: ================================== + // CHECK: 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // CHECK: ================================== // CHECK: BRGEMM DONE } diff --git a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir index 12f63edb9..38cf0dbf0 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-parallel.mlir @@ -5,18 +5,32 @@ module { func.func @simple_brgemm() { %c0_i64 = arith.constant 0 : i64 %c16_i64 = arith.constant 16 : i64 - %cst = arith.constant 0.000000e+00 : f32 + %c0_index = arith.constant 0 : index + %c1_index = arith.constant 1 : index + %c4_index = arith.constant 4 : index + %c8_index = arith.constant 8 : index + %c32_index = arith.constant 32 : index + %cst0f = arith.constant 0.000000e+00 : f32 + %cstn64f = arith.constant -64.000000e+00 : f32 + %cst1f = arith.constant 1.000000e+00 : f32 + %cst2f = arith.constant 2.000000e+00 : f32 + %cst3f = arith.constant 3.000000e+00 : f32 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc : memref<4x16x32x32xf32>) + linalg.fill ins(%cst1f : f32) outs(%alloc : memref<4x16x32x32xf32>) %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<8x16x32x32xf32>) + linalg.fill ins(%cst2f : f32) outs(%alloc_0 : memref<8x16x32x32xf32>) %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) + linalg.fill ins(%cst3f : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) + linalg.fill ins(%cst0f : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) scf.forall (%arg0, %arg1) in (4, 8) { %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %arg0i = arith.index_castui %arg0 : index to i64 + %arg1i = arith.index_castui %arg1 : index to i64 + %argmulti = arith.muli %arg0i, %arg1i : i64 + %v = arith.uitofp %argmulti : i64 to f32 + %v1 = arith.mulf %v, %cstn64f : f32 + linalg.fill ins(%v1 : f32) outs(%alloc_3 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) @@ -32,16 +46,90 @@ module { %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { ^bb0(%in: f32, %out: f32): - %1 = arith.maximumf %in, %cst : f32 + %1 = arith.maximumf %in, %cst0f : f32 linalg.yield %1 : f32 } memref.dealloc %alloc_3 : memref<32x32xf32> } + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + scf.for %arg2 = %c0_index to %c32_index step %c1_index { + scf.for %arg3 = %c0_index to %c32_index step %c1_index { + %elem = memref.load %alloc_2[%arg0, %arg1, %arg2, %arg3] : memref<4x8x32x32xf32> + cpuruntime.printf "%f, " %elem : f32 + } + cpuruntime.printf "\n" + } + cpuruntime.printf "==================================\n" + } + } return } func.func @main() { call @simple_brgemm() : ()->() + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, + // CHECK: ================================== + // CHECK: 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, + // CHECK: ================================== + // CHECK: 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, + // CHECK: ================================== + // CHECK: 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, + // CHECK: ================================== + // CHECK: 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, + // CHECK: ================================== + // CHECK: 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, + // CHECK: ================================== + // CHECK: 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, + // CHECK: ================================== + // CHECK: 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, + // CHECK: ================================== + // CHECK: 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, + // CHECK: ================================== + // CHECK: 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, + // CHECK: ================================== + // CHECK: 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // CHECK: ================================== + // CHECK: 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // CHECK: ================================== cpuruntime.printf "BRGEMM DONE\n" return } diff --git a/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir b/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir index 8f9a6696f..b7f3ca6a7 100644 --- a/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir +++ b/test/mlir/test/gc/cpu-runner/brgemm-simple-for.mlir @@ -7,21 +7,31 @@ module { %c1_index = arith.constant 1 : index %c4_index = arith.constant 4 : index %c8_index = arith.constant 8 : index + %c32_index = arith.constant 32 : index %c0_i64 = arith.constant 0 : i64 %c16_i64 = arith.constant 16 : i64 - %cst = arith.constant 0.000000e+00 : f32 + %cst0f = arith.constant 0.000000e+00 : f32 + %cstn64f = arith.constant -64.000000e+00 : f32 + %cst1f = arith.constant 1.000000e+00 : f32 + %cst2f = arith.constant 2.000000e+00 : f32 + %cst3f = arith.constant 3.000000e+00 : f32 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc : memref<4x16x32x32xf32>) + linalg.fill ins(%cst1f : f32) outs(%alloc : memref<4x16x32x32xf32>) %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_0 : memref<8x16x32x32xf32>) + linalg.fill ins(%cst2f : f32) outs(%alloc_0 : memref<8x16x32x32xf32>) %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) + linalg.fill ins(%cst3f : f32) outs(%alloc_1 : memref<4x8x32x32xf32>) %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<4x8x32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) + linalg.fill ins(%cst0f : f32) outs(%alloc_2 : memref<4x8x32x32xf32>) scf.for %arg0 = %c0_index to %c4_index step %c1_index { scf.for %arg1 = %c0_index to %c8_index step %c1_index { %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - linalg.fill ins(%cst : f32) outs(%alloc_3 : memref<32x32xf32>) + %arg0i = arith.index_castui %arg0 : index to i64 + %arg1i = arith.index_castui %arg1 : index to i64 + %argmulti = arith.muli %arg0i, %arg1i : i64 + %v = arith.uitofp %argmulti : i64 to f32 + %v1 = arith.mulf %v, %cstn64f : f32 + linalg.fill ins(%v1 : f32) outs(%alloc_3 : memref<32x32xf32>) %subview = memref.subview %alloc[%arg0, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> %subview_4 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> %0 = microkernel.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags(stride) data_type(f32, f32) @@ -37,12 +47,24 @@ module { %subview_6 = memref.subview %alloc_2[%arg0, %arg1, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%alloc_3 : memref<32x32xf32>) outs(%subview_6 : memref<32x32xf32, strided<[32, 1], offset: ?>>) { ^bb0(%in: f32, %out: f32): - %1 = arith.maximumf %in, %cst : f32 + %1 = arith.maximumf %in, %cst0f : f32 linalg.yield %1 : f32 } memref.dealloc %alloc_3 : memref<32x32xf32> } } + scf.for %arg0 = %c0_index to %c4_index step %c1_index { + scf.for %arg1 = %c0_index to %c8_index step %c1_index { + scf.for %arg2 = %c0_index to %c32_index step %c1_index { + scf.for %arg3 = %c0_index to %c32_index step %c1_index { + %elem = memref.load %alloc_2[%arg0, %arg1, %arg2, %arg3] : memref<4x8x32x32xf32> + cpuruntime.printf "%f, " %elem : f32 + } + cpuruntime.printf "\n" + } + cpuruntime.printf "==================================\n" + } + } return } @@ -52,5 +74,67 @@ module { return } + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, 963.000000, + // CHECK: ================================== + // CHECK: 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, + // CHECK: ================================== + // CHECK: 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, + // CHECK: ================================== + // CHECK: 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, 707.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, 579.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, 899.000000, + // CHECK: ================================== + // CHECK: 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, 771.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, 515.000000, + // CHECK: ================================== + // CHECK: 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, 387.000000, + // CHECK: ================================== + // CHECK: 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, + // CHECK: ================================== + // CHECK: 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, 131.000000, + // CHECK: ================================== + // CHECK: 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, 1027.000000, + // CHECK: ================================== + // CHECK: 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, 835.000000, + // CHECK: ================================== + // CHECK: 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, 643.000000, + // CHECK: ================================== + // CHECK: 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, 451.000000, + // CHECK: ================================== + // CHECK: 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, 259.000000, + // CHECK: ================================== + // CHECK: 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, 67.000000, + // CHECK: ================================== + // CHECK: 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // CHECK: ================================== + // CHECK: 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, + // CHECK: ================================== // CHECK: BRGEMM DONE } From eeb8e1f00dc63d63a4e91b0c36df9dcec7c69622 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Fri, 30 Aug 2024 02:08:28 -0700 Subject: [PATCH 89/93] add microkernel passes to pipeline --- lib/gc/Transforms/CMakeLists.txt | 1 + lib/gc/Transforms/Microkernel/CMakeLists.txt | 2 +- lib/gc/Transforms/Pipeline.cpp | 16 +++++++++------- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 243e17580..484c2ef4b 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(Utils) gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR MLIRSupport + MLIRMicrokernelTransforms MLIRBufferizationToMemRef MLIRBufferizationPipelines) diff --git a/lib/gc/Transforms/Microkernel/CMakeLists.txt b/lib/gc/Transforms/Microkernel/CMakeLists.txt index a555a752a..9dbac3e94 100644 --- a/lib/gc/Transforms/Microkernel/CMakeLists.txt +++ b/lib/gc/Transforms/Microkernel/CMakeLists.txt @@ -1,4 +1,4 @@ -gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR) +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR MLIRMicrokernel) include(onednn) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index fef58eae0..b63886fa0 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -30,6 +30,7 @@ #ifdef GC_HAS_ONEDNN_DIALECT #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #endif +#include "gc/Transforms/Microkernel/MicrokernelPasses.h" #include "gc/Transforms/Passes.h" namespace mlir::gc { @@ -57,6 +58,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // Fine-grain fusion pass pm.addNestedPass(createIterativeTilingAndFusion()); // todo: fine-grain fusion pass + pm.addNestedPass( + mlir::microkernel::createConvertLinalgToMicrokernel()); // todo: lower linalg to arith/math on virtual vector pass // REMOVE this pass after the above passes are added. Currently we add this @@ -120,13 +123,12 @@ void populateBufferizationPasses(mlir::OpPassManager &pm) { // scf + arith + math + vector + memref + func/microkernel void populateMicroKernelPasses(mlir::OpPassManager &pm) { - // todo: ConvertLinalgToMicrokernel pass - // todo: CleanupInvalidMicrokernel pass - // todo: InvariantMicrokernelMotion pass - // todo: ConvertMicrokernelToDnnlFunc to lower brgemm to dnnl call - // todo: ConvertMicrokernelToXsmm, to lower brgemm to libxsmm call - // todo: LowerMicrokernel pass - // todo: DispatchMicrokernel + pm.addNestedPass(mlir::microkernel::createExpandMicrokernel()); + pm.addPass(mlir::microkernel::createEarlyDispatchMicrokernel()); + pm.addPass(mlir::microkernel::createConvertMicrokernelToDnnlFunc()); + pm.addPass(mlir::microkernel::createMergeBranchMicrokernelContext()); + pm.addPass(mlir::microkernel::createMicrokernelInvariantCodeMotion()); + populateCleanUpPasses(pm); } void populateCPURuntimePasses(mlir::OpPassManager &pm) { From 4610480d6859d7f61f2287fb9f00a6518951f24c Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Tue, 3 Sep 2024 01:46:37 -0700 Subject: [PATCH 90/93] fix per review --- .../Microkernel/EarlyDispatchMicrokernel.cpp | 16 +++- .../MicrokernelInvariantCodeMotion.cpp | 3 - .../merge-branch-microkernel-context.mlir | 90 +++++++++---------- .../microkernel-invariant-code-motion.mlir | 40 ++++----- 4 files changed, 79 insertions(+), 70 deletions(-) diff --git a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp index 4b4b169ef..5fc6074a4 100644 --- a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp @@ -33,16 +33,28 @@ createGlobalKernelHandleName(RewriterBase &rewriter, std::stringstream ss; ss << "g_dispatched_microkernel_brgemm"; + bool isInit = false; + bool isStrideMode = false; auto flags = op.getFlagsAttr(); for (auto flag : flags) { auto brgemmFlag = dyn_cast_or_null(flag); if (!brgemmFlag) return failure(); if (brgemmFlag.getValue() == BrgemmFlags::LIST) + // TODO(haixin): Currently not supported. Support list brgemm in the + // future return failure(); - if (brgemmFlag.getValue() == BrgemmFlags::BETA_0) - ss << "_init"; + else if (brgemmFlag.getValue() == BrgemmFlags::STRIDE) + isStrideMode = true; + else if (brgemmFlag.getValue() == BrgemmFlags::BETA_0) + isInit = true; } + if (isStrideMode) + ss << "_stride"; + else + ss << "_list"; + if (isInit) + ss << "_init"; // M, N, K, LDA, LDB, LDC, stride_a, stride_b // they are in the same order with BrgemmDispatchOp inputs diff --git a/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp index 1eaa18980..ad8a0631f 100644 --- a/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp +++ b/lib/gc/Transforms/Microkernel/MicrokernelInvariantCodeMotion.cpp @@ -171,9 +171,6 @@ class BrgemmTilereleaseRewriter : public OpRewritePattern { if (targetInfoIter->second.hasTilereleased) { rewriter.eraseOp(op); } else { - // auto region = opStructInfo.maxInvariantScope->getRegion(0); - // auto block = ®ion.getBlocks().front(); - // auto enditer = block->end(); // rewriter.moveOpBefore(op, block, enditer); rewriter.moveOpAfter(op, opStructInfo.maxInvariantScope); // Mark all sub scope as released to avoid duplicate Tilerelease diff --git a/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir b/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir index f2dece5d4..7429ec970 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/merge-branch-microkernel-context.mlir @@ -1,10 +1,10 @@ // RUN: gc-opt %s -merge-branch-microkernel-context -split-input-file | FileCheck %s module { - llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32]} - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 0.000000e+00 : f32 @@ -13,9 +13,9 @@ module { llvm.store %1, %0 : i64, !llvm.ptr llvm.return %1 : i64 } - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 1.000000e+00 : f32 @@ -35,8 +35,8 @@ module { %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr - %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %2 = llvm.load %1 : !llvm.ptr -> i64 %3 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> @@ -89,10 +89,10 @@ module { // ----- module { - llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 0.000000e+00 : f32 @@ -112,7 +112,7 @@ module { %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %1 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> @@ -159,10 +159,10 @@ module { // ----- module { - llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2_ctor, @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32]} - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2_ctor, @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 0.000000e+00 : f32 @@ -171,9 +171,9 @@ module { llvm.store %1, %0 : i64, !llvm.ptr llvm.return %1 : i64 } - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c512_i64 = arith.constant 512 : i64 %cst = arith.constant 1.000000e+00 : f32 @@ -193,8 +193,8 @@ module { %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr - %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr + %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %2 = llvm.load %1 : !llvm.ptr -> i64 %3 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> @@ -249,10 +249,10 @@ module { // ----- module { - llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32]} - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 0.000000e+00 : f32 @@ -261,9 +261,9 @@ module { llvm.store %1, %0 : i64, !llvm.ptr llvm.return %1 : i64 } - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 1.000000e+00 : f32 @@ -283,8 +283,8 @@ module { %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr - %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %2 = llvm.load %1 : !llvm.ptr -> i64 %3 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> @@ -347,10 +347,10 @@ module { // ----- module { - llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32, 2147483647 : i32]} - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor, @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2_ctor], priorities = [2147483647 : i32, 2147483647 : i32, 2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c512_i64 = arith.constant 512 : i64 %cst = arith.constant 1.000000e+00 : f32 @@ -359,9 +359,9 @@ module { llvm.store %1, %0 : i64, !llvm.ptr llvm.return %1 : i64 } - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 0.000000e+00 : f32 @@ -370,9 +370,9 @@ module { llvm.store %1, %0 : i64, !llvm.ptr llvm.return %1 : i64 } - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 1.000000e+00 : f32 @@ -392,9 +392,9 @@ module { %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr - %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr - %2 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %1 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_init_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %2 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_512_512_2_2 : !llvm.ptr %3 = llvm.load %2 : !llvm.ptr -> i64 %4 = llvm.load %1 : !llvm.ptr -> i64 %5 = llvm.load %0 : !llvm.ptr -> i64 diff --git a/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir b/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir index a03729d5f..6cd0173c8 100644 --- a/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir +++ b/test/mlir/test/gc/Dialect/Microkernel/microkernel-invariant-code-motion.mlir @@ -1,10 +1,10 @@ // RUN: gc-opt %s -microkernel-invariant-code-motion -split-input-file | FileCheck %s module { - llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 1.000000e+00 : f32 @@ -21,7 +21,7 @@ module { %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 %c16_i64 = arith.constant 16 : i64 - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %1 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> @@ -60,10 +60,10 @@ module { // ----- module { - llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 1.000000e+00 : f32 @@ -84,7 +84,7 @@ module { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c2_i64 = arith.constant 2 : i64 - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %1 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> @@ -146,10 +146,10 @@ module { // ----- module { - llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 1.000000e+00 : f32 @@ -170,7 +170,7 @@ module { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c2_i64 = arith.constant 2 : i64 - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %1 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> @@ -239,10 +239,10 @@ module { // ----- module { - llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} - llvm.mlir.global internal @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 - llvm.func @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + llvm.mlir.global_ctors {ctors = [@g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor], priorities = [2147483647 : i32]} + llvm.mlir.global internal @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2() {addr_space = 0 : i32} : i64 + llvm.func @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2_ctor() -> i64 { + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %c32_i64 = arith.constant 32 : i64 %c1024_i64 = arith.constant 1024 : i64 %cst = arith.constant 1.000000e+00 : f32 @@ -262,7 +262,7 @@ module { %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr + %0 = llvm.mlir.addressof @g_dispatched_microkernel_brgemm_stride_32_32_32_32_32_32_1024_1024_2_2 : !llvm.ptr %1 = llvm.load %0 : !llvm.ptr -> i64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x16x32x32xbf16> %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x16x16x32x2xbf16> From a853cd8b2f5d5227817bdfe209d75b064ff167fa Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 4 Sep 2024 00:07:14 -0700 Subject: [PATCH 91/93] ignore upstream linalg op with invalid input --- lib/gc/Dialect/Microkernel/MicrokernelOps.cpp | 95 +++++++++++-------- .../ConvertLinalgToMicrokernel.cpp | 20 +++- 2 files changed, 70 insertions(+), 45 deletions(-) diff --git a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp index 20f736fc7..9dc940fcd 100644 --- a/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp +++ b/lib/gc/Dialect/Microkernel/MicrokernelOps.cpp @@ -213,6 +213,42 @@ static LogicalResult verifyBrgemmFlags(ArrayAttr flags, Operation *op, return success(); } +static bool isTypeSupported(Type outType, Type operandAType, + Type operandBType) { + if (!outType.isF32() && !outType.isSignedInteger(32)) + return false; + + if (outType.isF32()) { + if (!(operandAType.isF32() && operandBType.isF32()) && + !(operandAType.isBF16() && operandBType.isBF16())) + return false; + } + if (outType.isSignedInteger(32)) { + if (!(operandAType.isSignedInteger(8) || + operandAType.isUnsignedInteger(8)) && + (operandBType.isSignedInteger(8) || operandBType.isUnsignedInteger(8))) + return false; + } + return true; +} + +// TODO(haixin): could use compiler-wide VNNI utils? +static bool isInVnniLayout(ShapedType type) { + if (!type.getElementType().isBF16() && + !type.getElementType().isSignedInteger(8) && + !type.getElementType().isUnsignedInteger(8)) + return false; + + auto blockingFactor = 0; + if (type.getElementType().isBF16()) + blockingFactor = 2; + else if (type.getElementType().isSignedInteger(8) || + type.getElementType().isUnsignedInteger(8)) + blockingFactor = 4; + + return type.getShape().back() == blockingFactor; +} + ///////////////////////////////////////////////////// // Start of BrgemmOp @@ -308,9 +344,8 @@ static inline ArrayRef getShapedValueShape(Value val) { assert((llvm::isa(val.getType()) || llvm::isa(val.getType())) && "Expecting shaped value"); - if (auto tensorTy = dyn_cast_or_null(val.getType())) { + if (auto tensorTy = dyn_cast_or_null(val.getType())) return tensorTy.getShape(); - } auto memrefTy = dyn_cast_or_null(val.getType()); return memrefTy.getShape(); } @@ -331,15 +366,27 @@ LogicalResult BrgemmOp::verify() { return op.emitOpError() << "expect inputs and its related info to be size 2\n"; + auto elemTypeA = getElementTypeOrSelf(ins[0]); + auto elemTypeB = getElementTypeOrSelf(ins[1]); + auto elemTypeC = getElementTypeOrSelf(out); + if (!isTypeSupported(elemTypeC, elemTypeA, elemTypeB)) + return op.emitOpError() << "unsupported input matrix types\n"; + ArrayRef dimA = getShapedValueShape(ins[0]); ArrayRef dimB = getShapedValueShape(ins[1]); ArrayRef dimC = getShapedValueShape(out); if (dimA.size() != 3) return op.emitOpError() << "expect input A to be 3D\n"; - if (dimB.size() != 3 && dimB.size() != 4) - return op.emitOpError() << "expect input B to be 3D or 4D\n"; - if (dimB.size() == 4 && (dimB[3] != 2 && dimB[3] != 4)) - return op.emitOpError() << "expect input B vnni step to be 2 or 4\n"; + if (!elemTypeB.isF32()) { + if (dimB.size() != 4 || + !isInVnniLayout(dyn_cast(ins[1].getType()))) + return op.emitOpError() + << "expect a 4d VNNI input B for non-F32 operand: " << ins[1]; + } else { + if (dimB.size() != 3) + return op.emitOpError() + << "expect a 3d input B for F32 operand: " << ins[1]; + } if (dimC.size() != 2) return op.emitOpError() << "expect input C to be 2D\n"; for (auto dim : batchDims) @@ -558,42 +605,6 @@ LogicalResult BrgemmDispatchOp::verify() { ///////////////////////////////////////////////////// // Start of BrgemmExecuteOp -// TODO(haixin): could use compiler-wide VNNI utils? -static bool isInVnniLayout(MemRefType memref) { - if (!memref.getElementType().isBF16() && - !memref.getElementType().isSignedInteger(8) && - !memref.getElementType().isUnsignedInteger(8)) - return false; - - auto blockingFactor = 0; - if (memref.getElementType().isBF16()) - blockingFactor = 2; - else if (memref.getElementType().isSignedInteger(8) || - memref.getElementType().isUnsignedInteger(8)) - blockingFactor = 4; - - return memref.getShape().back() == blockingFactor; -} - -static bool isTypeSupported(Type outType, Type operandAType, - Type operandBType) { - if (!outType.isF32() && !outType.isSignedInteger(32)) - return false; - - if (outType.isF32()) { - if (!(operandAType.isF32() && operandBType.isF32()) && - !(operandAType.isBF16() && operandBType.isBF16())) - return false; - } - if (outType.isSignedInteger(32)) { - if (!(operandAType.isSignedInteger(8) || - operandAType.isUnsignedInteger(8)) && - (operandBType.isSignedInteger(8) || operandBType.isUnsignedInteger(8))) - return false; - } - return true; -} - LogicalResult BrgemmExecuteOp::verify() { BrgemmExecuteOp &brgemmOp = *this; diff --git a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp index 980fe8288..f89a31179 100644 --- a/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp @@ -157,6 +157,23 @@ static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { else return failure(); + OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; + OpOperand *operandB = linalgOp.getDpsInputOperands()[1]; + Type operandBElemType = getElementTypeOrSelf(operandB->get()); + if (operandBElemType.isF32()) { + if (kAffinePos.size() == 2) { + LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions for input " + "B, should be non-VNNI\n"); + return failure(); + } + } else { + if (kAffinePos.size() == 1) { + LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions for input " + "B, should be VNNI\n"); + return failure(); + } + } + LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] Candidate dims: " << "\n"); LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] m pos in affine: " << mAffinePos @@ -169,9 +186,6 @@ static FailureOr inferBrgemmDims(linalg::LinalgOp linalgOp) { LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] batch pos in affine: " << batchAffinePos << "\n"); - OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; - OpOperand *operandB = linalgOp.getDpsInputOperands()[1]; - BrgemmDims brgemmDims; #define CHECK_GET_POS_IN_DOMAIN(dim, dimPos, operand) \ From fedd427e4233a7fdace462d9a74dba8c30975c89 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 4 Sep 2024 00:12:11 -0700 Subject: [PATCH 92/93] add TODO comments --- lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp index 5fc6074a4..2f66feee4 100644 --- a/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp +++ b/lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp @@ -58,6 +58,7 @@ createGlobalKernelHandleName(RewriterBase &rewriter, // M, N, K, LDA, LDB, LDC, stride_a, stride_b // they are in the same order with BrgemmDispatchOp inputs + // TODO(haixin): Add order enforcement machanism for BrgemmDispatchOp ArrayRef inputs = op.getInputsAttr().asArrayRef(); for (auto input : inputs) { ss << "_" << input; From eb94d05026d22a6f518ce6f99940c4df181ce899 Mon Sep 17 00:00:00 2001 From: "Huang, Haixin" Date: Wed, 4 Sep 2024 01:34:32 -0700 Subject: [PATCH 93/93] fix correctness check --- lib/gc/Transforms/Utils/ValueUtils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/gc/Transforms/Utils/ValueUtils.cpp b/lib/gc/Transforms/Utils/ValueUtils.cpp index 8750042ee..0f26bdd42 100644 --- a/lib/gc/Transforms/Utils/ValueUtils.cpp +++ b/lib/gc/Transforms/Utils/ValueUtils.cpp @@ -120,7 +120,7 @@ std::pair getPtrAndOffset(OpBuilder &builder, Value operand) { auto memrefType = dyn_cast(operand.getType()); assert(memrefType && "Expect a memref value"); - Location loc = operand.getDefiningOp()->getLoc(); + Location loc = operand.getLoc(); OpBuilder::InsertionGuard guard(builder); // Insert right after operand producer for better opt chances. builder.setInsertionPointAfterValue(operand);