diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 3bb5f8af821c0..a1499824fde15 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -1309,9 +1309,10 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern { paddingVal, padOp.getNofold()); // Inject the linalg.unpack right after the packed padOp. - Value outputUnPack = - tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(), - padOp.getResultType().getElementType()); + // Compute the unpacked output size directly from the padded packed tensor. + Value outputUnPack = linalg::UnPackOp::createDestinationTensor( + rewriter, loc, newPadOp.getResult(), unpackOp.getMixedTiles(), + innerDimsPos, outerDimsPerm); Value replacement = linalg::UnPackOp::create( rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos, diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 7a16bc0a4faee..6121b69a3ecd8 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1634,3 +1634,33 @@ func.func @push_extract_through_generic_secondextract(%arg0: tensor<128x128xf32> // CHECK-SAME: ins(%[[PAD]], %[[ARG0]] // CHECK: %[[EXTRACT2:.+]] = tensor.extract_slice %[[GENERIC]] // CHECK: scf.yield %[[EXTRACT2]] + +// ----- + +func.func @test_dynamic_unpack_pad(%arg0: tensor, %dim: index) -> tensor { + %dest = tensor.empty(%dim) : tensor + %unpack = linalg.unpack %arg0 inner_dims_pos = [1] inner_tiles = [8] + into %dest : tensor -> tensor + + %c0 = arith.constant 0.0 : f32 + %pad = tensor.pad %unpack low[2, 0] high[3, 0] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %c0 : f32 + } : tensor to tensor + + return %pad : tensor +} + +// CHECK-LABEL: func.func @test_dynamic_unpack_pad +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DIM:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[2, 0, 0] high[3, 0, 0] +// CHECK: tensor.yield %[[CST]] +// CHECK: %[[DIM_VAL:.+]] = tensor.dim %[[PAD]], %[[C0]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM_VAL]]) +// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[PAD]] +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [8] +// CHECK-SAME: into %[[EMPTY]] +// CHECK: return %[[UNPACK]]