Skip to content

Commit

Permalink
[mlir][Vector] Add a LowerVectorsOp to VectorTransformOps
Browse files Browse the repository at this point in the history
This op significantly improves transfor dialect usage when using vector abstractions.
It also brings us closer to writing simple end-to-end unit tests that guard against subtle regressions in how patterns combine.

Differential Revision: https://reviews.llvm.org/D138723
  • Loading branch information
nicolasvasilache committed Nov 25, 2022
1 parent c757780 commit 6ff1233
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,31 @@ include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def LowerVectorsOp : Op<Transform_Dialect, "vector.lower_vectors",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let description = [{
Indicates that the vector operations nested under the isolated from above op
`target` should be lowered to finer-grained vector primitives.

At this time, the transform is all or nothing.

This is usally a late step that is run after bufferization as part of the
process of lowering to e.g. LLVM or NVVM.
}];

// TODO: evolve this to proper enums.
let arguments = (ins PDL_Operation:$target,
DefaultValuedAttr<StrAttr, "\"outerproduct\"">:$contraction_lowering,
DefaultValuedAttr<StrAttr, "\"innerparallel\"">:$multireduction_lowering,
DefaultValuedAttr<StrAttr, "\"linalg-copy\"">:$split_transfers,
DefaultValuedAttr<StrAttr, "\"eltwise\"">:$transpose_lowering,
DefaultValuedAttr<BoolAttr, "false">:$transpose_avx2_lowering,
DefaultValuedAttr<BoolAttr, "true">:$unroll_vector_transfers
);
let results = (outs PDL_Operation:$results);

let assemblyFormat = "$target attr-dict";
}

#endif // VECTOR_TRANSFORM_OPS
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ add_mlir_dialect_library(MLIRVectorTransformOps
MLIRSideEffectInterfaces
MLIRTransformDialect
MLIRVectorDialect
MLIRVectorToSCF
MLIRX86VectorTransforms
)
125 changes: 125 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,143 @@

#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"

#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace mlir::vector;
using namespace mlir::transform;

//===----------------------------------------------------------------------===//
// LowerVectorsOp
//===----------------------------------------------------------------------===//

void transform::LowerVectorsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTarget(), effects);
producesHandle(getResults(), effects);
modifiesPayload(effects);
}

DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {

SmallVector<Operation *> results;
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
for (Operation *target : payloadOps) {
// This check can't be part of the verifier because payload IR is
// independent from transform IR and may not even exist.
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
return mlir::emitDefiniteFailure(target,
"applies only to isolated-from-above "
"targets because it needs to apply "
"patterns greedily");
}

MLIRContext *ctx = getContext();
RewritePatternSet patterns(ctx);
vector::VectorTransposeLowering vectorTransposeLowering =
llvm::StringSwitch<vector::VectorTransposeLowering>(
getTransposeLowering())
.Case("eltwise", vector::VectorTransposeLowering::EltWise)
.Case("flat_transpose", vector::VectorTransposeLowering::Flat)
.Case("shuffle", vector::VectorTransposeLowering::Shuffle)
.Default(vector::VectorTransposeLowering::EltWise);
vector::VectorMultiReductionLowering vectorMultiReductionLowering =
llvm::StringSwitch<vector::VectorMultiReductionLowering>(
getMultireductionLowering())
.Case("innerreduction",
vector::VectorMultiReductionLowering::InnerReduction)
.Default(vector::VectorMultiReductionLowering::InnerParallel);
vector::VectorContractLowering vectorContractLowering =
llvm::StringSwitch<vector::VectorContractLowering>(
getContractionLowering())
.Case("matrixintrinsics", vector::VectorContractLowering::Matmul)
.Case("dot", vector::VectorContractLowering::Dot)
.Case("outerproduct", vector::VectorContractLowering::OuterProduct)
.Default(vector::VectorContractLowering::OuterProduct);
vector::VectorTransferSplit vectorTransferSplit =
llvm::StringSwitch<vector::VectorTransferSplit>(getSplitTransfers())
.Case("none", vector::VectorTransferSplit::None)
.Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy)
.Case("vector-transfers",
vector::VectorTransferSplit::VectorTransfer)
.Default(vector::VectorTransferSplit::None);

