Skip to content

Commit

Permalink
[mlir:Linalg] Populate LinalgOp patterns on LinalgDialect as opposed …
Browse files Browse the repository at this point in the history
…to each op

Interface patterns are unique in that they get added to every operation that also implements that interface, given that they aren't tied to individual operations. When the same interface pattern gets added to multiple operations (such as the current behavior with Linalg), an reference to each of these patterns is added to every op (meaning that an operation will now have N references to effectively the same pattern). This revision fixes this problematic behavior in Linalg, and can bring upwards of a 25% reduction in compile time in Linalg based workloads.

Differential Revision: https://reviews.llvm.org/D104160
  • Loading branch information
River707 committed Jun 14, 2021
1 parent 75d3b46 commit 66e2708
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 43 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def Linalg_Dialect : Dialect {
let dependentDialects = [
"AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
];
let hasCanonicalizer = 1;
let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
/// Attribute name used to to memoize indexing maps for named ops.
Expand Down
5 changes: 0 additions & 5 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
}];

let hasFolder = 1;
let hasCanonicalizer = 1;
let skipDefaultBuilders = 1;
}

Expand Down Expand Up @@ -230,7 +229,6 @@ def FillOp : LinalgStructured_Op<"fill", []> {
let verifier = [{ return ::verify(*this); }];

let hasFolder = 1;
let hasCanonicalizer = 1;
}

/// A base class for pooling operation such as conv. The arguments must contain
Expand Down Expand Up @@ -427,7 +425,6 @@ def ConvOp : PoolingBase_Op<"conv", []> {
let verifier = [{ return ::verify(*this); }];

let hasFolder = 1;
let hasCanonicalizer = 1;
}

// Only support buffer semantics.
Expand Down Expand Up @@ -490,7 +487,6 @@ class SingleInputPoolingBase_Op<string mnemonic>
let verifier = [{ return ::verify(*this); }];

let hasFolder = 1;
let hasCanonicalizer = 1;
}

def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> {
Expand Down Expand Up @@ -673,7 +669,6 @@ def GenericOp : GenericOpBase<"generic"> {
let verifier = [{ return ::verify(*this); }];

let hasFolder = 1;
let hasCanonicalizer = 1;
}

/// GenericOp with Indexing (i.e. multi-for style in which the region is passed
Expand Down
37 changes: 18 additions & 19 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2787,11 +2787,6 @@ DEFINE_POOLING_OP_GET_EFFECTS(PoolingMaxOp)
DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp)
DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp)

namespace {
struct EraseDeadLinalgOp;
struct FoldTensorCastOp;
} // namespace

#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc"
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"

Expand Down Expand Up @@ -3374,25 +3369,29 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
};
} // namespace

#define CANONICALIZERS_AND_FOLDERS(XXX) \
void XXX::getCanonicalizationPatterns(RewritePatternSet &results, \
MLIRContext *context) { \
results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
RemoveIdentityLinalgOps>(context); \
} \
\
#define LINALGOP_FOLDERS(XXX) \
LogicalResult XXX::fold(ArrayRef<Attribute>, \
SmallVectorImpl<OpFoldResult> &) { \
return foldMemRefCast(*this); \
}

CANONICALIZERS_AND_FOLDERS(ConvOp)
CANONICALIZERS_AND_FOLDERS(PoolingMaxOp)
CANONICALIZERS_AND_FOLDERS(PoolingMinOp)
CANONICALIZERS_AND_FOLDERS(PoolingSumOp)
CANONICALIZERS_AND_FOLDERS(CopyOp)
CANONICALIZERS_AND_FOLDERS(FillOp)
CANONICALIZERS_AND_FOLDERS(GenericOp)
LINALGOP_FOLDERS(ConvOp)
LINALGOP_FOLDERS(PoolingMaxOp)
LINALGOP_FOLDERS(PoolingMinOp)
LINALGOP_FOLDERS(PoolingSumOp)
LINALGOP_FOLDERS(CopyOp)
LINALGOP_FOLDERS(FillOp)
LINALGOP_FOLDERS(GenericOp)

// All named ops canonicalizers and folders are auto-generated in the
// .cpp.inc.

//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//

void LinalgDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,
RemoveIdentityLinalgOps>(getContext());
}
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
patterns);
}

void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
CanonicalizationPatternList<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
Expand Down
11 changes: 2 additions & 9 deletions mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1959,7 +1959,6 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let extraClassDeclaration = structuredOpsBaseDecls # [{{
// Auto-generated.
Expand Down Expand Up @@ -2094,13 +2093,7 @@ void TCParser::printReferenceIterators(llvm::raw_ostream &os,

void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
StringRef cppOpName) {
const char *canonicalizersAndFoldersFmt = R"FMT(
void {0}::getCanonicalizationPatterns(
RewritePatternSet &results,
MLIRContext *context) {{
results.add<EraseDeadLinalgOp>(context);
results.add<FoldTensorCastOp>(context);
}
const char *foldersFmt = R"FMT(
LogicalResult {0}::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {{
return foldMemRefCast(*this);
Expand All @@ -2112,7 +2105,7 @@ void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
getGenericEffectsImpl(effects,
getOperation()->getResults(), inputBuffers, outputBuffers);
})FMT";
os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName);
os << llvm::formatv(foldersFmt, cppOpName);
}

// Prints methods for querying whether the current named op has attributes that
Expand Down
13 changes: 3 additions & 10 deletions mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,6 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let extraClassDeclaration = structuredOpsBaseDecls # [{{
// Auto-generated.
Expand Down Expand Up @@ -535,16 +534,10 @@ ArrayAttr {0}::iterator_types() {
}
)FMT";

// Implementations of getCanonicalizationPatterns, fold and getEffects.
// Implementations of fold and getEffects.
// Parameters:
// {0}: Class name
const char structuredOpCanonicalizersAndFoldersFormat[] = R"FMT(
void {0}::getCanonicalizationPatterns(
RewritePatternSet &results,
MLIRContext *context) {{
results.add<EraseDeadLinalgOp>(context);
results.add<FoldTensorCastOp>(context);
}
const char structuredOpFoldersFormat[] = R"FMT(
LogicalResult {0}::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {{
return foldMemRefCast(*this);
Expand Down Expand Up @@ -880,7 +873,7 @@ void {0}::regionBuilder(
}

// Canonicalizers and folders.
os << llvm::formatv(structuredOpCanonicalizersAndFoldersFormat, className);
os << llvm::formatv(structuredOpFoldersFormat, className);

return success();
}
Expand Down

0 comments on commit 66e2708

Please sign in to comment.