diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index f9c8589683ba7..27988a451173c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -5010,15 +5010,14 @@ template SmallVector getPackedOuterShapeWithoutTransposition(UnPackOp); // Given the (potentially) updated packed type, `newPackedTy`, generates an -// updated mixed-tile-sizes attribute. A tile size is updated only -// when: -// * a dim from newPackedTy is static, and -// * the corresponding size from mixedTiles is still dynamic. -// Otherwise, the original tile size is preserved. +// updated mixed-tile-sizes list. For each inner packed dimension that is static +// in `newPackedTy`, the tile is set to that static size (replacing SSA values +// or mismatched constants). Dynamic packed dimensions preserve the original +// tile. The folded tensor type is treated as authoritative for static extents. // Note - packed-type-dim and mixed-tile-size should always match! static SmallVector getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, - SmallVector mixedTiles) { + ArrayRef mixedTiles) { SmallVector newMixedTileSizes; for (auto it : llvm::zip(cast(newPackedTy) .getShape() @@ -5029,19 +5028,7 @@ getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, newMixedTileSizes.push_back(std::get<1>(it)); continue; } - - // If the current result dim is static, update the dynamic mixed-size - // (provided the original value is dynamic). - OpFoldResult tile = std::get<1>(it); - if (Attribute attr = llvm::dyn_cast_if_present(tile)) { - // Already a constant - newMixedTileSizes.push_back(tile); - } else { - assert(getConstantIntValue(tile).value() == dimSize && - "tile size and dim size don't match!"); - newMixedTileSizes.push_back( - (rewriter.getIntegerAttr(rewriter.getIndexType(), dimSize))); - } + newMixedTileSizes.push_back(rewriter.getIndexAttr(dimSize)); } return newMixedTileSizes; diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 0c5a1c6108ae3..019b7433b2777 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -2179,3 +2179,82 @@ func.func @negative_unpack_pack_memref_no_canonicalization(%packed: memref<16x8x linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<128x256xf32> -> memref<16x8x8x32xf32> return } + +// ----- +// CHECK-LABEL: func.func @fold_unpack_cast_inner_tile_dynamic_arg +// CHECK-SAME: %[[SRC:.+]]: tensor<1x3x8x1xi32>, %[[TILE:.+]]: index +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32> +// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]] +// CHECK-SAME: inner_dims_pos = [0, 1] +// CHECK-SAME: inner_tiles = [8, 1] +// CHECK-SAME: into %[[EMPTY]] : tensor<1x3x8x1xi32> -> tensor<7x3xi32> +// CHECK: return %[[UNPACK]] : tensor<7x3xi32> +func.func @fold_unpack_cast_inner_tile_dynamic_arg(%arg0: tensor<1x3x8x1xi32>, %arg1: index) -> tensor<7x3xi32> { + %0 = tensor.empty() : tensor<7x3xi32> + %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor + %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%arg1, 1] into %0 : tensor -> tensor<7x3xi32> + return %unpack : tensor<7x3xi32> +} + + +// ----- +// Mismatched constant tile vs static packed shape: fold still drops the cast and +// takes inner tile sizes from the refined packed type. +// CHECK-LABEL: func.func @fold_unpack_cast_inner_tile_inlined_mismatch +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<7x3xi32> +// CHECK-NOT: tensor.cast +// CHECK: %[[UNPACK:.+]] = linalg.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 1] +// CHECK-SAME: into %[[EMPTY]] : tensor<1x3x8x1xi32> -> tensor<7x3xi32> +// CHECK: return %[[UNPACK]] : tensor<7x3xi32> +func.func @fold_unpack_cast_inner_tile_inlined_mismatch(%arg0: tensor<1x3x8x1xi32>) -> tensor<7x3xi32> { + %c256 = arith.constant 256 : index + %1 = tensor.empty() : tensor<7x3xi32> + %cast = tensor.cast %arg0 : tensor<1x3x8x1xi32> to tensor + %unpack = linalg.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c256, 1] into %1 : tensor -> tensor<7x3xi32> + return %unpack : tensor<7x3xi32> +} + +// ----- + +// CHECK-LABEL: func.func @no_fold_pack_cast_inner_tile_dynamic_arg +// CHECK-SAME: %[[SRC:.+]]: tensor<8x3xi32>, %[[TILE:.+]]: index, %[[DEST:.+]]: tensor +// CHECK: %[[PACK:.+]] = linalg.pack +// CHECK: padding_value +// CHECK: inner_dims_pos = [0, 1] +// CHECK: inner_tiles = [%[[TILE]], 1] +// CHECK: into %[[DEST]] : tensor +// CHECK: return %[[PACK]] : tensor +func.func @no_fold_pack_cast_inner_tile_dynamic_arg(%arg0: tensor<8x3xi32>, %arg1: index, + %dest: tensor) -> tensor { + %c0 = arith.constant 0 : i32 + %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor + %pack = linalg.pack %cast + padding_value(%c0 : i32) + inner_dims_pos = [0, 1] + inner_tiles = [%arg1, 1] + into %dest : tensor -> tensor + return %pack : tensor +} + +// ----- + +// CHECK-LABEL: func.func @no_fold_pack_cast_inner_tile_inlined_mismatch +// CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index +// CHECK: %[[PACK:.+]] = linalg.pack +// CHECK: padding_value +// CHECK: inner_dims_pos = [0, 1] +// CHECK: inner_tiles = [%[[C256]], 1] +// CHECK: into %{{.+}} : tensor +// CHECK: return %[[PACK]] : tensor +func.func @no_fold_pack_cast_inner_tile_inlined_mismatch(%arg0: tensor<8x3xi32>, + %dest: tensor) -> tensor { + %c0 = arith.constant 0 : i32 + %c256 = arith.constant 256 : index + %cast = tensor.cast %arg0 : tensor<8x3xi32> to tensor + %pack = linalg.pack %cast + padding_value(%c0 : i32) + inner_dims_pos = [0, 1] + inner_tiles = [%c256, 1] + into %dest : tensor -> tensor + return %pack : tensor +}