vector::VectorTransformsOptions vectorTransformOptions;
vectorTransformOptions.setVectorTransformsOptions(vectorContractLowering)
.setVectorMultiReductionLowering(vectorMultiReductionLowering)
.setVectorTransposeLowering(vectorTransposeLowering)
.setVectorTransferSplit(vectorTransferSplit);

VectorTransferToSCFOptions vectorTransferToSCFOptions =
VectorTransferToSCFOptions()
.enableFullUnroll(getUnrollVectorTransfers())
.enableLowerPermutationMaps();

int maxTransferRank = 1;

auto avx2LoweringOptions =
x86vector::avx2::LoweringOptions().setTransposeOptions(
x86vector::avx2::TransposeLoweringOptions()
.lower4x8xf32(getTransposeAvx2Lowering())
.lower8x8xf32(getTransposeAvx2Lowering()));

vector::populateVectorToVectorCanonicalizationPatterns(patterns);

// In the future we may want to more finely select particular stages.
// Stage 1: contraction lowerings.
patterns.add<mlir::vector::ContractionOpToOuterProductOpLowering,
mlir::vector::ContractionOpToMatmulOpLowering,
mlir::vector::ContractionOpLowering>(vectorTransformOptions,
ctx);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);

// Stage 2: multi-reduction lowerings.
vector::populateVectorMultiReductionLoweringPatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);

// Stage 3: Rewrite vector.transfer into full and partial parts.
patterns.add<vector::VectorTransferFullPartialRewriter>(
ctx, vectorTransformOptions);

// Stage 4: Lower vector transfers.
vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank);

// Stage 5: Vector to scf patterns.
populateVectorToSCFConversionPatterns(
patterns, vectorTransferToSCFOptions.setTargetRank(maxTransferRank));

// Stage 6: Lower vector.shape_cast.
vector::populateVectorShapeCastLoweringPatterns(patterns);

// Stage 7: Lower vector.transpose.
vector::populateVectorTransposeLoweringPatterns(patterns,
vectorTransformOptions);
if (getTransposeAvx2Lowering())
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
patterns, avx2LoweringOptions, /*benefit=*/10);

// Apply everything.
if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
return DiagnosedSilenceableFailure::definiteFailure();

results.push_back(target);
}

transformResults.set(getResults().cast<OpResult>(), results);
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Vector/transform-vector.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s

// CHECK-LABEL: func @matmul_tensors
func.func @matmul_tensors(
%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>, %arg2: tensor<8x32xf32>)
-> tensor<8x32xf32> {
// CHECK-NOT: linalg
// CHECK: vector.extract {{.*}} : vector<8x4xf32>
// CHECK: vector.store {{.*}} : memref<8x32xf32>, vector<4xf32>
%0 = linalg.matmul ins(%arg0, %arg1: tensor<8x16xf32>, tensor<16x32xf32>)
outs(%arg2: tensor<8x32xf32>)
-> tensor<8x32xf32>
return %0 : tensor<8x32xf32>
}

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %module_op
%1, %loops:3 = transform.structured.tile %0 [8, 4, 2]
%2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation
transform.structured.vectorize %2
transform.bufferization.one_shot_bufferize %module_op

%func = transform.structured.match ops{["func.func"]} in %module_op
transform.vector.lower_vectors %func { multireduction_lowering = "innerreduce"}
}
8 changes: 5 additions & 3 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3363,14 +3363,16 @@ cc_library(
":ArithDialect",
":AsmParser",
":IR",
":VectorDialect",
":VectorTransformOpsIncGen",
":VectorTransforms",
":PDLDialect",
":Parser",
":SideEffectInterfaces",
":TransformDialect",
":TransformUtils",
":VectorDialect",
":VectorToSCF",
":VectorTransformOpsIncGen",
":VectorTransforms",
":X86VectorTransforms",
"//llvm:Support",
],
)
Expand Down

0 comments on commit 6ff1233

Please sign in to comment.