Skip to content

Commit

Permalink
[mlir][std] Canonicalize a dim(memref_reshape) into a load from the s…
Browse files Browse the repository at this point in the history
…hape operand

This canonicalization helps propagate shape information through the program.

Differential Revision: https://reviews.llvm.org/D91854
  • Loading branch information
Stephan Herhut committed Nov 20, 2020
1 parent dfd2858 commit a89e55c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 0 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -1753,6 +1753,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
Optional<int64_t> getConstantIndex();
}];

let hasCanonicalizer = 1;
let hasFolder = 1;
}

Expand Down
28 changes: 28 additions & 0 deletions mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1555,6 +1555,34 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return {};
}

namespace {
/// Fold dim of a memref reshape operation to a load into the reshape's shape
/// operand.
struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;

LogicalResult matchAndRewrite(DimOp dim,
PatternRewriter &rewriter) const override {
auto reshape = dim.memrefOrTensor().getDefiningOp<MemRefReshapeOp>();

if (!reshape)
return failure();

// Place the load directly after the reshape to ensure that the shape memref
// was not mutated.
rewriter.setInsertionPointAfter(reshape);
rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
llvm::makeArrayRef({dim.index()}));
return success();
}
};
} // end anonymous namespace.

void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<DimOfMemRefReshape>(context);
}

// ---------------------------------------------------------------------------
// DmaStartOp
// ---------------------------------------------------------------------------
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Dialect/Standard/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,23 @@ func @cmpi_equal_operands(%arg0: i64)
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
: i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
}

// Test case: Folding of dim(memref_reshape %v %shp, %idx) -> load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
// CHECK-NEXT: %[[IDX:.*]] = constant 3
// CHECK-NEXT: %[[DIM:.*]] = load %[[SHP]][%[[IDX]]]
// CHECK-NEXT: store
// CHECK-NOT: dim
// CHECK: return %[[DIM]] : index
func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
-> index {
%c3 = constant 3 : index
%0 = memref_reshape %arg0(%arg1)
: (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
// Update the shape to test that he load ends up in the right place.
store %c3, %arg1[%c3] : memref<?xindex>
%1 = dim %0, %c3 : memref<*xf32>
return %1 : index
}

0 comments on commit a89e55c

Please sign in to comment.