Skip to content

Commit

Permalink
Update VectorContractionOp to take iterator types and index mapping a…
Browse files Browse the repository at this point in the history
…ttributes compatible with linalg ops.

PiperOrigin-RevId: 282412311
  • Loading branch information
andydavis1 authored and tensorflower-gardener committed Nov 25, 2019
1 parent d60133f commit 8fc44a4
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 76 deletions.
86 changes: 53 additions & 33 deletions mlir/include/mlir/Dialect/VectorOps/VectorOps.td
Expand Up @@ -46,12 +46,11 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :

// TODO(andydavis, ntv) Add an attribute to specify a different algebra
// with operators other than the current set: {*, +}.
// TODO(andydavis) Consider using AffineMaps to express contracting, batch
// and free dimension pairs.
def Vector_ContractionOp :
Vector_Op<"contract", [NoSideEffect]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
Variadic<TupleOf<[Index]>>:$masks)>,
Variadic<TupleOf<[Index]>>:$masks,
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
Results<(outs AnyVector)> {
let summary = "vector contraction operation";
let description = [{
Expand All @@ -64,39 +63,59 @@ def Vector_ContractionOp :
Optional vector mask arguments specify the dynamic dimension sizes of
valid data within the lhs/rhs vector arguments.

Dimensions for the arguments and result type fall into three categories:
*) Contracting: contracting dimensions are present in the lhs and rhs
An iterator type attribute list must be specified, where each element of
the list represents an iterator with one of the following types:

*) "reduction": reduction dimensions are present in the lhs and rhs
arguments but not in the output (or optional accumulator
argument). These are the dimensions along which the vector
contraction op computes the sum of products, and contracting
dimension pair dimension sizes must match between lhs/rhs.
*) Batch: batch dimensions are non-contracting dimensions and so are
present in the output and in the accumulator argument. The lhs
and rhs co-iterate along the batch dimension and so dimension
sizes must match across all arguments and result.
*) Free: free dimensions are non-contraction, non-batch dimensions and
are present in the output and accumulator argument. The lhs and
rhs free dimensions are unrelated to each other and do not
co-iterate.

Contracting and batch dimensions are specified as dimension pairs
of logical dimension numbers: the first in the pair represents the lhs
logical dimension number and the second in the pair represents the
associated rhs logical dimension number. A dimension pair binds together
logical dimension numbers from the lhs/rhs which co-iterate together, either
as contracting or batch dimensions.
contraction op computes the sum of products, and
contracting dimension pair dimension sizes must match
between lhs/rhs.
*) "parallel": Batch dimensions are iterator type "parallel", and
are non-contracting dimensions present in the lhs, rhs and
output. The lhs/rhs co-iterate along the batch dimensions,
which should be expressed in their indexing maps.

Free dimensions are iterator type "parallel", and are
non-contraction, non-batch dimensions accessed by either the
lhs or rhs (but not both). The lhs and rhs free dimensions
are unrelated to each other and do not co-iterate, which
should be expressed in their indexing maps.

An indexing map attribute list must be specified with an entry for lhs, rhs
and acc arguments. An indexing map attribute specifies a mapping from each
iterator in the iterator type list, to each dimension of an N-D vector.

Examples:

// 2D vector contraction with one contracting dimension (matmul).
%3 = vector.contract %0, %1, %2
{ contracting_dim_map = [[1, 0]] }
: vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
#contraction_accesses = [
(i, j, k) -> (i, k),
(i, j, k) -> (k, j),
(i, j, k) -> (i, j)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = [parallel, parallel, reduction]
}

%3 = vector.contract #contraction_trait %0, %1, %2
: vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>

// 4D to 3D vector contraction with two contracting dimensions and
// one batch dimension.
%4 = vector.contract %0, %1, %2
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
#contraction_accesses = [
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = [parallel, parallel, parallel reduction, reduction]
}

%4 = vector.contract #contraction_trait %0, %1, %2
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>

// 4D vector contraction with two contracting dimensions and optional
Expand All @@ -106,8 +125,7 @@ def Vector_ContractionOp :
%rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7
: tuple<index, index, index, index>

