From ea1e3369f7a8aa9729f8e2fc208b8f6a79392874 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 28 Jan 2020 13:44:37 -0500 Subject: [PATCH] [mlir][Linalg] Introduce folding patterns to remove certain MemRefCastOp Summary: Canonicalization and folding patterns in StandardOps may interfere with the needs of Linalg. This revision introduces specific foldings for dynamic memrefs that can be proven to be static. Very concretely: Determines whether it is possible to fold it away in the parent Linalg op: ```mlir %1 = memref_cast %0 : memref<8x16xf32> to memref %2 = linalg.slice %1 ... : memref ... // or %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>> to memref linalg.generic(%1 ...) : memref ... ``` into ```mlir %2 = linalg.slice %0 ... : memref<8x16xf32> ... // or linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>> ``` Reviewers: ftynse, aartbik, jsetoain, tetuante, asaadaldien Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D73565 --- .../mlir/Dialect/Linalg/IR/LinalgOps.td | 6 + .../Dialect/Linalg/IR/LinalgStructuredOps.td | 19 +++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 135 ++++++++++++++++++ mlir/test/Dialect/Linalg/canonicalize.mlir | 20 +++ 4 files changed, 180 insertions(+) create mode 100644 mlir/test/Dialect/Linalg/canonicalize.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 4c2344482164c7..0dec1d1b74fb9c 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -117,6 +117,8 @@ def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>, static StringRef getReassociationAttrName() { return "reassociation"; } MemRefType getViewType() { return view().getType().cast(); } }]; + + let hasFolder = 1; } def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>, @@ -188,6 +190,8 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>, return res; } }]; + + let hasFolder = 1; } def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>, @@ -222,6 +226,8 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>, static StringRef getPermutationAttrName() { return "permutation"; } ShapedType getShapedType() { return view().getType().cast(); } }]; + + let hasFolder = 1; } def Linalg_YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>, diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 2a2ef55d4fab86..03318fa48b9893 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -270,6 +270,8 @@ def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> { } }]; let verifier = [{ return ::verify(*this); }]; + + let hasFolder = 1; } def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { @@ -287,6 +289,8 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { } }]; let verifier = [{ return ::verify(*this); }]; + + let hasFolder = 1; } def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> { @@ -302,6 +306,8 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> { StringAttr::get(getReductionIteratorTypeName(), ctx), ctx); } }]; + + let hasFolder = 1; } def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> { @@ -319,6 +325,8 @@ def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> { return ArrayAttr::get(iters, ctx); } }]; + + let hasFolder = 1; } def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> { @@ -337,6 +345,8 @@ def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> { return ArrayAttr::get(iters, ctx); } }]; + + let hasFolder = 1; } def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { @@ -406,7 +416,10 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { .cast().getValue().getSExtValue(); } }]; + let verifier = [{ return ::verify(*this); }]; + + let hasFolder = 1; } def LinalgOperand: Type< @@ -583,7 +596,10 @@ def GenericOp : GenericOpBase<"generic"> { tensor SSA values are expected to be useful and will be added in the near future. }]; + let verifier = [{ return ::verify(*this); }]; + + let hasFolder = 1; } def IndexedGenericOp : GenericOpBase<"indexed_generic"> { @@ -710,7 +726,10 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> { tensor SSA values are expected to be useful and will be added in the near future. }]; + let verifier = [{ return ::verify(*this); }]; + + let hasFolder = 1; } #endif // LINALG_STRUCTURED_OPS diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index b1ffce62c1c4f1..8ed7e79df89e7e 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -31,6 +32,89 @@ using namespace mlir; using namespace mlir::linalg; +/// Determines whether it is possible to fold it away in the parent Linalg op: +/// +/// ```mlir +/// %1 = memref_cast %0 : memref<8x16xf32> to memref +/// %2 = linalg.slice %1 ... : memref ... +/// // or +/// %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>> +/// to memref +/// linalg.generic(%1 ...) : memref ... +/// ``` +/// +/// into +/// +/// ```mlir +/// %2 = linalg.slice %0 ... : memref<8x16xf32> ... +/// // or +/// linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>> +/// ``` +/// +static bool canFold(MemRefCastOp castOp) { + MemRefType sourceType = castOp.source().getType().dyn_cast(); + MemRefType resultType = castOp.getType().dyn_cast(); + + // If we don't have MemRefType as source and destination, bail out. + if (!sourceType || !resultType) + return false; + + // If resultType has a map, it needs to be the same as the source type to + // canonicalize. + if (!resultType.getAffineMaps().empty() && + sourceType.getAffineMaps() != resultType.getAffineMaps()) + return false; + + // Ensure that: + // 1. source is static + // 2. source and target have the same rank (will be extended when needed) + // 3. if result is partially static, ensure sizes match. + if (!sourceType.hasStaticShape() || + sourceType.getRank() != resultType.getRank()) + return false; + + for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { + auto sourceSize = std::get<0>(it); + auto resultSize = std::get<1>(it); + if (ShapedType::isDynamic(resultSize)) + continue; + if (sourceSize != resultSize) + return false; + } + + // If source has a map, it can only canonicalize if it is the canonical + // strided layout map. + if (sourceType.getAffineMaps().empty()) + return true; + + int64_t offset; + SmallVector strides; + auto res = getStridesAndOffset(sourceType, strides, offset); + (void)res; + assert(succeeded(res)); + auto stridedMap = + makeStridedLinearLayoutMap(strides, offset, castOp.getContext()); + AffineMap sourceMap = sourceType.getAffineMaps().front(); + return sourceMap == stridedMap; +} + +/// This is a common class used for patterns of the form +/// ``` +/// someop(memrefcast) -> someop +/// ``` +/// It folds the source of any memref_cast into the root operation directly. +static LogicalResult foldMemRefCast(Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto castOp = dyn_cast_or_null(operand.get().getDefiningOp()); + if (castOp && canFold(castOp)) { + operand.set(castOp.getOperand()); + folded = true; + } + } + return success(folded); +} + ///////////////////// Operations defined with Tablegen ///////////////////////// // For such operations that do not correspond to library calls (i.e. defined in // LinalgOps.td), we define an overloaded `print` function and a @@ -1077,3 +1161,54 @@ ArrayAttr mlir::linalg::MatmulOp::indexing_maps() { ArrayAttr mlir::linalg::MatvecOp::indexing_maps() { return getIndexingMaps(getOperation()); } + +// TODO(ntv, rriddle): Consider making all this boilerplate easy to autogenerate +// with Tablegen. This seems a desirable property in the context of OpInterfaces +// where a Linalg "named" op **isa** LinalgOp. +LogicalResult ConvOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult CopyOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult DotOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult FillOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult GenericOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult IndexedGenericOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult MatvecOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult MatmulOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +OpFoldResult ReshapeOp::fold(ArrayRef) { + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return {}; +} +OpFoldResult SliceOp::fold(ArrayRef) { + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return {}; +} +OpFoldResult TransposeOp::fold(ArrayRef) { + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return {}; +} diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir new file mode 100644 index 00000000000000..370cf45311698c --- /dev/null +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt %s -canonicalize | FileCheck %s + +// CHECK-LABEL: func @memref_cast( +func @memref_cast(%a: index, %b: index) -> memref { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c8 = constant 8 : index + %c16 = constant 16 : index + %1 = alloc (%b) : memref + %2 = view %1[][] : memref to memref<16x16xf32> + %3 = memref_cast %2 : memref<16x16xf32> to memref + %r0 = linalg.range %c0:%c8:%c1 : !linalg.range + + // CHECK: linalg.slice {{.*}} : memref<16x16xf32>, !linalg.range, !linalg.range, memref + %4 = linalg.slice %3[%r0, %r0] : memref, !linalg.range, !linalg.range, memref + + // CHECK: linalg.matmul{{.*}}: memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32> + linalg.matmul(%3, %3, %3) : memref, memref, memref + return %4: memref +}