Skip to content

Commit

Permalink
Add VectorOps.StridedSliceOp
Browse files Browse the repository at this point in the history
The `vector.strided_slice` takes an n-D vector, k-D `offsets` integer array attribute, a
k-D `sizes` integer array attribute, a k-D `strides` integer array attribute and extracts
the n-D subvector at the proper offset.

Returns an n-D vector where the first k-D dimensions match the `sizes` attribute.
The returned subvector contains the elements starting at offset `offsets` and ending at
`offsets + sizes`.

Example:
```
  %1 = vector.strided_slice %0
      {offsets : [0, 2], sizes : [2, 4], strides : [1, 1]}:
    vector<4x8x16xf32> // returns a vector<2x4x16xf32>
```

This op will be useful for progressive lowering within the VectorOp dialect.

PiperOrigin-RevId: 281352749
  • Loading branch information
Nicolas Vasilache authored and tensorflower-gardener committed Nov 19, 2019
1 parent 3732ba4 commit ee95f6f
Show file tree
Hide file tree
Showing 4 changed files with 289 additions and 2 deletions.
42 changes: 42 additions & 0 deletions mlir/include/mlir/Dialect/VectorOps/VectorOps.td
Expand Up @@ -76,6 +76,48 @@ def VectorExtractElementOp :
}];
}

def VectorStridedSliceOp :
Vector_Op<"strided_slice", [NoSideEffect,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
Results<(outs AnyVector)> {
let summary = "strided_slice operation";
let description = [{
Takes an n-D vector, k-D `offsets` integer array attribute, a k-D `sizes`
integer array attribute, a k-D `strides` integer array attribute and
extracts the n-D subvector at the proper offset.

At the moment strides must contain only 1s.

Returns an n-D vector where the first k-D dimensions match the `sizes`
attribute. The returned subvector contains the elements starting at offset
`offsets` and ending at `offsets + sizes`.

Examples:
```
%1 = vector.strided_slice %0
{offsets : [0, 2], sizes : [2, 4], strides : [1, 1]}:
vector<4x8x16xf32> to vector<2x4x16xf32>
```

// TODO(Evolve to a range form syntax):
%1 = vector.strided_slice %0[0:2:1][2:4:1]
vector<4x8x16xf32> to vector<2x4x16xf32>
}];
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value *source, " #
"ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, " #
"ArrayRef<int64_t> strides">];
let extraClassDeclaration = [{
static StringRef getOffsetsAttrName() { return "offsets"; }
static StringRef getSizesAttrName() { return "sizes"; }
static StringRef getStridesAttrName() { return "strides"; }
VectorType getVectorType(){ return vector()->getType().cast<VectorType>(); }
}];
}

def VectorOuterProductOp :
Vector_Op<"outerproduct", [NoSideEffect, SameOperandsAndResultElementType]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, Variadic<AnyVector>:$acc)>,
Expand Down
178 changes: 176 additions & 2 deletions mlir/lib/Dialect/VectorOps/VectorOps.cpp
Expand Up @@ -92,7 +92,7 @@ static ParseResult parseVectorExtractElementOp(OpAsmParser &parser,
static_cast<int64_t>(positionAttr.size()) > vectorType.getRank())
return parser.emitError(
attributeLoc,
"expected position attribute of rank smaller than vector");
"expected position attribute of rank smaller than vector rank");

