Skip to content

Commit

Permalink
[StructuredOps][Linalg] Add a primitive pattern to rewrite the linalg…
Browse files Browse the repository at this point in the history
….generic form of matmul to vector form.

This CL uses the newly expanded matcher support to easily detect when a linalg.generic has a multiply-accumulate body. A linalg.generic with such a body is rewritten as a vector contraction.
This CL additionally limits the rewrite to the case of matrix multiplication on contiguous and statically shaped memrefs for now.

Before expanding further, we should harden the infrastructure for expressing custom ops with the structured ops abstraction.

PiperOrigin-RevId: 284566659
  • Loading branch information
Nicolas Vasilache authored and tensorflower-gardener committed Dec 9, 2019
1 parent 70aeb45 commit 91c0074
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 6 deletions.
7 changes: 4 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
Expand Up @@ -484,9 +484,10 @@ def GenericOp : GenericOpBase<"generic"> {
The external library is assumed to be dynamically linked and no strong
compile-time guarantees are provided. In the absence of such a library
call, linalg.generic will always lower to loops.
- iterator_types: an ArrayAttr they type of the enclosing loops; Each element of
the list represents and iterator of one of the following types:
parallel, reduction, window
- iterator_types: an ArrayAttr specifying the type of the enclosing loops.
Each element of the list represents and iterator of one of the following
types:
parallel, reduction, window
- n_views: a pair of I64Attr representing the number of input (readonly)
and output (readwrite) views.

Expand Down
Expand Up @@ -30,6 +30,8 @@ def HasNoLinalgTransformMarker : CPred<[{
}]>;

class HasLinalgTransformMarker<string str> : CPred<[{
$0.getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker) &&
$0.getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>;

Expand Down Expand Up @@ -77,4 +79,11 @@ class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " #
" return matchFailure();">;

//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
"if (failed(vectorizeGenericOp($_builder, $0))) " #
" return matchFailure();">;

#endif // LINALG_TRANSFORMS
Expand Up @@ -87,6 +87,9 @@ LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op);
template <typename ConcreteOp>
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);

// Rewrite a linalg.generic into a suitable vector.contraction op.
LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op);

} // namespace linalg
} // namespace mlir

Expand Down
11 changes: 10 additions & 1 deletion mlir/include/mlir/Dialect/VectorOps/VectorOps.td
Expand Up @@ -127,7 +127,16 @@ def Vector_ContractionOp :
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
}];
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value *lhs, Value *rhs, "
"Value *acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">];
let extraClassDeclaration = [{
static constexpr StringLiteral getIndexingMapsAttrName() {
return "indexing_maps";
}
static constexpr StringLiteral getIteratorTypesAttrName() {
return "iterator_types";
}
VectorType getLhsType() {
return lhs()->getType().cast<VectorType>();
}
Expand All @@ -148,7 +157,7 @@ def Vector_ContractionOp :
VectorType getResultType() {
return getResult()->getType().cast<VectorType>();
}
SmallVector<StringRef, 2> getTraitAttrNames();
ArrayRef<StringRef> getTraitAttrNames();
SmallVector<AffineMap, 4> getIndexingMaps();
static StringRef getReductionIteratorTypeName() {
return "reduction";
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/CMakeLists.txt
Expand Up @@ -25,4 +25,5 @@ add_dependencies(MLIRLinalg
MLIRLinalgTransformPatternsIncGen
MLIRStandardOps
MLIRStandardToLLVM
MLIRVectorOps
)
87 changes: 87 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
Expand Up @@ -23,12 +23,22 @@
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/VectorOps/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>

#define DEBUG_TYPE "linalg-transforms"

using namespace mlir;
using namespace mlir::linalg;

using llvm::dbgs;

// Marker used as attribute name in generated Linalg rewriting transformations.
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
"__internal_linalg_transform__";
Expand Down Expand Up @@ -106,3 +116,80 @@ bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
}
return false;
}

static bool hasMultiplyAddBody(linalg::GenericOp op) {
auto &r = op.region();
if (r.empty())
return false;
if (r.getBlocks().size() != 1)
return false;
auto &ops = r.front().getOperations();
if (ops.size() != 3)
return false;

using mlir::matchers::m_Val;
auto a = m_Val(r.front().getArgument(0));
auto b = m_Val(r.front().getArgument(1));
auto c = m_Val(r.front().getArgument(2));
// TODO(ntv) Update this detection once we have matcher support for
// specifying that any permutation of operands matches.
auto pattern1 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
return pattern1.match(&ops.back()) || pattern2.match(&ops.back()) ||
pattern3.match(&ops.back()) || pattern4.match(&ops.back());
}

// TODO(ntv) should be Tablegen'd from a single source that generates the op
// itself.
static bool isMatmul(linalg::GenericOp genericOp) {
auto *ctx = genericOp.getContext();
auto m = getAffineDimExpr(0, ctx);
auto n = getAffineDimExpr(1, ctx);
auto k = getAffineDimExpr(2, ctx);
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}));
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}));
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}));
auto maps = ArrayAttr::get({mapA, mapB, mapC}, ctx);
return genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
}

LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter,
Operation *op) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: Rewrite linalg op as vector.contract: "
<< *op << ":\n");

// TODO(ntv): This is in fact much more general than just vectorization for
// matmul ops.
auto genericOp = dyn_cast<linalg::GenericOp>(op);
if (!genericOp || !isMatmul(genericOp))
return failure();

// TODO(ntv): non-identity layout.
auto isStaticMemRefWithIdentityLayout = [](Value *v) {
auto m = v->getType().dyn_cast<MemRefType>();
if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
return false;
return true;
};
if (!llvm::all_of(genericOp.getInputsAndOutputs(),
isStaticMemRefWithIdentityLayout))
return failure();

edsc::ScopedContext scope(rewriter, op->getLoc());
using edsc::intrinsics::std_load;
using edsc::intrinsics::std_store;
using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
auto vA = std_load(vector_type_cast(genericOp.getInput(0)));
auto vB = std_load(vector_type_cast(genericOp.getInput(1)));
auto vectorMemRefC = vector_type_cast(genericOp.getOutput(0));
auto vC = std_load(vectorMemRefC);
auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(),
genericOp.iterator_types());
std_store(vRes, vectorMemRefC);
return success();
}
16 changes: 14 additions & 2 deletions mlir/lib/Dialect/VectorOps/VectorOps.cpp
Expand Up @@ -51,6 +51,16 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
// ContractionOp
//===----------------------------------------------------------------------===//

void vector::ContractionOp::build(Builder *builder, OperationState &result,
Value *lhs, Value *rhs, Value *acc,
ArrayAttr indexingMaps,
ArrayAttr iteratorTypes) {
result.addOperands({lhs, rhs, acc});
result.addTypes(acc->getType());
result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
}

static ParseResult parseContractionOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType lhsInfo;
Expand Down Expand Up @@ -235,8 +245,10 @@ static LogicalResult verify(ContractionOp op) {
return success();
}

SmallVector<StringRef, 2> ContractionOp::getTraitAttrNames() {
return SmallVector<StringRef, 2>{"indexing_maps", "iterator_types"};
ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
static constexpr StringRef names[2] = {getIndexingMapsAttrName(),
getIteratorTypesAttrName()};
return ArrayRef<StringRef>(names);
}

static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Dialect/Linalg/transform-patterns.mlir
Expand Up @@ -2,6 +2,9 @@

// CHECK-DAG: #[[STRIDED_1D:.*]] = (d0)[s0] -> (d0 + s0)
// CHECK-DAG: #[[STRIDED_2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
// CHECK-DAG: #[[mk:.*]] = (d0, d1, d2) -> (d0, d2)
// CHECK-DAG: #[[kn:.*]] = (d0, d1, d2) -> (d2, d1)
// CHECK-DAG: #[[mn:.*]] = (d0, d1, d2) -> (d0, d1)

func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>,
Expand Down Expand Up @@ -158,3 +161,33 @@ func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK : loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] {
// CHECK : linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>

#matmul_trait = {
indexing_maps = [
(m, n, k) -> (m, k),
(m, n, k) -> (k, n),
(m, n, k) -> (m, n)
],
n_views = [2, 1],
iterator_types = ["parallel", "parallel", "reduction"],
__internal_linalg_transform__ = "_marked_matmul_"
}
func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
linalg.generic #matmul_trait %A, %B, %C {
^bb(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
linalg.yield %e : f32
} : memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>
return
}

// CHECK-LABEL: func @vectorization_test
// CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref<vector<8x16xf32>>
// CHECK: load %{{.*}}[] : memref<vector<8x16xf32>>
// CHECK: vector.type_cast %{{.*}} : memref<16x32xf32> to memref<vector<16x32xf32>>
// CHECK: load %{{.*}}[] : memref<vector<16x32xf32>>
// CHECK: vector.type_cast %{{.*}} : memref<8x32xf32> to memref<vector<8x32xf32>>
// CHECK: load %{{.*}}[] : memref<vector<8x32xf32>>
// CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
// CHECK: store %{{.*}}, %{{.*}}[] : memref<vector<8x32xf32>>
Expand Up @@ -80,4 +80,11 @@ def : Pattern<(DotOp:$op $a, $b, $c),
[(LinalgOpToLoops<"DotOp"> $op)],
[(Constraint<HasLinalgTransformMarker<"REG">> $op)]>;

//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7),
[(LinalgOpToVectorContraction<"GenericOp"> $op)],
[(Constraint<HasLinalgTransformMarker<"_marked_matmul_">> $op)]>;

#endif // TEST_LINALG_TRANSFORMS_PATTERNS

0 comments on commit 91c0074

Please sign in to comment.