Skip to content

Commit

Permalink
[mlir] Add primitive transform pattern to rewrite linalg.fill into ve…
Browse files Browse the repository at this point in the history
…ctor.broadcast form.

Summary:
This diff adds a transformation patter to rewrite linalg.fill as broadcasting a scaler into a vector.
It uses the same preconditioning as matmul (memory is contiguous).

Reviewers: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73391
  • Loading branch information
asaadaldien committed Jan 28, 2020
1 parent 60b8842 commit 16e82d8
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 18 deletions.
51 changes: 33 additions & 18 deletions mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
Expand Down Expand Up @@ -156,8 +158,8 @@ static bool isMatmul(linalg::GenericOp genericOp) {
genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
}

// TODO(ntv): This is in fact much more general than just vectorization for
// matmul ops.
// TODO(ntv, ataei): This is in fact much more general than just vectorization
// for matmul and fill ops.
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
Expand All @@ -167,7 +169,7 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
if (isa<linalg::MatmulOp>(op))
if (isa<linalg::MatmulOp>(op) || isa<linalg::FillOp>(op))
return success();

auto genericOp = dyn_cast<linalg::GenericOp>(op);
Expand All @@ -189,28 +191,41 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {

SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
Operation *op) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: Rewrite linalg op as vector.contract: "
<< *op << ":\n");
using edsc::intrinsics::std_load;
using edsc::intrinsics::std_store;
using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
using vector_broadcast = edsc::intrinsics::ValueBuilder<vector::BroadcastOp>;
using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;

assert(succeeded(vectorizeLinalgOpPrecondition(op)) &&
"DRR failure case must be a precondition");

auto linalgOp = cast<linalg::LinalgOp>(op);
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
edsc::ScopedContext scope(rewriter, op->getLoc());
using edsc::intrinsics::std_load;
using edsc::intrinsics::std_store;
using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
auto vA = std_load(vector_type_cast(linalgOp.getInput(0)));
auto vB = std_load(vector_type_cast(linalgOp.getInput(1)));
auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0));
auto vC = std_load(vectorMemRefC);
auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(),
linalgOp.iterator_types());
std_store(vRes, vectorMemRefC);

if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
// Vectorize fill as a vector.broadcast.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: Rewrite linalg.fill as vector.broadcast: "
<< *op << ":\n");
auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0));
auto dstVec = std_load(dstMemrefVec);
auto resVec = vector_broadcast(dstVec, fillOp.value());
std_store(resVec, dstMemrefVec);
} else {
// Vectorize other ops as vector contraction (currently only matmul).
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: Rewrite linalg op as vector.contract: "
<< *op << ":\n");
auto vA = std_load(vector_type_cast(linalgOp.getInput(0)));
auto vB = std_load(vector_type_cast(linalgOp.getInput(1)));
auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0));
auto vC = std_load(vectorMemRefC);
auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(),
linalgOp.iterator_types());
std_store(vRes, vectorMemRefC);
}
return {};
}

Expand Down
7 changes: 7 additions & 0 deletions mlir/test/Dialect/Linalg/transform-patterns.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
// CHECK: vector.contract {{.*}} :
// vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>

func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} : memref<8x16xf32>, f32
return
}
// CHECK-LABEL: func @test_vectorize_fill
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>

func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,20 @@ def : Pattern<(MatmulOp:$op $_, $_, $_),
HasLinalgTransformMarker<"VECTORIZE">,
PreconditionVectorizeLinalgOp
]>>)]>;
def : Pattern<(FillOp:$op $_, $_),
[(VectorizeLinalgOp)],
[(Constraint<And<[
HasLinalgTransformMarker<"VECTORIZE">,
PreconditionVectorizeLinalgOp
]>>)]>;
def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
[(VectorizeLinalgOp)],
[(Constraint<And<[
HasLinalgTransformMarker<"VECTORIZE">,
PreconditionVectorizeLinalgOp
]>>)]>;


//===----------------------------------------------------------------------===//
// Linalg generic permutation patterns.
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 16e82d8

Please sign in to comment.