Type resType = inferExtractOpResultType(vectorType, positionAttr);
result.attributes = attrs;
Expand All @@ -106,7 +106,7 @@ static LogicalResult verify(VectorExtractElementOp op) {
return op.emitOpError("expected non-empty position attribute");
if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank()))
return op.emitOpError(
"expected position attribute of rank smaller than vector");
"expected position attribute of rank smaller than vector rank");
for (auto en : llvm::enumerate(positionAttr)) {
auto attr = en.value().dyn_cast<IntegerAttr>();
if (!attr || attr.getInt() < 0 ||
Expand All @@ -119,6 +119,180 @@ static LogicalResult verify(VectorExtractElementOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// VectorStridedSliceOp
//===----------------------------------------------------------------------===//

static Type inferVectorExtractRangeOpResultType(VectorType vectorType,
ArrayAttr offsets,
ArrayAttr sizes,
ArrayAttr strides) {
assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
SmallVector<int64_t, 4> shape;
shape.reserve(vectorType.getRank());
unsigned idx = 0;
for (unsigned e = offsets.size(); idx < e; ++idx)
shape.push_back(sizes.getValue()[idx].cast<IntegerAttr>().getInt());
for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
shape.push_back(vectorType.getShape()[idx]);

return VectorType::get(shape, vectorType.getElementType());
}

void VectorStridedSliceOp::build(Builder *builder, OperationState &result,
Value *source, ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
result.addOperands(source);
auto offsetsAttr = builder->getI64ArrayAttr(offsets);
auto sizesAttr = builder->getI64ArrayAttr(sizes);
auto stridesAttr = builder->getI64ArrayAttr(strides);
result.addTypes(
inferVectorExtractRangeOpResultType(source->getType().cast<VectorType>(),
offsetsAttr, sizesAttr, stridesAttr));
result.addAttribute(getOffsetsAttrName(), offsetsAttr);
result.addAttribute(getSizesAttrName(), sizesAttr);
result.addAttribute(getStridesAttrName(), stridesAttr);
}

static void print(OpAsmPrinter &p, VectorStridedSliceOp op) {
p << op.getOperationName() << " " << *op.vector();
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.vector()->getType() << " to " << op.getResult()->getType();
}

static ParseResult parseVectorStridedSliceOp(OpAsmParser &parser,
OperationState &result) {
llvm::SMLoc attributeLoc, typeLoc;
OpAsmParser::OperandType vector;
VectorType vectorType, resultVectorType;
return failure(parser.parseOperand(vector) ||
parser.getCurrentLocation(&attributeLoc) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.getCurrentLocation(&typeLoc) ||
parser.parseColonType(vectorType) ||
parser.parseKeywordType("to", resultVectorType) ||
parser.resolveOperand(vector, vectorType, result.operands) ||
parser.addTypeToList(resultVectorType, result.types));
}

// TODO(ntv) Should be moved to Tablegen Confined attributes.
static bool isIntegerArrayAttrSmallerThanShape(VectorStridedSliceOp op,
ArrayAttr arrayAttr,
ShapedType shape,
StringRef attrName) {
if (arrayAttr.size() > static_cast<unsigned>(shape.getRank())) {
op.emitOpError("expected ")
<< attrName << " attribute of rank smaller than vector rank";
return false;
}
return true;
}

// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
// interval. If `halfOpen` is true then the admissible interval is [min, max).
// Otherwise, the admissible interval is [min, max].
static bool isIntegerArrayAttrConfinedToRange(VectorStridedSliceOp op,
ArrayAttr arrayAttr, int64_t min,
int64_t max, StringRef attrName,
bool halfOpen = true) {
for (auto attr : arrayAttr) {
auto val = attr.cast<IntegerAttr>().getInt();
auto upper = max;
if (!halfOpen)
upper += 1;
if (val < min || val >= upper) {
op.emitOpError("expected ")
<< attrName << " to be confined to [" << min << ", " << upper << ")";
return false;
}
}
return true;
}

// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
// interval. If `halfOpen` is true then the admissible interval is [min, max).
// Otherwise, the admissible interval is [min, max].
static bool
isIntegerArrayAttrConfinedToShape(VectorStridedSliceOp op, ArrayAttr arrayAttr,
ShapedType shape, StringRef attrName,
bool halfOpen = true, int64_t min = 0) {
assert(arrayAttr.size() <= static_cast<unsigned>(shape.getRank()));
for (auto it : llvm::zip(arrayAttr, shape.getShape())) {
auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
auto max = std::get<1>(it);
if (!halfOpen)
max += 1;
if (val < min || val >= max) {
op.emitOpError("expected ")
<< attrName << " to be confined to [" << min << ", " << max << ")";
return false;
}
}
return true;
}

// Returns true if all integers in `arrayAttr` are in the interval [min, max}.
// interval. If `halfOpen` is true then the admissible interval is [min, max).
// Otherwise, the admissible interval is [min, max].
static bool isSumOfIntegerArrayAttrConfinedToShape(
VectorStridedSliceOp op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
ShapedType shape, StringRef attrName1, StringRef attrName2,
bool halfOpen = true, int64_t min = 1) {
assert(arrayAttr1.size() <= static_cast<unsigned>(shape.getRank()));
assert(arrayAttr2.size() <= static_cast<unsigned>(shape.getRank()));
for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape.getShape())) {
auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
auto max = std::get<2>(it);
if (!halfOpen)
max += 1;
if (val1 + val2 < 0 || val1 + val2 >= max) {
op.emitOpError("expected sum(")
<< attrName1 << ", " << attrName2 << ") to be confined to [" << min
<< ", " << max << ")";
return false;
}
}
return true;
}

static LogicalResult verify(VectorStridedSliceOp op) {
auto type = op.getVectorType();
auto offsets = op.offsets();
auto sizes = op.sizes();
auto strides = op.strides();
if (offsets.size() != sizes.size() || offsets.size() != strides.size()) {
op.emitOpError(
"expected offsets, sizes and strides attributes of same size");
return failure();
}

auto offName = VectorStridedSliceOp::getOffsetsAttrName();
auto sizesName = VectorStridedSliceOp::getSizesAttrName();
auto stridesName = VectorStridedSliceOp::getStridesAttrName();
if (!isIntegerArrayAttrSmallerThanShape(op, offsets, type, offName) ||
!isIntegerArrayAttrSmallerThanShape(op, sizes, type, sizesName) ||
!isIntegerArrayAttrSmallerThanShape(op, strides, type, stridesName) ||
!isIntegerArrayAttrConfinedToShape(op, offsets, type, offName) ||
!isIntegerArrayAttrConfinedToShape(op, sizes, type, sizesName,
/*halfOpen=*/false, /*min=*/1) ||
!isIntegerArrayAttrConfinedToRange(op, strides, 1, 1, stridesName,
/*halfOpen=*/false) ||
!isSumOfIntegerArrayAttrConfinedToShape(op, offsets, sizes, type, offName,
sizesName, /*halfOpen=*/false))
return failure();

auto resultType = inferVectorExtractRangeOpResultType(
op.getVectorType(), op.offsets(), op.sizes(), op.strides());
if (op.getResult()->getType() != resultType) {
op.emitOpError("expected result type to be ") << resultType;
return failure();
}

return success();
}

//===----------------------------------------------------------------------===//
// VectorOuterProductOp
//===----------------------------------------------------------------------===//
Expand Down
64 changes: 64 additions & 0 deletions mlir/test/Dialect/VectorOps/invalid.mlir
Expand Up @@ -231,6 +231,7 @@ func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
// expected-error@+1 {{requires a projected permutation_map (at most one dim or the zero constant can appear in each result)}}
vector.transfer_write %cst, %arg0[%c3, %c3] {permutation_map = (d0, d1)->(d0 + 1)} : vector<128xf32>, memref<?x?xf32>
}

