Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,37 @@ def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
"apply_patterns.vector.multi_reduction_unrolling",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that vector multi_reduction operations should be unrolled.
Indicates that vector multi_reduction operations with more than one
reduction dimension should be unrolled in a rank-reducing way.

This populates the patterns from
`populateVectorMultiReductionUnrollingPatterns`, i.e.:
* `UnrollMultiReductionInnerReduction` (inner_reduction)
* `UnrollMultiReductionInnerParallel` (inner_parallel)
}];

let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
"vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
);

let assemblyFormat = [{
(`lowering_strategy` `=` $lowering_strategy^)? attr-dict
}];
}


def ApplyMultiReductionLoweringPatternsOp: Op<Transform_Dialect,
"apply_patterns.vector.multi_reduction_lowering",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Indicates that vector multi_reduction operations should be lowered.
1-D multi_reductions are converted directly to vector.reduction.
2-D multi_reductions are unrolled into either a sequence of
vector.reduction ops (innerreduction) or element-wise arith ops
(innerparallel).

This populates the patterns from
`populateVectorMultiReductionUnrollingPatterns`, i.e.:
`populateVectorMultiReductionLoweringPatterns`, i.e.:
* `OneDimMultiReductionToReduction`
* `TwoDimMultiReductionToReduction` (innerreduction)
* `TwoDimMultiReductionToElementWise` (innerparallel)
Expand Down
28 changes: 25 additions & 3 deletions mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
/// Rewrites vector.multi_reduction such that all reduction dimensions are
/// either innermost or outermost, by adding the proper vector.transpose
/// operations.
///
/// The benefit is set to be higher than the unrolling patterns. Otherwise
/// patterns here would match the same operations as those in unrolling.
void populateVectorMultiReductionReorderPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
PatternBenefit benefit = 2);

/// Populate the pattern set with the following patterns:
///
Expand All @@ -77,7 +80,23 @@ void populateVectorMultiReductionReorderPatterns(
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
/// back.
///
/// The benefit is set to be higher than the unrolling patterns. Otherwise
/// patterns here would match the same operations as those in unrolling.
void populateVectorMultiReductionFlatteningPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 2);

/// Populate the pattern set with the following patterns:
///
/// [UnrollMultiReductionInnerReduction]
/// Extracts vectors along outer dimension
/// and chains multiple multi_reduction operations.
///
/// [UnrollMultiReductionInnerParallel]
/// Extracts vectors along outer dimension, performs multi_reduction operations
/// and inserts them back to a vector.multi_reduction with a lower rank.
void populateVectorMultiReductionUnrollingPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);

Expand All @@ -96,9 +115,12 @@ void populateVectorMultiReductionFlatteningPatterns(
/// dimension, unroll the outer dimension to obtain a sequence of extract +
/// vector.reduction + insert. This can further lower to horizontal reduction
/// ops.
void populateVectorMultiReductionUnrollingPatterns(
///
/// The benefit is set to be higher than the unrolling patterns. Otherwise
/// patterns here would match the same operations as those in unrolling.
void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit = 1);
PatternBenefit benefit = 2);

/// Populate the pattern set with the following patterns:
///
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}

void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::VectorTransformsOptions vectorTransformOptions;
vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
vector::populateVectorMultiReductionLoweringPatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}

void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorOuterProductLoweringPatterns(patterns);
Expand Down
219 changes: 207 additions & 12 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,196 @@ struct TwoDimMultiReductionToReduction
}
};

/// Unrolls outermost dimension for vector.multi_reduction.
/// Matches when the outermost dimension is not the only
/// reduction dimension.
///
/// ```mlir
/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
/// vector<NxMx...xf32> to vector<Ix...xf32>
/// ```
///
/// ```mlir
/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
/// ...
/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
/// vector<NxMx...xf32>
///
/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
/// vector<Mx...xf32> to vector<Ix...xf32>
/// ...
/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
/// vector<Mx...xf32> to vector<Ix...xf32>
/// ```
struct UnrollMultiReductionInnerParallel
: public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;

FailureOr<Value>
matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
vector::MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
if (!multiReductionOp.isReducedDim(0))
return failure();

ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
if (reductionDims.size() <= 1)
return failure();

Location loc = multiReductionOp.getLoc();
Value source = multiReductionOp.getSource();

ArrayRef<int64_t> srcShape =
multiReductionOp.getSourceVectorType().getShape();
int64_t outerDimSize = srcShape.front();

Value mask = maskingOp ? maskingOp.getMask() : nullptr;

SmallVector<Value> vectors(outerDimSize);
for (int64_t i = 0; i < outerDimSize; ++i)
vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i);

SmallVector<Value> masks(outerDimSize);
if (mask)
for (int64_t i = 0; i < outerDimSize; ++i)
masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);

SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
ArrayRef<bool> reductionMask =
ArrayRef<bool>(fullReductionMask).drop_front();
Value result = multiReductionOp.getAcc();
for (auto [innerVector, innerMask] : llvm::zip_equal(vectors, masks)) {
auto reductionOp = vector::MultiDimReductionOp::create(
rewriter, loc, innerVector, result, reductionMask,
multiReductionOp.getKind());

if (innerMask) {
Operation *maskOp =
vector::maskOperation(rewriter, reductionOp, innerMask);
result = maskOp->getResult(0);
} else {
result = reductionOp.getResult();
}
}

return result;
}
};

