diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp index 3cb8e6554891..cc726479cdd5 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp @@ -694,6 +694,80 @@ struct DynamicBroadcastInDimAllDimsNonExpanding final } }; +struct NoopReduceOpCanon final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op, + PatternRewriter &rewriter) const override { + // No dimensions to reduce. + if (op.getDimensions().empty()) { + rewriter.replaceOp(op, op.getInputs()); + return success(); + } + + // If all returned values in the ReduceOp region exists outside the + // region, replace the ReduceOp with those values. + if (auto retOp = dyn_cast( + op.getBody().front().getTerminator())) { + Region *retRegion = retOp->getParentRegion(); + if (llvm::any_of(retOp.getResults(), [retRegion](Value result) { + return result.getParentRegion() == retRegion; + })) { + return failure(); + } + + rewriter.replaceOp(op, retOp.getResults()); + return success(); + } + + return failure(); + } +}; + +struct EmptyReduceOpCanon final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op, + PatternRewriter &rewriter) const override { + // We require all reduce shapes to be the same, up to the element types, so + // we can just the first operand and the first result as a representative. + auto elemTy = dyn_cast(op.getInputs().getType().front()); + if (!elemTy) { + return rewriter.notifyMatchFailure(op.getLoc(), + "unranked input unsupported"); + } + + if (!llvm::is_contained(elemTy.getShape(), 0)) return failure(); + + Location loc = op.getLoc(); + DenseIntElementsAttr empty = rewriter.getI64TensorAttr({}); + if (elemTy.hasStaticShape()) { + SmallVector broadcasts(op.getNumResults()); + for (auto [bcast, init, outTy] : llvm::zip_equal( + broadcasts, op.getInitValues(), op.getResultTypes())) { + bcast = rewriter.create(loc, outTy, + init, empty); + } + rewriter.replaceOp(op, broadcasts); + return success(); + } + + SmallVector shapes; + if (failed(op.reifyReturnTypeShapes(rewriter, op.getOperands(), shapes))) { + return failure(); + } + + SmallVector broadcasts(op.getNumResults()); + for (auto [bcast, init, shape, outTy] : llvm::zip_equal( + broadcasts, op.getInitValues(), shapes, op.getResultTypes())) { + bcast = rewriter.create( + loc, outTy, init, shape, empty); + } + rewriter.replaceOp(op, broadcasts); + return success(); + } +}; + struct DynamicReshapeOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -922,6 +996,8 @@ void populateCanonicalizationPatterns(MLIRContext *context, BroadcastInDimOpCanon, DynamicBroadcastInDimOpNotActuallyDynamic, ChainedDynamicBroadcastInDimCanonicalization, DynamicBroadcastInDimAllDimsNonExpanding, + // Reduce op. + NoopReduceOpCanon, EmptyReduceOpCanon, // Shape manipulation(-ish) ops. ConcatenateOpCanon, ConvertOpCanon, DynamicReshapeOpCanon, GatherOpCanon, ReshapeOpCanon, TransposeOpCanon>(context, benefit); diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir index 911b21ad3cb3..abc297a38e08 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir @@ -427,7 +427,7 @@ func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32> // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape(%arg0: tensor) -> tensor<4x32xi32> { - %0 = mhlo.constant dense<[4, 32]> : tensor<2xi32> + %0 = stablehlo.constant dense<[4, 32]> : tensor<2xi32> // CHECK: %[[RESULT:.+]] = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor) -> tensor<4x32xi32> %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [] : (tensor, tensor<2xi32>) -> tensor %2 = stablehlo.dynamic_reshape %1, %0 : (tensor, tensor<2xi32>) -> tensor<4x32xi32> @@ -584,3 +584,48 @@ func.func @gather_to_slice_indices_clamp_lowerbound(%arg0 : tensor<4x2xui32>) -> // CHECK-NEXT: %[[V1:.*]] = stablehlo.reshape %[[V0]] : (tensor<1x2xui32>) -> tensor<2xui32> // CHECK-NEXT: return %[[V1]] : tensor<2xui32> } + +// ----- + +// CHECK-LABEL: func.func @reduce_noop_1 +// CHECK-SAME: ([[ARG0:%.+]]: tensor<4x8xf32>) +func.func @reduce_noop_1(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = stablehlo.constant dense<0.000000e+00> : tensor + %1 = stablehlo.reduce(%arg0 init: %0) across dimensions = [] : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> + reducer(%arg1: tensor, %arg2: tensor) { + %4 = stablehlo.add %arg1, %arg2 : tensor + stablehlo.return %4 : tensor + } + // CHECK: return [[ARG0]] : tensor<4x8xf32> + func.return %1 : tensor<4x8xf32> +} + +// CHECK-LABEL: func.func @reduce_noop_2 +// CHECK-SAME: ([[ARG0:%.+]]: tensor<4x8xi32>, [[ARG1:%.+]]: tensor) +func.func @reduce_noop_2(%arg0: tensor<4x8xi32>, %arg1: tensor) -> tensor { + %0 = stablehlo.constant dense<0> : tensor + %1 = stablehlo.reduce(%arg0 init: %0) across dimensions = [0, 1] : (tensor<4x8xi32>, tensor) -> tensor + reducer(%b1: tensor, %b2: tensor) { + stablehlo.return %arg1 : tensor + } + // CHECK: return [[ARG1]] : tensor + func.return %1 : tensor +} + +// CHECK-LABEL: func.func @reduce_zero_ext +func.func @reduce_zero_ext(%arg0: tensor<0xi1>) -> tensor { + %0 = stablehlo.constant dense : tensor + %1 = stablehlo.constant dense : tensor<0xi1> + %2 = stablehlo.compare NE, %arg0, %1, UNSIGNED : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> + %3 = stablehlo.convert %2 : (tensor<0xi1>) -> tensor<0xi32> + %4 = stablehlo.constant dense<0> : tensor + %5 = stablehlo.reduce(%3 init: %4) across dimensions = [0] : (tensor<0xi32>, tensor) -> tensor + reducer(%arg1: tensor, %arg2: tensor) { + %6 = stablehlo.add %arg1, %arg2 : tensor + stablehlo.return %6 : tensor + } + + // CHECK: [[CST:%.+]] = stablehlo.constant dense<0> : tensor + // CHECK: return [[CST]] : tensor + return %5 : tensor +}