diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp index 670865de6031f..73b3a1cbf7263 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp @@ -40,11 +40,14 @@ struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern { !llvm::hasSingleElement(resultShapes)) return failure(); + Attribute encoding; + if (auto tensorTy = dyn_cast(reshapeOp.getResultType())) + encoding = tensorTy.getEncoding(); + // Create new tensor.empty op. - // TODO: Do not drop tensor type encoding. Value emptyTensor = EmptyOp::create(rewriter, loc, resultShapes[0], - reshapeOp.getResultType().getElementType()); + reshapeOp.getResultType().getElementType(), encoding); if (emptyTensor.getType() != reshapeOp.getResultType()) { rewriter.replaceOpWithNewOp( reshapeOp, reshapeOp.getResultType(), emptyTensor); diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir index 7b11c9f43c7ec..62ee7e8c2d5ca 100644 --- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir +++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir @@ -37,6 +37,27 @@ func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { // CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) // CHECK-NEXT: return %[[INIT]] +#encoding = #test.tensor_encoding<"encoding"> + +func.func @empty_expand_encoding() -> tensor<2x3x4x2xf32, #encoding> { + %0 = tensor.empty() : tensor<6x8xf32, #encoding> + %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [2, 3, 4, 2] : tensor<6x8xf32, #encoding> into tensor<2x3x4x2xf32, #encoding> + return %1 : tensor<2x3x4x2xf32, #encoding> +} +// CHECK-LABEL: func.func @empty_expand_encoding +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x3x4x2xf32, #test.tensor_encoding<"encoding">> +// CHECK-NEXT: return %[[INIT]] + +func.func @empty_collapse_encoding() -> tensor<6x8xf32, #encoding> { + %0 = tensor.empty() : tensor<2x3x4x2xf32, #encoding> + %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] + : tensor<2x3x4x2xf32, #encoding> into tensor<6x8xf32, #encoding> + return %1 : tensor<6x8xf32, #encoding> +} +// CHECK-LABEL: func.func @empty_collapse_encoding +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<6x8xf32, #test.tensor_encoding<"encoding">> +// CHECK-NEXT: return %[[EMPTY_0]] + func.func @fold_empty_tensor_with_slice (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32> {