Skip to content

Commit

Permalink
[mlir][tensor] Fold padding_value away for pack ops when possible. (#…
Browse files Browse the repository at this point in the history
…74005)

If we can infer statically that there are no incomplete tiles, we can
remove the optional padding operand.

Fixes iree-org/iree#15417
  • Loading branch information
hanhanW committed Dec 1, 2023
1 parent 8c1d476 commit 171cac9
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 10 deletions.
43 changes: 33 additions & 10 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -3800,17 +3801,39 @@ static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
return true;
}

/// Fold an unpack(pack(x)) to x.
/// Returns true if the pack op does not need a padding value.
static bool paddingIsNotNeeded(PackOp op) {
auto srcType = op.getSourceType();
if (llvm::any_of(op.getInnerDimsPos(),
[&](int64_t pos) { return srcType.isDynamicDim(pos); }))
return false;
if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
return false;
return !PackOp::requirePaddingValue(srcType.getShape(), op.getInnerDimsPos(),
op.getMixedTiles());
}

LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
UnPackOp unPackOp = packOp.getSource().getDefiningOp<UnPackOp>();
if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType())
return failure();
if (packOp.getPaddingValue() ||
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
!haveSameTiles(packOp, unPackOp))
return failure();
rewriter.replaceOp(packOp, unPackOp.getSource());
return success();
// Fold an unpack(pack(x)) to x.
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
if (unPackOp.getSourceType() != packOp.getDestType())
return failure();
if (packOp.getPaddingValue() ||
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
!haveSameTiles(packOp, unPackOp))
return failure();
rewriter.replaceOp(packOp, unPackOp.getSource());
return success();
}

// Fold optional PaddingValue operand away if padding is not needed.
if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
rewriter.startRootUpdate(packOp);
packOp.getPaddingValueMutable().clear();
rewriter.finalizeRootUpdate(packOp);
return success();
}
return failure();
}

template <typename PackOrUnpackOp>
Expand Down
65 changes: 65 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,71 @@ func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x1

// -----

func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>
%pack = tensor.pack %arg0
padding_value(%cst : f32)
outer_dims_perm = [1, 0]
inner_dims_pos = [1, 0]
inner_tiles = [16, 1]
into %0 : tensor<1200x500000xf32> -> tensor<31250x1200x16x1xf32>
return %pack : tensor<31250x1200x16x1xf32>
}
// CHECK-LABEL: func @fold_padding_value_pack
// CHECK-NOT: padding_value

// -----

func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>
%pack = tensor.pack %arg0
padding_value(%cst : f32)
outer_dims_perm = [1, 0]
inner_dims_pos = [1, 0]
inner_tiles = [16, 1]
into %0 : tensor<1200x499999xf32> -> tensor<31250x1200x16x1xf32>
return %pack : tensor<31250x1200x16x1xf32>
}
// CHECK-LABEL: func @fold_padding_value_pack_negative1
// CHECK: tensor.pack
// CHECK-SAME: padding_value

// -----

func.func @fold_padding_value_pack_negative2(%arg0: tensor<1200x?xf32>, %arg1: tensor<?x1200x16x1xf32>) -> tensor<?x1200x16x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%pack = tensor.pack %arg0
padding_value(%cst : f32)
outer_dims_perm = [1, 0]
inner_dims_pos = [1, 0]
inner_tiles = [16, 1]
into %arg1 : tensor<1200x?xf32> -> tensor<?x1200x16x1xf32>
return %pack : tensor<?x1200x16x1xf32>
}
// CHECK-LABEL: func @fold_padding_value_pack_negative2
// CHECK: tensor.pack
// CHECK-SAME: padding_value

// -----

func.func @fold_padding_value_pack_negative3(%arg0: tensor<1200x500000xf32>, %arg1: tensor<?x1200x?x1xf32>, %tile : index) -> tensor<?x1200x?x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%pack = tensor.pack %arg0
padding_value(%cst : f32)
outer_dims_perm = [1, 0]
inner_dims_pos = [1, 0]
inner_tiles = [%tile, 1]
into %arg1 : tensor<1200x500000xf32> -> tensor<?x1200x?x1xf32>
return %pack : tensor<?x1200x?x1xf32>
}
// CHECK-LABEL: func @fold_padding_value_pack_negative3
// CHECK: tensor.pack
// CHECK-SAME: padding_value

// -----

// CHECK-LABEL: func @fold_unpack_constant_splat
// CHECK-NOT: tensor.unpack
// CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>
Expand Down

0 comments on commit 171cac9

Please sign in to comment.