diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 3bd763ea00cd7..f27175a1f91e3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -2200,6 +2200,63 @@ struct RemoveOutsDependency : public OpRewritePattern { } }; +/// Drops an unused result from an elementwise `linalg.generic` by +/// reclassifying its tied `outs` operand as an extra input operand. +struct DropRedundantResultsFromGenericOps + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(linalg::GenericOp op, + PatternRewriter &rewriter) const override { + if (!linalg::isElementwise(op) || op.getNumResults() < 2U) + return failure(); + + // Given that the op has no reductions, there is no need to preserve an + // unused result: transform it into an input instead. + auto maybeUnusedRes = llvm::find_if( + op.getResults(), [](OpResult res) { return res.use_empty(); }); + if (maybeUnusedRes == op.getResults().end()) + return failure(); + + OpResult unusedRes = *maybeUnusedRes; + const unsigned resIdx = unusedRes.getResultNumber(); + auto resTypes = llvm::to_vector(op.getResultTypes()); + resTypes.erase(resTypes.begin() + resIdx); + SmallVector resValues = llvm::to_vector_of(op.getResults()); + resValues.erase(resValues.begin() + resIdx); + const int64_t numInputs = op.getNumDpsInputs(); + OpOperand *resOperand = op.getTiedOpOperand(unusedRes); + AffineMap map = op.getIndexingMapMatchingResult(unusedRes); + const unsigned operandIdx = resOperand->getOperandNumber(); + + // Remove the output operand and add it as an input operand with the same + // map. + SmallVector outs(op.getOutputs()); + outs.erase(outs.begin() + resIdx); + SmallVector ins(op.getInputs()); + ins.insert(ins.begin() + numInputs, resOperand->get()); + SmallVector maps = op.getIndexingMapsArray(); + maps.erase(maps.begin() + operandIdx); + maps.insert(maps.begin() + numInputs, map); + rewriter.setInsertionPoint(op); + + auto newGenericOp = rewriter.create( + op.getLoc(), TypeRange(resTypes), ins, outs, maps, + op.getIteratorTypesArray()); + + op->setDiscardableAttrs(op->getDiscardableAttrDictionary()); + op.getBody()->getTerminator()->eraseOperands(resIdx); + newGenericOp.getRegion().takeBody(op.getBodyRegion()); + + // Replace the remaining results of the old op with the results of the new + // op. + rewriter.replaceAllUsesWith(resValues, newGenericOp.getResults()); + + // Remove the old op. + rewriter.eraseOp(op); + return success(); + } +}; + /// Fold linalg.fill into linalg.generic struct FoldFillWithGenericOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -2262,6 +2319,7 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns( RemoveOutsDependency>(context); // Add the patterns that clean up dead operands and results. populateEraseUnusedOperandsAndResultsPatterns(patterns); + patterns.add(context); } void mlir::linalg::populateCollapseDimensions( diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index bc55c12c02f29..9f1fd4609b00e 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1079,4 +1079,49 @@ module { // CHECK-NOT: linalg.generic // CHECK: tensor.expand_shape // CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) \ No newline at end of file +// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) + +// ----- + +// CHECK-LABEL: @drop_unused_results +// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]: tensor<64xf32>, [[ARG1:%[a-zA-Z0-9]+]]: tensor<1x56x56x64xf32> +func.func @drop_unused_results(%arg0: tensor<64xf32>, %arg1: tensor<1x56x56x64xf32>) -> tensor<1x56x56x64xf32> { + %cst = arith.constant 3.40282347E+38 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + // CHECK: [[OUT:%[a-zA-Z0-9]+]] = tensor.empty() : tensor<1x56x56x64xf32> + %0 = tensor.empty() : tensor<1x56x56x64xf32> + // CHECK: [[RES:%[0-9]+]] = linalg.generic {{.*}} ins([[ARG0]], [[ARG1]] : tensor<64xf32>, tensor<1x56x56x64xf32>) outs([[OUT]] : tensor<1x56x56x64xf32>) + %1:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<64xf32>) outs(%arg1, %0 : tensor<1x56x56x64xf32>, tensor<1x56x56x64xf32>) { + ^bb0(%in: f32, %out: f32, %out_1: f32): + %2 = arith.addf %in, %out : f32 + %3 = arith.minimumf %2, %cst : f32 + %4 = arith.maximumf %3, %cst_0 : f32 + linalg.yield %2, %4 : f32, f32 + } -> (tensor<1x56x56x64xf32>, tensor<1x56x56x64xf32>) + // CHECK: -> tensor<1x56x56x64xf32> + // CHECK: return [[RES]] : tensor<1x56x56x64xf32> + return %1#1 : tensor<1x56x56x64xf32> +} + +// ----- + +// CHECK-LABEL: @swap_drop_unused_results +// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]: tensor<64xf32>, [[ARG1:%[a-zA-Z0-9]+]]: tensor<1x56x56x64xf32> +func.func @swap_drop_unused_results(%arg0: tensor<64xf32>, %arg1: tensor<1x56x56x64xf32>) -> tensor<1x56x56x64xf32> { + %cst = arith.constant 3.40282347E+38 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + // CHECK: [[OUT:%[a-zA-Z0-9]+]] = tensor.empty() : tensor<1x56x56x64xf32> + %0 = tensor.empty() : tensor<1x56x56x64xf32> + // CHECK: [[RES:%[0-9]+]] = linalg.generic {{.*}} ins([[ARG0]] : tensor<64xf32>) outs([[OUT]] : tensor<1x56x56x64xf32>) + %1:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<64xf32>) outs(%arg1, %0 : tensor<1x56x56x64xf32>, tensor<1x56x56x64xf32>) { + ^bb0(%in: f32, %out_1: f32, %out: f32): + %2 = arith.addf %in, %out : f32 + %3 = arith.minimumf %2, %cst : f32 + %4 = arith.maximumf %3, %cst_0 : f32 + linalg.yield %2, %4 : f32, f32 + } -> (tensor<1x56x56x64xf32>, tensor<1x56x56x64xf32>) + // CHECK: -> tensor<1x56x56x64xf32> + // CHECK: return [[RES]] : tensor<1x56x56x64xf32> + return %1#0 : tensor<1x56x56x64xf32> +} +