Skip to content

Commit

Permalink
[mlir][Vector] NFC - Add option to hook vector.transpose lowering to …
Browse files Browse the repository at this point in the history
…strategies.

This revision also moves some code around to improve overall structure.

Differential Revision: https://reviews.llvm.org/D112437
  • Loading branch information
nicolasvasilache committed Oct 25, 2021
1 parent 3b1165b commit d054b80
Show file tree
Hide file tree
Showing 12 changed files with 482 additions and 441 deletions.
44 changes: 28 additions & 16 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Expand Up @@ -15,7 +15,7 @@
#include "mlir/Dialect/SCF/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/Bufferize.h"
Expand Down Expand Up @@ -846,6 +846,9 @@ struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
: LinalgBaseVectorizationPattern(opName, context, filter, benefit) {}
};

//===----------------------------------------------------------------------===//
// Transformation and lowering options exposed as auxiliary structs.
//===----------------------------------------------------------------------===//
/// Options to control the application of enabling transformations.
/// Hoisting transformations are always deemed beneficial and must be disabled
/// explicitly.
Expand Down Expand Up @@ -887,10 +890,16 @@ struct LinalgVectorLoweringOptions {
transferLowering = val;
return *this;
}
/// Trigger full / partial vector.transfer splits.
bool transferPartialRewrite = false;
LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
transferPartialRewrite = val;
/// Enable lowering of vector.transpose.
bool transposeLowering = false;
LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) {
transposeLowering = val;
return *this;
}
/// Enable lowering of vector.multi_reduce.
bool multiReductionLowering = false;
LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) {
multiReductionLowering = val;
return *this;
}
/// Enable lowering of vector.contract.
Expand All @@ -899,10 +908,10 @@ struct LinalgVectorLoweringOptions {
contractionLowering = val;
return *this;
}
/// Enable lowering of vector.multi_reduce.
bool multiReductionLowering = false;
LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) {
multiReductionLowering = val;
/// Trigger full / partial vector.transfer splits.
bool transferPartialRewrite = false;
LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
transferPartialRewrite = val;
return *this;
}
/// Enable lowering of vector.transfer to scf.
Expand All @@ -911,13 +920,6 @@ struct LinalgVectorLoweringOptions {
transferToSCFConversion = val;
return *this;
}
/// Configure late vector transformations.
vector::VectorTransformsOptions vectorTransformOptions;
LinalgVectorLoweringOptions &
setVectorTransformsOptions(vector::VectorTransformsOptions options) {
vectorTransformOptions = options;
return *this;
}
/// Configure the post staged-patterns late vector.transfer to scf
/// conversion.
VectorTransferToSCFOptions vectorTransferToSCFOptions;
Expand All @@ -926,8 +928,18 @@ struct LinalgVectorLoweringOptions {
vectorTransferToSCFOptions = options;
return *this;
}
/// Configure late vector transformations.
vector::VectorTransformsOptions vectorTransformOptions;
LinalgVectorLoweringOptions &
setVectorTransformsOptions(vector::VectorTransformsOptions options) {
vectorTransformOptions = options;
return *this;
}
};

//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
/// Trait to check if T provides a `getOperationName` method.
template <typename T, typename... Args>
using has_get_operation_name = decltype(T::getOperationName());
Expand Down
112 changes: 0 additions & 112 deletions mlir/include/mlir/Dialect/Vector/VectorOps.h
Expand Up @@ -40,76 +40,6 @@ namespace detail {
struct BitmaskEnumStorage;
} // namespace detail