%5 = vector.contract %0, %1, %2, %lhs_mask, %rhs_mask
{ contracting_dim_map = [[0, 2], [2, 1]] }
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
}];
let extraClassDeclaration = [{
Expand All @@ -131,11 +149,13 @@ def Vector_ContractionOp :
VectorType getResultType() {
return getResult()->getType().cast<VectorType>();
}
static StringRef getContractingDimMapAttrName() {
return "contracting_dim_map";
SmallVector<StringRef, 2> getTraitAttrNames();
SmallVector<AffineMap, 4> getIndexingMaps();
static StringRef getReductionIteratorTypeName() {
return "reduction";
}
static StringRef getBatchDimMapAttrName() {
return "batch_dim_map";
static StringRef getParallelIteratorTypeName() {
return "parallel";
}
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
Expand Down
104 changes: 87 additions & 17 deletions mlir/lib/Dialect/VectorOps/VectorOps.cpp
Expand Up @@ -27,6 +27,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringSet.h"

using namespace mlir;
using namespace mlir::vector;
Expand Down Expand Up @@ -56,7 +57,10 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
SmallVector<Type, 2> types;
Type resultVectorType;
auto loc = parser.getCurrentLocation();
if (parser.parseOperand(lhsInfo) || parser.parseComma() ||
DictionaryAttr dictAttr;
// TODO(andydavis, ntv) Unify linalg op attribute parsing.
if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
parser.parseOperand(lhsInfo) || parser.parseComma() ||
parser.parseOperand(rhsInfo) || parser.parseComma() ||
parser.parseOperand(accInfo) ||
parser.parseTrailingOperandList(masksInfo) ||
Expand All @@ -68,7 +72,8 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
parser.resolveOperand(accInfo, resultVectorType, result.operands) ||
parser.addTypeToList(resultVectorType, result.types))
return failure();

result.attributes.assign(dictAttr.getValue().begin(),
dictAttr.getValue().end());
if (masksInfo.empty())
return success();
if (masksInfo.size() != 2)
Expand All @@ -90,13 +95,23 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
}

static void print(OpAsmPrinter &p, ContractionOp op) {
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
p << ", " << *op.acc();
// TODO(andydavis, ntv) Unify printing code with linalg ops.
auto attrNames = op.getTraitAttrNames();
llvm::StringSet<> traitAttrsSet;
traitAttrsSet.insert(attrNames.begin(), attrNames.end());
SmallVector<NamedAttribute, 8> attrs;
for (auto attr : op.getAttrs()) {
if (traitAttrsSet.count(attr.first.strref()) > 0)
attrs.push_back(attr);
}
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", ";
p << *op.rhs() << ", " << *op.acc();
if (llvm::size(op.masks()) == 2) {
p << ", " << **op.masks().begin();
p << ", " << **(op.masks().begin() + 1);
}
p.printOptionalAttrDict(op.getAttrs());
p.printOptionalAttrDict(op.getAttrs(), attrNames);
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
<< op.getResultType();
}
Expand Down Expand Up @@ -159,6 +174,34 @@ static LogicalResult verify(ContractionOp op) {
auto rhsType = op.getRhsType();
auto accType = op.getAccType();
auto resType = op.getResultType();

// Verify that an indexing map was specified for each vector operand.
if (op.indexing_maps().size() != 3)
return op.emitOpError("expected an indexing map for each vector operand");

// Verify that each index map has 'numIterators' inputs, no symbols, and
// that the number of map outputs equals the rank of its associated
// vector operand.
unsigned numIterators = op.iterator_types().getValue().size();
for (auto it : llvm::enumerate(op.indexing_maps())) {
auto index = it.index();
auto map = it.value().cast<AffineMapAttr>().getValue();
if (map.getNumSymbols() != 0)
return op.emitOpError("expected indexing map ")
<< index << " to have no symbols";
if (map.getNumDims() != numIterators)
return op.emitOpError("expected indexing map ")
<< index << " to have " << numIterators << " number of inputs";
auto operandType = op.getOperand(index)->getType().cast<VectorType>();
unsigned rank = operandType.getShape().size();
if (map.getNumResults() != rank)
return op.emitOpError("expected indexing map ")
<< index << " to have " << rank << " number of outputs";
if (!map.isProjectedPermutation())
return op.emitOpError("expected indexing map ")
<< index << " to be a projected permutation of its inputs";
}

auto contractingDimMap = op.getContractingDimMap();
auto batchDimMap = op.getBatchDimMap();

Expand Down Expand Up @@ -198,27 +241,54 @@ static LogicalResult verify(ContractionOp op) {
return success();
}

static std::vector<std::pair<int64_t, int64_t>> getDimMap(Attribute attr) {
SmallVector<StringRef, 2> ContractionOp::getTraitAttrNames() {
return SmallVector<StringRef, 2>{"indexing_maps", "iterator_types"};
}

static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
if (targetExpr == map.getResult(i))
return i;
return -1;
}

static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
StringRef targetIteratorTypeName, MLIRContext *context) {
std::vector<std::pair<int64_t, int64_t>> dimMap;
auto dimPairs = attr.dyn_cast_or_null<ArrayAttr>();
if (!dimPairs)
return dimMap;
for (auto dimPairAttr : dimPairs) {
auto dimPair = dimPairAttr.cast<ArrayAttr>();
assert(dimPair.size() == 2);
auto lhsDim = dimPair.begin()->cast<IntegerAttr>().getInt();
auto rhsDim = std::prev(dimPair.end())->cast<IntegerAttr>().getInt();
dimMap.push_back({lhsDim, rhsDim});
for (auto it : llvm::enumerate(iteratorTypes)) {
auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
if (iteratorTypeName != targetIteratorTypeName)
continue;
// Search lhs/rhs map results for 'targetExpr'.
auto targetExpr = getAffineDimExpr(it.index(), context);
int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
if (lhsDim >= 0 && rhsDim >= 0)
dimMap.push_back({lhsDim, rhsDim});
}
return dimMap;
}

std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
return getDimMap(getAttr(getContractingDimMapAttrName()));
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
return getDimMap(indexingMaps, iterator_types(),
getReductionIteratorTypeName(), getContext());
}

std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
return getDimMap(getAttr(getBatchDimMapAttrName()));
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
return getDimMap(indexingMaps, iterator_types(),
getParallelIteratorTypeName(), getContext());
}

SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
SmallVector<AffineMap, 4> res;
auto mapAttrs = indexing_maps().getValue();
res.reserve(mapAttrs.size());
for (auto mapAttr : mapAttrs)
res.push_back(mapAttr.cast<AffineMapAttr>().getValue());
return res;
}

//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 8fc44a4

Please sign in to comment.