Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization #84225

Merged
merged 8 commits into from
Mar 12, 2024
32 changes: 31 additions & 1 deletion mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,37 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();

if (!reshape)
return failure();
return rewriter.notifyMatchFailure(
dim, "Dim op is not defined by a reshape op.");

// dim of a memref reshape can be folded if dim.getIndex() dominates the
// reshape. Instead of using `DominanceInfo` (which is usually costly) we
// cheaply check that either of the following conditions hold:
// 1. dim.getIndex() is defined in the same block as reshape but before
// reshape.
// 2. dim.getIndex() is defined in a parent block of
// reshape.

// Check condition 1
if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
sahas3 marked this conversation as resolved.
Show resolved Hide resolved
if (auto *definingOp = dim.getIndex().getDefiningOp()) {
if (reshape->isBeforeInBlock(definingOp)) {
return rewriter.notifyMatchFailure(
dim,
"dim.getIndex is not defined before reshape in the same block.");
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
} // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
// If dim and reshape are in the same block but dim.getIndex() isn't, we
// already know dim.getIndex() dominates reshape without calling
// `isProperAncestor`
return rewriter.notifyMatchFailure(
dim, "dim.getIndex does not dominate reshape.");
sahas3 marked this conversation as resolved.
Show resolved Hide resolved
}

// Place the load directly after the reshape to ensure that the shape memref
// was not mutated.
Expand Down
28 changes: 27 additions & 1 deletion mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,11 +824,37 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
return success();
}
};

/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
/// operand.
struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;

LogicalResult matchAndRewrite(DimOp dim,
PatternRewriter &rewriter) const override {
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();

if (!reshape)
return failure();

// Since tensors are immutable we don't need to worry about where to place
// the extract call
rewriter.setInsertionPointAfter(dim);
Location loc = dim.getLoc();
Value extract =
rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
if (extract.getType() != dim.getType())
extract =
rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
rewriter.replaceOp(dim, extract);
return success();
}
};
} // namespace

void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfCastOp, DimOfDestStyleOp>(context);
results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
53 changes: 53 additions & 0 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,59 @@ func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)

// -----

// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>,
// CHECK-SAME: %[[IDX:[0-9a-z]+]]: index
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
// CHECK-NOT: memref.dim
// CHECK: return %[[DIM]] : index
func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
%dim = memref.dim %reshape, %arg2 : memref<*xf32>
return %dim : index
}

// -----

// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_for(
// CHECK: memref.reshape
// CHECK: memref.dim
// CHECK-NOT: memref.load
func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref<?xindex>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index

%0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>

%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
%2 = memref.dim %0, %arg2 : memref<*xf32>
%3 = arith.muli %arg3, %2 : index
scf.yield %3 : index
}
return %1 : index
}

// -----

// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_undominated(
// CHECK: memref.reshape
// CHECK: memref.dim
// CHECK-NOT: memref.load
func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
%c4 = arith.constant 4 : index
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
%0 = arith.muli %arg2, %c4 : index
%dim = memref.dim %reshape, %0 : memref<*xf32>
return %dim : index
}

// -----

// CHECK-LABEL: func @alloc_const_fold
func.func @alloc_const_fold() -> memref<?xf32> {
// CHECK-NEXT: memref.alloc() : memref<4xf32>
Expand Down
80 changes: 80 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK: return %[[SRC]]

// -----

// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
// CHECK-NOT: tensor.store
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
// CHECK: return %[[DIM]] : index
func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
-> index {
%c3 = arith.constant 3 : index
%0 = tensor.reshape %arg0(%arg1)
: (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// Update the shape to test that the load ends up in the right place.
tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
%1 = tensor.dim %0, %c3 : tensor<*xf32>
return %1 : index
}

// -----

// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_i32(
// CHECK: tensor.extract
// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
// CHECK: return %[[CAST]] : index
func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
-> index {
%c3 = arith.constant 3 : index
%0 = tensor.reshape %arg0(%arg1)
: (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
%1 = tensor.dim %0, %c3 : tensor<*xf32>
return %1 : index
}

// -----

// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_for(
// CHECK: scf.for
// CHECK-NEXT: tensor.extract
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index

%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>

%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
%2 = tensor.dim %0, %arg2 : tensor<*xf32>
%3 = arith.muli %arg3, %2 : index
scf.yield %3 : index
}
return %1 : index
}

// -----

// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_undominated(
// CHECK: arith.muli
// CHECK-NEXT: tensor.extract
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
%c4 = arith.constant 4 : index
%reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
%0 = arith.muli %arg2, %c4 : index
%dim = tensor.dim %reshape, %0 : tensor<*xf32>
return %dim : index
}
Loading