Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization #84225

Merged
merged 8 commits into from
Mar 12, 2024
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
Speculation::Speculatability getSpeculatability();
}];

let hasCanonicalizer = 1;
let hasFolder = 1;
}

Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,6 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
MLIRContext *context = enclosingOp->getContext();
RewritePatternSet patterns(context);
patterns.add<LinalgRewritePattern<LoopType>>(context);
memref::DimOp::getCanonicalizationPatterns(patterns, context);
sahas3 marked this conversation as resolved.
Show resolved Hide resolved
tensor::DimOp::getCanonicalizationPatterns(patterns, context);
affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.add<FoldAffineOp>(context);
Expand Down
33 changes: 0 additions & 33 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1069,39 +1069,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
return {};
}

namespace {
/// Fold dim of a memref reshape operation to a load into the reshape's shape
/// operand.
struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;

LogicalResult matchAndRewrite(DimOp dim,
PatternRewriter &rewriter) const override {
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();

if (!reshape)
return failure();

// Place the load directly after the reshape to ensure that the shape memref
// was not mutated.
rewriter.setInsertionPointAfter(reshape);
Location loc = dim.getLoc();
Value load =
rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
sahas3 marked this conversation as resolved.
Show resolved Hide resolved
if (load.getType() != dim.getType())
load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
rewriter.replaceOp(dim, load);
return success();
}
};

} // namespace

void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfMemRefReshape>(context);
}

// ---------------------------------------------------------------------------
// DmaStartOp
// ---------------------------------------------------------------------------
Expand Down
28 changes: 27 additions & 1 deletion mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,11 +824,37 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
return success();
}
};

/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
/// operand.
struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;

LogicalResult matchAndRewrite(DimOp dim,
PatternRewriter &rewriter) const override {
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();

if (!reshape)
return failure();

// Since tensors are immutable we don't need to worry about where to place
// the extract call
rewriter.setInsertionPointAfter(dim);
Location loc = dim.getLoc();
Value extract =
rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
if (extract.getType() != dim.getType())
extract =
rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
rewriter.replaceOp(dim, extract);
return success();
}
};
} // namespace

void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfCastOp, DimOfDestStyleOp>(context);
results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
42 changes: 0 additions & 42 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -242,48 +242,6 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {

// -----

// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
// CHECK-NEXT: memref.store
// CHECK-NOT: memref.dim
// CHECK: return %[[DIM]] : index
func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
-> index {
%c3 = arith.constant 3 : index
%0 = memref.reshape %arg0(%arg1)
: (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
// Update the shape to test that he load ends up in the right place.
memref.store %c3, %arg1[%c3] : memref<?xindex>
%1 = memref.dim %0, %c3 : memref<*xf32>
return %1 : index
}

// -----

// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_i32(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]]
// CHECK-NOT: memref.dim
// CHECK: return %[[CAST]] : index
func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
-> index {
%c3 = arith.constant 3 : index
%0 = memref.reshape %arg0(%arg1)
: (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
%1 = memref.dim %0, %c3 : memref<*xf32>
return %1 : index
}

// -----

// CHECK-LABEL: func @alloc_const_fold
func.func @alloc_const_fold() -> memref<?xf32> {
// CHECK-NEXT: memref.alloc() : memref<4xf32>
Expand Down
80 changes: 80 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK: return %[[SRC]]

// -----

// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
// CHECK-NOT: tensor.store
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
// CHECK: return %[[DIM]] : index
func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
-> index {
%c3 = arith.constant 3 : index
%0 = tensor.reshape %arg0(%arg1)
: (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// Update the shape to test that the load ends up in the right place.
tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
%1 = tensor.dim %0, %c3 : tensor<*xf32>
return %1 : index
}

// -----

// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_i32(
// CHECK: tensor.extract
// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
// CHECK: return %[[CAST]] : index
func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
-> index {
%c3 = arith.constant 3 : index
%0 = tensor.reshape %arg0(%arg1)
: (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
%1 = tensor.dim %0, %c3 : tensor<*xf32>
return %1 : index
}

// -----

// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_for(
// CHECK: scf.for
// CHECK-NEXT: tensor.extract
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index

%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>

%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
%2 = tensor.dim %0, %arg2 : tensor<*xf32>
%3 = arith.muli %arg3, %2 : index
scf.yield %3 : index
}
return %1 : index
}

// -----

// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_undominated(
// CHECK: arith.muli
// CHECK-NEXT: tensor.extract
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
%c4 = arith.constant 4 : index
%reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
%0 = arith.muli %arg2, %c4 : index
%dim = tensor.dim %reshape, %0 : tensor<*xf32>
return %dim : index
}
Loading