Skip to content

Commit

Permalink
[mlir][Vector] Add vector contraction to outerproduct lowering
Browse files Browse the repository at this point in the history
This revision adds the additional lowering and exposes the patterns at a finer granularity for better programmatic reuse. The unit test makes use of the finer grained pattern for simpler checks.

As the ContractionOpLowering is exposed programmatically, cleanup opportunities appear and static class methods are turned into free functions with static visibility.

Differential Revision: https://reviews.llvm.org/D80375
  • Loading branch information
Nicolas Vasilache committed May 26, 2020
1 parent a3b5ccd commit 9578a54
Show file tree
Hide file tree
Showing 7 changed files with 598 additions and 356 deletions.
21 changes: 14 additions & 7 deletions mlir/include/mlir/Dialect/Vector/VectorOps.h
Expand Up @@ -25,13 +25,6 @@ class MLIRContext;
class OwningRewritePatternList;
namespace vector {

/// Structure to control the behavior of vector transform patterns.
struct VectorTransformsOptions {
/// Let vector.contract lower to vector.matrix_multiply and LLVM matrix
/// intrinsics.
bool lowerToLLVMMatrixIntrinsics = false;
};

/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context);
Expand All @@ -51,6 +44,20 @@ void populateVectorToVectorTransformationPatterns(
void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
MLIRContext *context);

/// Enum to control the lowering of `vector.contract` operations.
enum class VectorContractLowering {
/// Progressively lower to finer grained `vector.contract` and `vector.fma`.
FMA = 0,
/// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
Matmul = 1,
/// Lower to `vector.outerproduct`.
OuterProduct = 2,
};
/// Structure to control the behavior of vector transform patterns.
struct VectorTransformsOptions {
VectorContractLowering vectorContractLowering = VectorContractLowering::FMA;
};

/// Collect a set of transformation patterns that are related to contracting
/// or expanding vector operations:
/// ContractionOpLowering,
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Vector/VectorOps.td
Expand Up @@ -686,6 +686,11 @@ def Vector_OuterProductOp :
return %3: vector<4x8xf32>
```
}];
let builders = [
// Build an op without mask, use the type of `acc` as the return type.
OpBuilder<
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
"Value acc">];
let extraClassDeclaration = [{
VectorType getOperandVectorTypeLHS() {
return lhs().getType().cast<VectorType>();
Expand Down
117 changes: 110 additions & 7 deletions mlir/include/mlir/Dialect/Vector/VectorTransforms.h
Expand Up @@ -9,6 +9,7 @@
#ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_
#define DIALECT_VECTOR_VECTORTRANSFORMS_H_

#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir {
Expand All @@ -22,13 +23,6 @@ void populateVectorToVectorConversionPatterns(
ArrayRef<int64_t> coarseVectorShape = {},
ArrayRef<int64_t> fineVectorShape = {});

////////////////////////////////////////////////////////////////////////////////
// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite
// patterns. As such, they must not call into `rewriter.erase/replace` APIs and
// it is the responsibility of the enclosing PatternRewriter to erase on
// success.
////////////////////////////////////////////////////////////////////////////////

namespace vector {

// Entry point for unrolling declarative pattern rewrites.
Expand Down Expand Up @@ -69,6 +63,115 @@ unrollSingleResultOpMatchingType(OpBuilder &builder, Operation *op,
ArrayRef<int64_t> targetShape);

} // namespace vector

//===----------------------------------------------------------------------===//
// Finer-grained patterns exposed for more control over individual lowerings.
//===----------------------------------------------------------------------===//

/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to:
/// ```
/// %flattened_a = vector.shape_cast %a
/// %flattened_b = vector.shape_cast %b
/// %flattened_d = vector.matmul %flattened_a, %flattened_b
/// %d = vector.shape_cast %%flattened_d
/// %e = add %c, %d
/// ```
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
//
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;

ContractionOpToMatmulOpLowering(
vector::VectorTransformsOptions vectorTransformsOptions,
MLIRContext *context)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformsOptions(vectorTransformsOptions) {}

LogicalResult match(vector::ContractionOp op) const override;
void rewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;

private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformsOptions;
};

/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to a reduction_size-unrolled sequence:
/// ```
/// %at = vector.transpose %a, [1, 0]
/// %bRow0 = vector.extract %b[0]
/// %atRow0 = vector.extract %at[0]
/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
/// ...
/// %bRowK = vector.extract %b[K]
/// %atRowK = vector.extract %at[K]
/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
/// ```
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToOuterProductOpLowering
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
ContractionOpToOuterProductOpLowering(
vector::VectorTransformsOptions vectorTransformsOptions,
MLIRContext *context)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformsOptions(vectorTransformsOptions) {}

LogicalResult match(vector::ContractionOp op) const override;
void rewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;

private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformsOptions;
};

/// Progressive lowering of ContractionOp.
///
/// One:
/// %x = vector.contract with at least one free/batch dimension
/// is replaced by:
/// %a = vector.contract with one less free/batch dimension
/// %b = vector.contract with one less free/batch dimension
/// ..
/// %x = combine %a %b ..
/// until a pure contraction is reached (no free/batch dimensions),
/// which is replaced by a fma/reduction op.
///
/// This only kicks in when either VectorTransformsOptions is set to FMA or when
/// other contraction patterns fail.
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;

ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
MLIRContext *context)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformsOptions(vectorTransformsOptions) {}

LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;

private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformsOptions;
// Lower one parallel dimension.
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
int64_t rhsIndex, PatternRewriter &rewriter) const;
// Lower one reduction dimension.
Value lowerReduction(vector::ContractionOp op,
PatternRewriter &rewriter) const;
};

} // namespace mlir

#endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/Vector/VectorOps.cpp
Expand Up @@ -957,6 +957,13 @@ static LogicalResult verify(InsertStridedSliceOp op) {
// OuterProductOp
//===----------------------------------------------------------------------===//

/// Build an op without mask, use the type of `acc` as the return type.
void OuterProductOp::build(OpBuilder &builder, OperationState &result,
Value lhs, Value rhs, Value acc) {
result.addOperands({lhs, rhs, acc});
result.addTypes(acc.getType());
}

static void print(OpAsmPrinter &p, OuterProductOp op) {
p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs();
if (!op.acc().empty())
Expand Down

0 comments on commit 9578a54

Please sign in to comment.