diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h index 5c72b5ea5c664..42d88156300c2 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h @@ -55,8 +55,7 @@ FailureOr castOrReallocMemRefValue(OpBuilder &b, Value value, /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the /// to_memref op are different, a memref.cast is needed. LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, - ToMemrefOp toMemref, - bool allowSameType = true); + ToMemrefOp toMemref); } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 679c4e8bbba3f..62cf424e6fef5 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -84,8 +84,9 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the /// to_memref op are different, a memref.cast is needed. -LogicalResult mlir::bufferization::foldToMemrefToTensorPair( - RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) { +LogicalResult +mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter, + ToMemrefOp toMemref) { auto memrefToTensor = toMemref.getTensor().getDefiningOp(); if (!memrefToTensor) return failure(); @@ -95,9 +96,6 @@ LogicalResult mlir::bufferization::foldToMemrefToTensorPair( // Directly rewrite if the type did not change. if (srcType == destType) { - // Function can be configured to only handle cases where a cast is needed. - if (!allowSameType) - return failure(); rewriter.replaceOp(toMemref, memrefToTensor.getMemref()); return success(); } @@ -541,6 +539,19 @@ OpFoldResult ToTensorOp::fold(ArrayRef) { } namespace { +/// Canonicalize bufferization.to_tensor + bufferization.to_memref. +struct ToTensorToMemrefFolding : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ToTensorOp toTensorOp, + PatternRewriter &rewriter) const final { + auto toMemrefOp = toTensorOp.getMemref().getDefiningOp(); + if (!toMemrefOp) + return failure(); + rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor()); + return success(); + } +}; struct DimOfToTensorFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -556,12 +567,11 @@ struct DimOfToTensorFolder : public OpRewritePattern { return success(); } }; - } // namespace void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// @@ -601,17 +611,14 @@ struct ToMemrefOfCast : public OpRewritePattern { } }; -/// Canonicalize bufferization.to_tensor + bufferization.to_memref to -/// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. -struct TensorLoadToMemref : public OpRewritePattern { +/// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a +/// cast if necessary. +struct ToMemrefToTensorFolding : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ToMemrefOp toMemref, PatternRewriter &rewriter) const final { - // Only handle cases where a cast is needed. The other case is handled by - // the folder. - return foldToMemrefToTensorPair(rewriter, toMemref, - /*allowSameType=*/false); + return foldToMemrefToTensorPair(rewriter, toMemref); } }; @@ -651,8 +658,8 @@ struct DimOfCastOp : public OpRewritePattern { void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add( - context); + results.add(context); } LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index ad5afa9c36015..da7d043133020 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -787,8 +787,7 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>, } // CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32> - // CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32> - // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> + // CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> } diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir index df55b8373e0ee..f24048e60e07c 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir @@ -109,8 +109,7 @@ // CHECK: scf.yield %[[VAL_84]] : f64 // CHECK: } // CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref -// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref -// CHECK: return %[[VAL_87]] : tensor +// CHECK: return %[[VAL_0]] : tensor // CHECK: } func.func @sparse_matrix_sum(%argx: tensor {linalg.inplaceable = true}, %arga: tensor<64x32xf64, #SparseMatrix>,