Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
[mlir][Linalg] Introduce folding patterns to remove certain MemRefCastOp
Summary:
Canonicalization and folding patterns in StandardOps may interfere with the needs
of Linalg. This revision introduces specific foldings for dynamic memrefs that can
be proven to be static.

Very concretely:

Determines whether it is possible to fold it away in the parent Linalg op:

```mlir
  %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
  %2 = linalg.slice %1 ... : memref<?x?xf32> ...
  // or
  %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
         to memref<?x?xf32>
  linalg.generic(%1 ...) : memref<?x?xf32> ...
```

into

```mlir
  %2 = linalg.slice %0 ... : memref<8x16xf32> ...
  // or
  linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
```

Reviewers: ftynse, aartbik, jsetoain, tetuante, asaadaldien

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73565
  • Loading branch information
Nicolas Vasilache committed Jan 29, 2020
1 parent 02adfb5 commit ea1e336
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Expand Up @@ -117,6 +117,8 @@ def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
static StringRef getReassociationAttrName() { return "reassociation"; }
MemRefType getViewType() { return view().getType().cast<MemRefType>(); }
}];

let hasFolder = 1;
}

def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
Expand Down Expand Up @@ -188,6 +190,8 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
return res;
}
}];

let hasFolder = 1;
}

def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
Expand Down Expand Up @@ -222,6 +226,8 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
static StringRef getPermutationAttrName() { return "permutation"; }
ShapedType getShapedType() { return view().getType().cast<ShapedType>(); }
}];

let hasFolder = 1;
}

def Linalg_YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
Expand Down
19 changes: 19 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Expand Up @@ -270,6 +270,8 @@ def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> {
}
}];
let verifier = [{ return ::verify(*this); }];

let hasFolder = 1;
}

def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
Expand All @@ -287,6 +289,8 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
}
}];
let verifier = [{ return ::verify(*this); }];

let hasFolder = 1;
}

def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
Expand All @@ -302,6 +306,8 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
StringAttr::get(getReductionIteratorTypeName(), ctx), ctx);
}
}];

let hasFolder = 1;
}

def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
Expand All @@ -319,6 +325,8 @@ def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
return ArrayAttr::get(iters, ctx);
}
}];

let hasFolder = 1;
}

def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
Expand All @@ -337,6 +345,8 @@ def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
return ArrayAttr::get(iters, ctx);
}
}];

let hasFolder = 1;
}

def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
Expand Down Expand Up @@ -406,7 +416,10 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
.cast<IntegerAttr>().getValue().getSExtValue();
}
}];

let verifier = [{ return ::verify(*this); }];

let hasFolder = 1;
}

def LinalgOperand: Type<
Expand Down Expand Up @@ -583,7 +596,10 @@ def GenericOp : GenericOpBase<"generic"> {
tensor SSA values are expected to be useful and will be added in the near
future.
}];

let verifier = [{ return ::verify(*this); }];

let hasFolder = 1;
}

def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
Expand Down Expand Up @@ -710,7 +726,10 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
tensor SSA values are expected to be useful and will be added in the near
future.
}];

let verifier = [{ return ::verify(*this); }];

let hasFolder = 1;
}

#endif // LINALG_STRUCTURED_OPS
135 changes: 135 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
Expand All @@ -31,6 +32,89 @@
using namespace mlir;
using namespace mlir::linalg;

/// Determines whether it is possible to fold it away in the parent Linalg op:
///
/// ```mlir
/// %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
/// %2 = linalg.slice %1 ... : memref<?x?xf32> ...
/// // or
/// %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
/// to memref<?x?xf32>
/// linalg.generic(%1 ...) : memref<?x?xf32> ...
/// ```
///
/// into
///
/// ```mlir
/// %2 = linalg.slice %0 ... : memref<8x16xf32> ...
/// // or
/// linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
/// ```
///
static bool canFold(MemRefCastOp castOp) {
MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();

// If we don't have MemRefType as source and destination, bail out.
if (!sourceType || !resultType)
return false;

// If resultType has a map, it needs to be the same as the source type to
// canonicalize.
if (!resultType.getAffineMaps().empty() &&
sourceType.getAffineMaps() != resultType.getAffineMaps())
return false;

// Ensure that:
// 1. source is static
// 2. source and target have the same rank (will be extended when needed)
// 3. if result is partially static, ensure sizes match.
if (!sourceType.hasStaticShape() ||
sourceType.getRank() != resultType.getRank())
return false;

for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
auto sourceSize = std::get<0>(it);
auto resultSize = std::get<1>(it);
if (ShapedType::isDynamic(resultSize))
continue;
if (sourceSize != resultSize)
return false;
}

// If source has a map, it can only canonicalize if it is the canonical
// strided layout map.
if (sourceType.getAffineMaps().empty())
return true;

int64_t offset;
SmallVector<int64_t, 4> strides;
auto res = getStridesAndOffset(sourceType, strides, offset);
(void)res;
assert(succeeded(res));
auto stridedMap =
makeStridedLinearLayoutMap(strides, offset, castOp.getContext());
AffineMap sourceMap = sourceType.getAffineMaps().front();
return sourceMap == stridedMap;
}

/// This is a common class used for patterns of the form
/// ```
/// someop(memrefcast) -> someop
/// ```
/// It folds the source of any memref_cast into the root operation directly.
static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
if (castOp && canFold(castOp)) {
operand.set(castOp.getOperand());
folded = true;
}
}
return success(folded);
}

///////////////////// Operations defined with Tablegen /////////////////////////
// For such operations that do not correspond to library calls (i.e. defined in
// LinalgOps.td), we define an overloaded `print` function and a
Expand Down Expand Up @@ -1077,3 +1161,54 @@ ArrayAttr mlir::linalg::MatmulOp::indexing_maps() {
ArrayAttr mlir::linalg::MatvecOp::indexing_maps() {
return getIndexingMaps(getOperation());
}

// TODO(ntv, rriddle): Consider making all this boilerplate easy to autogenerate
// with Tablegen. This seems a desirable property in the context of OpInterfaces
// where a Linalg "named" op **isa** LinalgOp.
LogicalResult ConvOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult CopyOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult DotOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult FillOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult GenericOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult IndexedGenericOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return {};
}
OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return {};
}
OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return {};
}
20 changes: 20 additions & 0 deletions mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -0,0 +1,20 @@
// RUN: mlir-opt %s -canonicalize | FileCheck %s

// CHECK-LABEL: func @memref_cast(
func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c8 = constant 8 : index
%c16 = constant 16 : index
%1 = alloc (%b) : memref<?xi8>
%2 = view %1[][] : memref<?xi8> to memref<16x16xf32>
%3 = memref_cast %2 : memref<16x16xf32> to memref<?x?xf32>
%r0 = linalg.range %c0:%c8:%c1 : !linalg.range

// CHECK: linalg.slice {{.*}} : memref<16x16xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
%4 = linalg.slice %3[%r0, %r0] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32>

// CHECK: linalg.matmul{{.*}}: memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>
linalg.matmul(%3, %3, %3) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
return %4: memref<?x?xf32>
}

0 comments on commit ea1e336

Please sign in to comment.