/// Unrolls vector.multi_reduction along the outermost parallel dimension
/// when the innermost dimension is a reduction dimension.
///
/// This pattern matches operations where:
/// - The innermost dimension is a reduction dimension
/// - The outermost dimension is a parallel dimension
///
/// The transformation extracts slices along the outermost parallel dimension,
/// creates smaller multi_reductions, and assembles the results:
///
/// ```mlir
/// %res = vector.multi_reduction <add> %src, %acc [1, 3]
/// : vector<AxBxCxDxf32> to vector<AxCxf32>
/// ```
///
/// becomes:
///
/// ```mlir
/// %result = arith.constant dense<0.0> : vector<AxCxf32>
/// %0 = vector.extract %src[0] : vector<BxCxDxf32> from vector<AxBxCxDxf32>
/// %acc0 = vector.extract %acc[0] : vector<Cxf32> from vector<AxCxf32>
/// %red0 = vector.multi_reduction <add>, %0, %acc0 [0, 2]
/// : vector<BxCxDxf32> to vector<Cxf32>
/// %res0 = vector.insert %red0, %result[0]
/// : vector<Cxf32> into vector<AxCxf32>
/// // ... repeat for indices 1 to A-1
/// ```
struct UnrollMultiReductionInnerReduction
: public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;

FailureOr<Value>
matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
vector::MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
auto srcRank = multiReductionOp.getSourceVectorType().getRank();

if (srcRank < 2)
return rewriter.notifyMatchFailure(multiReductionOp,
"expected source rank >= 2.");

if (!multiReductionOp.isReducedDim(srcRank - 1))
return rewriter.notifyMatchFailure(
multiReductionOp,
"expected innermost dimension to be a reduction dimension.");

if (multiReductionOp.isReducedDim(0))
return rewriter.notifyMatchFailure(
multiReductionOp,
"expected outermost dimension to be a parallel dimension.");

Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
if (!elementType.isIntOrIndexOrFloat())
return rewriter.notifyMatchFailure(
multiReductionOp, "expected integer or float element type.");

Location loc = multiReductionOp.getLoc();
Value source = multiReductionOp.getSource();
Value acc = multiReductionOp.getAcc();

ArrayRef<int64_t> srcShape =
multiReductionOp.getSourceVectorType().getShape();
int64_t numSlices = srcShape.front();

Value mask = maskingOp ? maskingOp.getMask() : Value();

SmallVector<Value> srcSlices;
for (int64_t i = 0; i < numSlices; ++i)
srcSlices.push_back(vector::ExtractOp::create(rewriter, loc, source, i));

SmallVector<Value> accSlices;
for (int64_t i = 0; i < numSlices; ++i)
accSlices.push_back(vector::ExtractOp::create(rewriter, loc, acc, i));

SmallVector<Value> maskSlices;
for (int64_t i = 0; i < numSlices; ++i)
if (mask)
maskSlices.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
else
maskSlices.push_back(nullptr);

// Compute new reduction mask by dropping the first element (dimension 0).
// Since dimension 0 is parallel (not reduced), all reduction indices shift
// down by 1.
SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
ArrayRef<bool> newReductionMask =
ArrayRef<bool>(fullReductionMask).drop_front();

SmallVector<Value> reductionResults;
for (auto [srcSlice, accSlice, maskSlice] :
llvm::zip(srcSlices, accSlices, maskSlices)) {
Operation *newReductionOp = vector::MultiDimReductionOp::create(
rewriter, loc, srcSlice, accSlice, newReductionMask,
multiReductionOp.getKind());

if (maskSlice)
newReductionOp =
mlir::vector::maskOperation(rewriter, newReductionOp, maskSlice);

reductionResults.push_back(newReductionOp->getResult(0));
}

Value result = arith::ConstantOp::create(
rewriter, loc, multiReductionOp.getDestType(),
rewriter.getZeroAttr(multiReductionOp.getDestType()));

for (int64_t i = 0; i < numSlices; ++i)
result = vector::InsertOp::create(rewriter, loc, reductionResults[i],
result, i);

return result;
}
};

/// Converts 1D vector.multi_reduction directly to vector.reduction.
///
/// Example:
Expand Down Expand Up @@ -496,19 +686,13 @@ struct LowerVectorMultiReductionPass
RewritePatternSet patterns(context);
mlir::vector::populateVectorMultiReductionReorderPatterns(
patterns, this->loweringStrategy);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
signalPassFailure();

RewritePatternSet flatteningPatterns(context);
mlir::vector::populateVectorMultiReductionFlatteningPatterns(
flatteningPatterns, this->loweringStrategy);
if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
signalPassFailure();

RewritePatternSet unrollingPatterns(context);
patterns, this->loweringStrategy);
mlir::vector::populateVectorMultiReductionUnrollingPatterns(
unrollingPatterns, this->loweringStrategy);
if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
patterns, this->loweringStrategy);
mlir::vector::populateVectorMultiReductionLoweringPatterns(
patterns, this->loweringStrategy);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
signalPassFailure();
}

Expand All @@ -532,7 +716,7 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
patterns.add<FlattenMultiReduction>(patterns.getContext(), options, benefit);
}

void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
Expand All @@ -544,6 +728,17 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
benefit);
}

void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options,
PatternBenefit benefit) {
if (options == VectorMultiReductionLowering ::InnerReduction)
patterns.add<UnrollMultiReductionInnerReduction>(patterns.getContext(),
benefit);
else
patterns.add<UnrollMultiReductionInnerParallel>(patterns.getContext(),
benefit);
}

std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
vector::VectorMultiReductionLowering option) {
return std::make_unique<LowerVectorMultiReductionPass>(option);
Expand Down
1 change: 1 addition & 0 deletions mlir/test/Dialect/LLVM/transform-e2e.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1
Expand Down
1 change: 1 addition & 0 deletions mlir/test/Dialect/Vector/transform-vector.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
} : !transform.any_op

transform.apply_patterns to %f {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,15 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
} : !transform.op<"func.func">
transform.yield
}

transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
} : !transform.op<"func.func">
transform.yield
}
Expand Down
Loading