diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index a0c7e40c20a46..d085e7cb72c5c 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1244,12 +1244,33 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern { } }; +template +struct FoldEmptyTensorWithCollapseExpandOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(T op, + PatternRewriter &rewriter) const override { + auto producer = op.getSrc().template getDefiningOp(); + if (!producer) + return failure(); + if (!producer.getType().hasStaticShape()) + return failure(); + + auto resultType = cast(op.getResultType()); + rewriter.replaceOpWithNewOp(op, resultType.getShape(), + resultType.getElementType()); + return success(); + } +}; + } // namespace void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + ReplaceEmptyTensorStaticShapeDims, + FoldEmptyTensorWithCollapseExpandOp, + FoldEmptyTensorWithCollapseExpandOp>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index f4020ede4854e..f2a708b1d747d 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -389,9 +389,9 @@ func.func @fold_linalg_index_memref(%0: memref<1x?xi32>, %1: memref<1x?xi32>) { func.func @fold_fill_reshape() -> tensor<6x4xf32> { %zero = arith.constant 0.0 : f32 %empty = tensor.empty() : tensor<1x2x3x4xf32> - // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape + // CHECK: %[[COLLAPSED_EMPTY:.+]] = tensor.empty() // CHECK-NEXT: %[[FILL:.+]] = linalg.fill ins(%cst : f32) - // CHECK-SAME: outs(%[[COLLAPSE]] : tensor<6x4xf32>) + // CHECK-SAME: outs(%[[COLLAPSED_EMPTY]] : tensor<6x4xf32>) %fill = linalg.fill ins(%zero : f32) outs(%empty : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %reshape = tensor.collapse_shape %fill [[0, 1, 2], [3]] : tensor<1x2x3x4xf32> into tensor<6x4xf32> @@ -512,7 +512,7 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) { // ----- // CHECK-LABEL: func @no_fold_fill_like_memref -// CHECK-NEXT: linalg.generic +// CHECK-NEXT: linalg.generic func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32) { linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -528,7 +528,7 @@ func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32) // ----- // CHECK-LABEL: func @no_fold_fill_like_tensor -// CHECK-NEXT: linalg.generic +// CHECK-NEXT: linalg.generic func.func @no_fold_fill_like_tensor(%in_out : tensor<4x16xf32>, %fill_val : f32) -> tensor<4x16xf32> { %result = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 95c5b8c91edf5..fe0ea58282149 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2218,6 +2218,30 @@ func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> { // ----- +func.func @fold_empty_tensor_with_collapse() -> tensor<12xf32> { + %0 = tensor.empty() : tensor<1x12xf32> + %1 = tensor.collapse_shape %0 [[0, 1]]: tensor<1x12xf32> into tensor<12xf32> + return %1 : tensor<12xf32> +} + +// CHECK: func @fold_empty_tensor_with_collapse() +// CHECK: %[[T0:.+]] = tensor.empty() : tensor<12xf32> +// CHECK: return %[[T0]] : tensor<12xf32> + +// ----- + +func.func @fold_empty_tensor_with_expand() -> tensor<1x12xf32> { + %0 = tensor.empty() : tensor<12xf32> + %1 = tensor.expand_shape %0 [[0, 1]] output_shape [1, 12] : tensor<12xf32> into tensor<1x12xf32> + return %1 : tensor<1x12xf32> +} + +// CHECK: func @fold_empty_tensor_with_expand() +// CHECK: %[[T0:.+]] = tensor.empty() : tensor<1x12xf32> +// CHECK: return %[[T0]] : tensor<1x12xf32> + +// ----- + func.func private @some_use(%i : index, %j : index) // CHECK-LABEL: func @empty_tensor_canonicalize @@ -2523,8 +2547,8 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>) -> tensor<10x1x10xf32> { - %c1 = arith.constant 1 : index - %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10] : tensor into tensor @@ -2549,7 +2573,7 @@ func.func @sink_expand_of_cast(%arg0 : tensor) // CHECK-LABEL: func.func @sink_expand_of_cast // CHECK-DAG: %[[C10:.*]] = arith.constant 10 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: output_shape [%[[C10]], %[[C1]], 10] // CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]] // CHECK: return %[[RES]] @@ -2567,7 +2591,7 @@ func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, // CHECK-LABEL: func.func @partial_sink_expand_of_cast // CHECK: %[[CAST:.+]] = tensor.cast // CHECK-SAME: tensor<10x10xf32> to tensor -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: output_shape [%{{.*}}, %{{.*}}, 10] // CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]] // CHECK-SAME: tensor to tensor