From 5c5b64d7b880e3674a77413e9891b30f22635cb9 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Tue, 23 Sep 2025 09:26:53 -0700 Subject: [PATCH] fix linalg.pack canonicalization Signed-off-by: Nirvedh Meshram --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 15 +++++++-------- mlir/test/Dialect/Linalg/canonicalize.mlir | 3 ++- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 578931e1351c6..4bc4d97697a21 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -5583,14 +5583,13 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl &srcShape, LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { // Fold an pack(unpack(x)) to x. if (auto unPackOp = packOp.getSource().getDefiningOp()) { - if (unPackOp.getSourceType() != packOp.getDestType()) - return failure(); - if (packOp.getPaddingValue() || - !hasSameInnerOuterAttribute(packOp, unPackOp) || - !haveSameTiles(packOp, unPackOp)) - return failure(); - rewriter.replaceOp(packOp, unPackOp.getSource()); - return success(); + if (unPackOp.getSourceType() == packOp.getDestType() && + !packOp.getPaddingValue() && + hasSameInnerOuterAttribute(packOp, unPackOp) && + haveSameTiles(packOp, unPackOp)) { + rewriter.replaceOp(packOp, unPackOp.getSource()); + return success(); + } } // Fold optional PaddingValue operand away if padding is not needed. diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 5c5f7e861d37d..26d2d98572f47 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1756,10 +1756,11 @@ func.func @pack_unpack(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) // CHECK-SAME: %[[T:.+]]: tensor<16x16x8x8xf32> // CHECK: return %[[T]] : tensor<16x16x8x8xf32> func.func @pack_unpack(%t: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> { + %cst = arith.constant 0.000000e+00 : f32 %tensor_empty = tensor.empty() : tensor<128x128xf32> %unpacked = linalg.unpack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<16x16x8x8xf32> -> tensor<128x128xf32> %tensor_empty1 = tensor.empty() : tensor<16x16x8x8xf32> - %packed = linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32> + %packed = linalg.pack %unpacked padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32> return %packed : tensor<16x16x8x8xf32> }