/// Enum to control the lowering of `vector.contract` operations.
enum class VectorContractLowering {
/// Progressively lower to finer grained `vector.contract` and dot-products.
Dot = 0,
/// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
Matmul = 1,
/// Lower to `vector.outerproduct`.
OuterProduct = 2,
};
/// Enum to control the lowering of `vector.multi_reduction` operations.
enum class VectorMultiReductionLowering {
/// Lower multi_reduction into outer-reduction and inner-parallel ops.
InnerParallel = 0,
/// Lower multi_reduction into outer-parallel and inner-reduction ops.
InnerReduction = 1,
};
/// Enum to control the lowering of `vector.transpose` operations.
enum class VectorTransposeLowering {
/// Lower transpose into element-wise extract and inserts.
EltWise = 0,
/// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
/// intrinsics.
Flat = 1,
};
/// Enum to control the splitting of `vector.transfer` operations into
/// in-bounds and out-of-bounds variants.
enum class VectorTransferSplit {
/// Do not split vector transfer operations.
None = 0,
/// Split using in-bounds + out-of-bounds vector.transfer operations.
VectorTransfer = 1,
/// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
/// operations.
LinalgCopy = 2,
/// Do not split vector transfer operation but instead mark it as "in-bounds".
ForceInBounds = 3
};
/// Structure to control the behavior of vector transform patterns.
struct VectorTransformsOptions {
/// Option to control the lowering of vector.contract.
VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
VectorTransformsOptions &
setVectorTransformsOptions(VectorContractLowering opt) {
vectorContractLowering = opt;
return *this;
}
/// Option to control the lowering of vector.multi_reduction.
VectorMultiReductionLowering vectorMultiReductionLowering =
VectorMultiReductionLowering::InnerParallel;
VectorTransformsOptions &
setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
vectorMultiReductionLowering = opt;
return *this;
}
/// Option to control the lowering of vector.transpose.
VectorTransposeLowering vectorTransposeLowering =
VectorTransposeLowering::EltWise;
VectorTransformsOptions &
setVectorTransposeLowering(VectorTransposeLowering opt) {
vectorTransposeLowering = opt;
return *this;
}
/// Option to control the splitting of vector transfers.
VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
vectorTransferSplit = opt;
return *this;
}
};

/// Return whether `srcType` can be broadcast to `dstVectorType` under the
/// semantics of the `vector.broadcast` op.
enum class BroadcastableToResult {
Expand Down Expand Up @@ -161,33 +91,6 @@ void populateVectorTransferPermutationMapLoweringPatterns(
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
bool enableIndexOptimizations);

/// Collect a set of patterns to convert vector.multi_reduction op into
/// a sequence of vector.reduction ops. The patterns comprise:
/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
/// that all reduction dimensions are either innermost or outermost, by adding
/// the proper vector.transpose operations.
/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
/// back.
/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
/// form, with an **outermost** reduction dimension, unroll the outer dimension
/// to obtain a sequence of 1-D vector ops. This also has an opportunity for
/// tree-reduction (in the future).
/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
/// with an **innermost** reduction dimension, unroll the outer dimension to
/// obtain a sequence of extract + vector.reduction + insert. This can further
/// lower to horizontal reduction ops.
/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
/// reduction (and are thus missing either a parallel or a reduction), we lift
/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
/// the other patterns can kick in, thus fully exiting out of the
/// vector.multi_reduction abstraction.
void populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns,
VectorMultiReductionLowering options =
vector::VectorMultiReductionLowering::InnerParallel);

/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
/// chain.
void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
Expand All @@ -212,12 +115,6 @@ class CombiningKindAttr
/// vectors to low-D vector ops.
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);

/// Collects patterns to progressively lower vector contraction ops on high-D
/// into low-D reduction and product ops.
void populateVectorContractLoweringPatterns(
RewritePatternSet &patterns,
VectorTransformsOptions options = VectorTransformsOptions());

/// Collects patterns to progressively lower vector mask ops into elementary
/// selection and insertion ops.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
Expand All @@ -227,15 +124,6 @@ void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
/// ops.
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns);

/// Insert TransposeLowering patterns into extraction/insertion.
void populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
VectorTransformsOptions options = VectorTransformsOptions());

/// Collect patterns to convert reduction op to vector.contract and fold
/// transpose/broadcast ops into the contract.
void populateVetorReductionToContractPatterns(RewritePatternSet &patterns);

/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);

Expand Down

0 comments on commit d054b80

Please sign in to comment.