// -----

func @test_vector.transfer_write(%arg0: memref<?x?x?xf32>) {
Expand All @@ -239,3 +240,66 @@ func @test_vector.transfer_write(%arg0: memref<?x?x?xf32>) {
// expected-error@+1 {{requires a permutation_map that is a permutation (found one dim used more than once)}}
vector.transfer_write %cst, %arg0[%c3, %c3, %c3] {permutation_map = (d0, d1, d2)->(d0, d0)} : vector<3x7xf32>, memref<?x?x?xf32>
}

// -----

func @strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected offsets, sizes and strides attributes of same size}}
%1 = vector.strided_slice %arg0 {offsets = [100], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
}

// -----

func @strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected offsets attribute of rank smaller than vector rank}}
%1 = vector.strided_slice %arg0 {offsets = [2, 2, 2, 2], sizes = [2, 2, 2, 2], strides = [1, 1, 1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
}

// -----

func @strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{expected offsets attribute of rank smaller than vector rank}}
%1 = vector.strided_slice %arg0 {offsets = [2, 2, 2, 2], sizes = [2, 2, 2, 2], strides = [1, 1, 1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
}

// -----

func @strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{op expected offsets to be confined to [0, 4)}}
%1 = vector.strided_slice %arg0 {offsets = [100], sizes = [100], strides = [100]} : vector<4x8x16xf32> to vector<100x8x16xf32>
}

// -----

func @strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{op expected sizes to be confined to [1, 5)}}
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [100], strides = [100]} : vector<4x8x16xf32> to vector<100x8x16xf32>
}

// -----

func @strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{op expected strides to be confined to [1, 2)}}
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>
}

// -----

func @strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{op expected strides to be confined to [1, 2)}}
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>
}

// -----

func @strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{op expected sum(offsets, sizes) to be confined to [1, 5)}}
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [3], strides = [1]} : vector<4x8x16xf32> to vector<3x8x16xf32>
}

// -----

func @strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{op expected result type to be 'vector<2x8x16xf32>'}}
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
}
7 changes: 7 additions & 0 deletions mlir/test/Dialect/VectorOps/ops.mlir
Expand Up @@ -41,3 +41,10 @@ func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8
%1 = vector.outerproduct %arg0, %arg1, %arg2 : vector<4xf32>, vector<8xf32>
return %1 : vector<4x8xf32>
}

// CHECK-LABEL: strided_slice
func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
// CHECK: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32>
%1 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
return %1: vector<2x2x16xf32>
}

0 comments on commit ee95f6f

Please sign in to comment.