[vector][multi_reduction] Add unrolling for vector.multi_reduction#185033
[vector][multi_reduction] Add unrolling for vector.multi_reduction#185033amd-eochoalo wants to merge 6 commits into
Conversation
Stated previously: > Below are some notes on naming > * RewriteXAsY or LowerXToY (if we're changing op kind) > > Concretely, TwoDimMultiReductionToReduction looks like a lowering > (it rewrites to vector.reduction), not an unrolling, llvm#182301 (comment)
When unrolling vector.multi_reduction with outermost dimensions being reduced, exctract outermost dimension vectors and chain them with vector.multi_reductions.
Unrolls multi_reduction by series of extractions, multi_reduction and insertions.
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-llvm Author: Erick Ochoa Lopez (amd-eochoalo) Changes
The benefits of vector.multi_reduction expand, flattening, unrolling and lowering patterns are such that they preserve existing behaviour. The benefit of adding unrolling is that now we can unroll without going through expand or flattening. This allows different backends finer control over which patterns to apply. For example, to better handle differences between SPIR-V and LLVM lowerings. This will break downstream projects. All you need to do is apply multi_reduction_lowering where multi_reduction_unrolling was applied before. Patch is 28.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/185033.diff 13 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index dcd5f6ff3ad74..b91bd09053c7e 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -266,14 +266,37 @@ def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
"apply_patterns.vector.multi_reduction_unrolling",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Indicates that vector multi_reduction operations should be unrolled.
+ Indicates that vector multi_reduction operations with more than one
+ reduction dimension should be unrolled in a rank-reducing way.
+
+ This populates the patterns from
+ `populateVectorMultiReductionUnrollingPatterns`, i.e.:
+ * `UnrollMultiReductionInnerReduction` (inner_reduction)
+ * `UnrollMultiReductionInnerParallel` (inner_parallel)
+ }];
+
+ let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+ "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+ );
+
+ let assemblyFormat = [{
+ (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+ }];
+}
+
+
+def ApplyMultiReductionLoweringPatternsOp: Op<Transform_Dialect,
+ "apply_patterns.vector.multi_reduction_lowering",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector multi_reduction operations should be lowered.
1-D multi_reductions are converted directly to vector.reduction.
2-D multi_reductions are unrolled into either a sequence of
vector.reduction ops (innerreduction) or element-wise arith ops
(innerparallel).
This populates the patterns from
- `populateVectorMultiReductionUnrollingPatterns`, i.e.:
+ `populateVectorMultiReductionLoweringPatterns`, i.e.:
* `OneDimMultiReductionToReduction`
* `TwoDimMultiReductionToReduction` (innerreduction)
* `TwoDimMultiReductionToElementWise` (innerparallel)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index aa75eff409ef9..f9eeaf4136370 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -66,9 +66,12 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
/// Rewrites vector.multi_reduction such that all reduction dimensions are
/// either innermost or outermost, by adding the proper vector.transpose
/// operations.
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
void populateVectorMultiReductionReorderPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit = 1);
+ PatternBenefit benefit = 2);
/// Populate the pattern set with the following patterns:
///
@@ -77,7 +80,23 @@ void populateVectorMultiReductionReorderPatterns(
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
/// back.
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
void populateVectorMultiReductionFlatteningPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit = 2);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollMultiReductionInnerReduction]
+/// Extracts vectors along outer dimension
+/// and chains multiple multi_reduction operations.
+///
+/// [UnrollMultiReductionInnerParallel]
+/// Extracts vectors along outer dimension, performs multi_reduction operations
+/// and inserts them back to a vector.multi_reduction with a lower rank.
+void populateVectorMultiReductionUnrollingPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
@@ -96,9 +115,12 @@ void populateVectorMultiReductionFlatteningPatterns(
/// dimension, unroll the outer dimension to obtain a sequence of extract +
/// vector.reduction + insert. This can further lower to horizontal reduction
/// ops.
-void populateVectorMultiReductionUnrollingPatterns(
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
+void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
- PatternBenefit benefit = 1);
+ PatternBenefit benefit = 2);
/// Populate the pattern set with the following patterns:
///
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 312bd28ad48cf..55953aa22dd88 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -153,6 +153,14 @@ void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
+void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::VectorTransformsOptions vectorTransformOptions;
+ vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+ vector::populateVectorMultiReductionLoweringPatterns(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorOuterProductLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 76599822fbfe4..3f2b399e55343 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -443,6 +443,196 @@ struct TwoDimMultiReductionToReduction
}
};
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// Matches when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+/// ```
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
+struct UnrollMultiReductionInnerParallel
+ : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
+ if (!multiReductionOp.isReducedDim(0))
+ return failure();
+
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+ if (reductionDims.size() <= 1)
+ return failure();
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t outerDimSize = srcShape.front();
+
+ Value mask = maskingOp ? maskingOp.getMask() : nullptr;
+
+ SmallVector<Value> vectors(outerDimSize);
+ for (int64_t i = 0; i < outerDimSize; ++i)
+ vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i);
+
+ SmallVector<Value> masks(outerDimSize);
+ if (mask)
+ for (int64_t i = 0; i < outerDimSize; ++i)
+ masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);
+
+ SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
+ ArrayRef<bool> reductionMask =
+ ArrayRef<bool>(fullReductionMask).drop_front();
+ Value result = multiReductionOp.getAcc();
+ for (auto [innerVector, innerMask] : llvm::zip_equal(vectors, masks)) {
+ auto reductionOp = vector::MultiDimReductionOp::create(
+ rewriter, loc, innerVector, result, reductionMask,
+ multiReductionOp.getKind());
+
+ if (innerMask) {
+ Operation *maskOp =
+ vector::maskOperation(rewriter, reductionOp, innerMask);
+ result = maskOp->getResult(0);
+ } else {
+ result = reductionOp.getResult();
+ }
+ }
+
+ return result;
+ }
+};
+
+/// Unrolls vector.multi_reduction along the outermost parallel dimension
+/// when the innermost dimension is a reduction dimension.
+///
+/// This pattern matches operations where:
+/// - The innermost dimension is a reduction dimension
+/// - The outermost dimension is a parallel dimension
+///
+/// The transformation extracts slices along the outermost parallel dimension,
+/// creates smaller multi_reductions, and assembles the results:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [1, 3]
+/// : vector<AxBxCxDxf32> to vector<AxCxf32>
+/// ```
+///
+/// becomes:
+///
+/// ```mlir
+/// %result = arith.constant dense<0.0> : vector<AxCxf32>
+/// %0 = vector.extract %src[0] : vector<BxCxDxf32> from vector<AxBxCxDxf32>
+/// %acc0 = vector.extract %acc[0] : vector<Cxf32> from vector<AxCxf32>
+/// %red0 = vector.multi_reduction <add>, %0, %acc0 [0, 2]
+/// : vector<BxCxDxf32> to vector<Cxf32>
+/// %res0 = vector.insert %red0, %result[0]
+/// : vector<Cxf32> into vector<AxCxf32>
+/// // ... repeat for indices 1 to A-1
+/// ```
+struct UnrollMultiReductionInnerReduction
+ : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
+ auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+
+ if (srcRank < 2)
+ return rewriter.notifyMatchFailure(multiReductionOp,
+ "expected source rank >= 2.");
+
+ if (!multiReductionOp.isReducedDim(srcRank - 1))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected innermost dimension to be a reduction dimension.");
+
+ if (multiReductionOp.isReducedDim(0))
+ return rewriter.notifyMatchFailure(
+ multiReductionOp,
+ "expected outermost dimension to be a parallel dimension.");
+
+ Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+ if (!elementType.isIntOrIndexOrFloat())
+ return rewriter.notifyMatchFailure(
+ multiReductionOp, "expected integer or float element type.");
+
+ Location loc = multiReductionOp.getLoc();
+ Value source = multiReductionOp.getSource();
+ Value acc = multiReductionOp.getAcc();
+
+ ArrayRef<int64_t> srcShape =
+ multiReductionOp.getSourceVectorType().getShape();
+ int64_t numSlices = srcShape.front();
+
+ Value mask = maskingOp ? maskingOp.getMask() : Value();
+
+ SmallVector<Value> srcSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ srcSlices.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+ SmallVector<Value> accSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ accSlices.push_back(vector::ExtractOp::create(rewriter, loc, acc, i));
+
+ SmallVector<Value> maskSlices;
+ for (int64_t i = 0; i < numSlices; ++i)
+ if (mask)
+ maskSlices.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+ else
+ maskSlices.push_back(nullptr);
+
+ // Compute new reduction mask by dropping the first element (dimension 0).
+ // Since dimension 0 is parallel (not reduced), all reduction indices shift
+ // down by 1.
+ SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
+ ArrayRef<bool> newReductionMask =
+ ArrayRef<bool>(fullReductionMask).drop_front();
+
+ SmallVector<Value> reductionResults;
+ for (auto [srcSlice, accSlice, maskSlice] :
+ llvm::zip(srcSlices, accSlices, maskSlices)) {
+ Operation *newReductionOp = vector::MultiDimReductionOp::create(
+ rewriter, loc, srcSlice, accSlice, newReductionMask,
+ multiReductionOp.getKind());
+
+ if (maskSlice)
+ newReductionOp =
+ mlir::vector::maskOperation(rewriter, newReductionOp, maskSlice);
+
+ reductionResults.push_back(newReductionOp->getResult(0));
+ }
+
+ Value result = arith::ConstantOp::create(
+ rewriter, loc, multiReductionOp.getDestType(),
+ rewriter.getZeroAttr(multiReductionOp.getDestType()));
+
+ for (int64_t i = 0; i < numSlices; ++i)
+ result = vector::InsertOp::create(rewriter, loc, reductionResults[i],
+ result, i);
+
+ return result;
+ }
+};
+
/// Converts 1D vector.multi_reduction directly to vector.reduction.
///
/// Example:
@@ -496,19 +686,13 @@ struct LowerVectorMultiReductionPass
RewritePatternSet patterns(context);
mlir::vector::populateVectorMultiReductionReorderPatterns(
patterns, this->loweringStrategy);
- if (failed(applyPatternsGreedily(op, std::move(patterns))))
- signalPassFailure();
-
- RewritePatternSet flatteningPatterns(context);
mlir::vector::populateVectorMultiReductionFlatteningPatterns(
- flatteningPatterns, this->loweringStrategy);
- if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
- signalPassFailure();
-
- RewritePatternSet unrollingPatterns(context);
+ patterns, this->loweringStrategy);
mlir::vector::populateVectorMultiReductionUnrollingPatterns(
- unrollingPatterns, this->loweringStrategy);
- if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
+ patterns, this->loweringStrategy);
+ mlir::vector::populateVectorMultiReductionLoweringPatterns(
+ patterns, this->loweringStrategy);
+ if (failed(applyPatternsGreedily(op, std::move(patterns))))
signalPassFailure();
}
@@ -532,7 +716,7 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
patterns.add<FlattenMultiReduction>(patterns.getContext(), options, benefit);
}
-void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
@@ -544,6 +728,17 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
benefit);
}
+void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+ RewritePatternSet &patterns, VectorMultiReductionLowering options,
+ PatternBenefit benefit) {
+ if (options == VectorMultiReductionLowering ::InnerReduction)
+ patterns.add<UnrollMultiReductionInnerReduction>(patterns.getContext(),
+ benefit);
+ else
+ patterns.add<UnrollMultiReductionInnerParallel>(patterns.getContext(),
+ benefit);
+}
+
std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
vector::VectorMultiReductionLowering option) {
return std::make_unique<LowerVectorMultiReductionPass>(option);
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index bf7eba6e50174..b4cb4b75a5f4d 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -33,6 +33,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 4dc11c26e83f1..b8729d06c9cfc 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -42,6 +42,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
} : !transform.any_op
transform.apply_patterns to %f {
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
index 447416ccba637..a7ec4b9206c33 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
@@ -117,7 +117,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
}
@@ -125,7 +125,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+ transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
} : !transform.op<"func.func">
transform.yield
}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir
new file mode 100644
index 0000000000000..4f67d5678d6d3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: func @unroll_multi_reduction_inner_parallel
+// CHECK-SAME: %[[INPUT:.+]]: vector<4x2x3xf32>, %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_multi_reduction_inner_parallel(%arg0: vector<4x2x3xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+ // CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2x3xf32> from vector<4x2x3xf32>
+ // CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2x3xf32> from vector<4x2x3xf32>
+ // CHECK: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2x3xf32> from vect...
[truncated]
|
|
Closing due to change of direction. |
vector.multi_reduction's unrolling patterns are currently a misnomer. They don't unroll, but instead lower multi_reduction to either arith operations orvector.reduction. This PR:multi_reduction_unrollingtomulti_reduction_loweringmulti_reduction_unrollingwhere actual unrolling happens.The benefits of vector.multi_reduction expand, flattening, unrolling and lowering patterns are such that they preserve existing behaviour. The benefit of adding unrolling is that now we can unroll without going through expand or flattening. This allows different backends finer control over which patterns to apply. For example, to better handle differences between SPIR-V and LLVM lowerings.
This will break downstream projects. All you need to do is apply multi_reduction_lowering where multi_reduction_unrolling was applied before.