Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization #84225

Merged
merged 8 commits into from
Mar 12, 2024
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ def MemRef_DimOp : MemRef_Op<"dim", [
Speculation::Speculatability getSpeculatability();
}];

let hasCanonicalizer = 1;
let hasFolder = 1;
}

Expand Down
46 changes: 46 additions & 0 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,52 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
return {};
}

namespace {
/// Fold dim of a memref reshape operation to a load into the reshape's shape
/// operand.
struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;

LogicalResult matchAndRewrite(DimOp dim,
PatternRewriter &rewriter) const override {
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();

if (!reshape)
return rewriter.notifyMatchFailure(
dim, "Dim op is not defined by a reshape op.");

if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
sahas3 marked this conversation as resolved.
Show resolved Hide resolved
if (auto *definingOp = dim.getIndex().getDefiningOp()) {
if (reshape->isBeforeInBlock(definingOp))
return rewriter.notifyMatchFailure(
dim,
"dim.getIndex is not defined before reshape in the same block.");
} // else dim.getIndex is a block argument to reshape->getBlock
} else if (!dim.getIndex().getParentRegion()->isProperAncestor(
sahas3 marked this conversation as resolved.
Show resolved Hide resolved
reshape->getParentRegion()))
return rewriter.notifyMatchFailure(
dim, "dim.getIndex does not dominate reshape.");
sahas3 marked this conversation as resolved.
Show resolved Hide resolved

// Place the load directly after the reshape to ensure that the shape memref
// was not mutated.
rewriter.setInsertionPointAfter(reshape);
Location loc = dim.getLoc();
Value load =
rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
if (load.getType() != dim.getType())
load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
rewriter.replaceOp(dim, load);
return success();
}
};

} // namespace

void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfMemRefReshape>(context);
}

// ---------------------------------------------------------------------------
// DmaStartOp
// ---------------------------------------------------------------------------
Expand Down
95 changes: 95 additions & 0 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,101 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {

// -----

// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
// CHECK-NEXT: memref.store
// CHECK-NOT: memref.dim
// CHECK: return %[[DIM]] : index
func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
-> index {
%c3 = arith.constant 3 : index
%0 = memref.reshape %arg0(%arg1)
: (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
// Update the shape to test that he load ends up in the right place.
memref.store %c3, %arg1[%c3] : memref<?xindex>
%1 = memref.dim %0, %c3 : memref<*xf32>
return %1 : index
}

// -----

// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_i32(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]]
// CHECK-NOT: memref.dim
// CHECK: return %[[CAST]] : index
func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
-> index {
%c3 = arith.constant 3 : index
%0 = memref.reshape %arg0(%arg1)
: (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
%1 = memref.dim %0, %c3 : memref<*xf32>
return %1 : index
}

// -----

// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>,
// CHECK-SAME: %[[IDX:[0-9a-z]+]]: index
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
// CHECK-NOT: memref.dim
// CHECK: return %[[DIM]] : index
func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
%dim = memref.dim %reshape, %arg2 : memref<*xf32>
return %dim : index
}

// -----

// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_for(
// CHECK: memref.reshape
// CHECK: memref.dim
// CHECK-NOT: memref.load
func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref<?xindex>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index

%0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>

%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
%2 = memref.dim %0, %arg2 : memref<*xf32>
%3 = arith.muli %arg3, %2 : index
scf.yield %3 : index
}
return %1 : index
}

// -----

// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_undominated(
// CHECK: memref.reshape
// CHECK: memref.dim
// CHECK-NOT: memref.load
func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
%c4 = arith.constant 4 : index
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
%0 = arith.muli %arg2, %c4 : index
%dim = memref.dim %reshape, %0 : memref<*xf32>
return %dim : index
}

// -----

// CHECK-LABEL: func @alloc_const_fold
func.func @alloc_const_fold() -> memref<?xf32> {
// CHECK-NEXT: memref.alloc() : memref<4xf32>
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2294,7 +2294,7 @@ func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)

// -----

// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_for(
// CHECK: scf.for
// CHECK-NEXT: tensor.extract
Expand All @@ -2317,7 +2317,7 @@ func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) ->

// -----

// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_undominated(
// CHECK: arith.muli
// CHECK-NEXT: tensor.extract
Expand Down
Loading