Skip to content

Commit

Permalink
[VectorOps] Update vector transfer_read/write ops to operatate on mem…
Browse files Browse the repository at this point in the history
…refs with vector element type.

Update vector transfer_read/write ops to operatate on memrefs with vector element type.
This handle cases where the memref vector element type represents the minimal memory transfer unit (or multiple of the minimal memory transfer unit).

PiperOrigin-RevId: 286482115
  • Loading branch information
andydavis1 authored and tensorflower-gardener committed Dec 20, 2019
1 parent 6685282 commit 8020ad3
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 47 deletions.
38 changes: 29 additions & 9 deletions mlir/include/mlir/Dialect/VectorOps/VectorOps.td
Expand Up @@ -746,10 +746,15 @@ def Vector_TransferReadOp :

let description = [{
The `vector.transfer_read` op performs a blocking read from a slice within
a scalar [MemRef](../LangRef.md#memref-type) supplied as its first operand
into a [vector](../LangRef.md#vector-type) of the same elemental type. The
slice is further defined by a full-rank index within the MemRef, supplied as
the operands `2 .. 1 + rank(memref)`. The permutation_map
a [MemRef](../LangRef.md#memref-type) supplied as its first operand
into a [vector](../LangRef.md#vector-type) of the same base elemental type.

A vector memref operand must have its vector element type match a suffix
(shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
vector<1x1x4x3xf32>).

The slice is further defined by a full-rank index within the MemRef,
supplied as the operands `2 .. 1 + rank(memref)`. The permutation_map
[attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
slice to match the vector shape. The size of the slice is specified by the
Expand Down Expand Up @@ -854,6 +859,11 @@ def Vector_TransferReadOp :
memref<?x?xf32>, vector<128xf32>
}
}

// Read from a memref with vector element type.
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0
{permutation_map = (d0, d1)->(d0, d1)}
: memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
```
}];

Expand All @@ -878,10 +888,15 @@ def Vector_TransferWriteOp :
let description = [{
The `vector.transfer_write` performs a blocking write from a
[vector](../LangRef.md#vector-type), supplied as its first operand, into a
slice within a scalar [MemRef](../LangRef.md#memref-type) of the same
elemental type, supplied as its second operand. The slice is further defined
by a full-rank index within the MemRef, supplied as the operands
`3 .. 2 + rank(memref)`.
slice within a [MemRef](../LangRef.md#memref-type) of the same base
elemental type, supplied as its second operand.

A vector memref operand must have its vector element type match a suffix
(shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
vector<1x1x4x3xf32>).

The slice is further defined by a full-rank index within the MemRef,
supplied as the operands `3 .. 2 + rank(memref)`.
The permutation_map [attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
slice to match the vector shape. The size of the slice is specified by the
Expand Down Expand Up @@ -915,6 +930,11 @@ def Vector_TransferWriteOp :
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
vector<16x32x64xf32>, memref<?x?x?x?xf32>
}}}}

// write to a memref with vector element type.
vector.transfer_write %4, %arg1[%c3, %c3]
{permutation_map = (d0, d1)->(d0, d1)}
: vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
```
}];

Expand Down Expand Up @@ -1048,7 +1068,7 @@ def Vector_TupleOp :
Note that this operation is used during the vector op unrolling
transformation and should be removed before lowering to lower-level
dialects.


Examples:
```
Expand Down
121 changes: 87 additions & 34 deletions mlir/lib/Dialect/VectorOps/VectorOps.cpp
Expand Up @@ -1420,6 +1420,59 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
return success();
}

static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
VectorType vectorType,
AffineMap permutationMap) {
auto memrefElementType = memrefType.getElementType();
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
// Memref has vector element type.

// Check that 'memrefVectorElementType' and vector element types match.
if (memrefVectorElementType.getElementType() != vectorType.getElementType())
return op->emitOpError(
"requires memref and vector types of the same elemental type");

// Check that memref vector type is a suffix of 'vectorType.
unsigned memrefVecEltRank = memrefVectorElementType.getRank();
unsigned resultVecRank = vectorType.getRank();
if (memrefVecEltRank > resultVecRank)
return op->emitOpError(
"requires memref vector element and vector result ranks to match.");
// TODO(b/146516564) Move this to isSuffix in VectorOps/Utils.h.
unsigned rankOffset = resultVecRank - memrefVecEltRank;
auto memrefVecEltShape = memrefVectorElementType.getShape();
auto resultVecShape = vectorType.getShape();
for (unsigned i = 0; i < memrefVecEltRank; ++i)
if (memrefVecEltShape[i] != resultVecShape[rankOffset + i])
return op->emitOpError(
"requires memref vector element shape to match suffix of "
"vector result shape.");
// Check that permutation map results match 'rankOffset' of vector type.
if (permutationMap.getNumResults() != rankOffset)
return op->emitOpError("requires a permutation_map with result dims of "
"the same rank as the vector type");
} else {
// Memref has scalar element type.

// Check that memref and vector element types match.
if (memrefType.getElementType() != vectorType.getElementType())
return op->emitOpError(
"requires memref and vector types of the same elemental type");

// Check that permutation map results match rank of vector type.
if (permutationMap.getNumResults() != vectorType.getRank())
return op->emitOpError("requires a permutation_map with result dims of "
"the same rank as the vector type");
}

if (permutationMap.getNumSymbols() != 0)
return op->emitOpError("requires permutation_map without symbols");
if (permutationMap.getNumInputs() != memrefType.getRank())
return op->emitOpError("requires a permutation_map with input dims of the "
"same rank as the memref type");
return success();
}

static void print(OpAsmPrinter &p, TransferReadOp op) {
p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
<< "], " << op.padding() << " ";
Expand Down Expand Up @@ -1459,26 +1512,35 @@ static LogicalResult verify(TransferReadOp op) {
// Consistency of elemental types in memref and vector.
MemRefType memrefType = op.getMemRefType();
VectorType vectorType = op.getVectorType();
if (memrefType.getElementType() != vectorType.getElementType())
return op.emitOpError(
"requires memref and vector types of the same elemental type");
auto elementalType = op.padding()->getType();
if (!VectorType::isValidElementType(elementalType))
return op.emitOpError("requires valid padding vector elemental type");
if (elementalType != vectorType.getElementType())
return op.emitOpError(
"requires formal padding and vector of the same elemental type");
if (llvm::size(op.indices()) != memrefType.getRank())
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
auto paddingType = op.padding()->getType();
auto permutationMap = op.permutation_map();
if (permutationMap.getNumSymbols() != 0)
return op.emitOpError("requires permutation_map without symbols");
if (permutationMap.getNumInputs() != memrefType.getRank())
return op.emitOpError("requires a permutation_map with input dims of the "
"same rank as the memref type");
if (permutationMap.getNumResults() != vectorType.getRank())
return op.emitOpError("requires a permutation_map with result dims of the "
"same rank as the vector type");
auto memrefElementType = memrefType.getElementType();

if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank())
return op.emitOpError("requires ") << memrefType.getRank() << " indices";

if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
permutationMap)))
return failure();

if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
// Memref has vector element type.
// Check that 'memrefVectorElementType' and 'paddingType' types match.
if (memrefVectorElementType != paddingType)
return op.emitOpError(
"requires memref element type and padding type to match.");

} else {
// Check that 'paddingType' is valid to store in a vector type.
if (!VectorType::isValidElementType(paddingType))
return op.emitOpError("requires valid padding vector elemental type");

// Check that padding type and vector element types match.
if (paddingType != vectorType.getElementType())
return op.emitOpError(
"requires formal padding and vector of the same elemental type");
}

return verifyPermutationMap(permutationMap,
[&op](Twine t) { return op.emitOpError(t); });
}
Expand Down Expand Up @@ -1519,24 +1581,15 @@ static LogicalResult verify(TransferWriteOp op) {
// Consistency of elemental types in memref and vector.
MemRefType memrefType = op.getMemRefType();
VectorType vectorType = op.getVectorType();
if (memrefType.getElementType() != vectorType.getElementType())
return op.emitOpError(
"requires memref and vector types of the same elemental type");
auto permutationMap = op.permutation_map();

if (llvm::size(op.indices()) != memrefType.getRank())
return op.emitOpError("requires ") << memrefType.getRank() << " indices";

// Consistency of AffineMap attribute.
auto permutationMap = op.permutation_map();
if (permutationMap.getNumSymbols() != 0)
return op.emitOpError("requires a symbol-less permutation_map");
if (permutationMap.getNumInputs() != memrefType.getRank())
return op.emitOpError("requires a permutation_map with input dims of the "
"same rank as the memref type: ")
<< permutationMap.getNumInputs() << " vs " << memrefType;
if (permutationMap.getNumResults() != vectorType.getRank())
return op.emitOpError("requires a permutation_map with result dims of the "
"same rank as the vector type.")
<< permutationMap.getNumResults() << " vs " << vectorType;
if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
permutationMap)))
return failure();

return verifyPermutationMap(permutationMap,
[&op](Twine t) { return op.emitOpError(t); });
}
Expand Down
30 changes: 30 additions & 0 deletions mlir/test/Dialect/VectorOps/invalid.mlir
Expand Up @@ -308,6 +308,36 @@ func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {

// -----

func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
// expected-error@+1 {{requires memref and vector types of the same elemental type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xi32>
}

// -----

func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
// expected-error@+1 {{requires memref vector element and vector result ranks to match}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<3xf32>
}

// -----

func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
// expected-error@+1 {{ requires memref vector element shape to match suffix of vector result shape}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<1x1x2x3xf32>
}

// -----

func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
%c3 = constant 3 : index
%cst = constant dense<3.0> : vector<128 x f32>
Expand Down
19 changes: 15 additions & 4 deletions mlir/test/Dialect/VectorOps/ops.mlir
@@ -1,24 +1,35 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s

// CHECK-DAG: #[[MAP0:map[0-9]+]] = (d0, d1) -> (d0, d1)

// CHECK-LABEL: func @vector_transfer_ops(
func @vector_transfer_ops(%arg0: memref<?x?xf32>) {
func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%arg1 : memref<?x?xvector<4x3xf32>>) {
// CHECK: %[[C3:.*]] = constant 3 : index
%c3 = constant 3 : index
%cst = constant 3.0 : f32
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>

//
// CHECK: %0 = vector.transfer_read
// CHECK: vector.transfer_read
%0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = (d0, d1)->(d0)} : memref<?x?xf32>, vector<128xf32>
// CHECK: %1 = vector.transfer_read
// CHECK: vector.transfer_read
%1 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = (d0, d1)->(d1, d0)} : memref<?x?xf32>, vector<3x7xf32>
// CHECK: vector.transfer_read
%2 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = (d0, d1)->(d0)} : memref<?x?xf32>, vector<128xf32>
// CHECK: vector.transfer_read
%3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = (d0, d1)->(d1)} : memref<?x?xf32>, vector<128xf32>
//
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>

// CHECK: vector.transfer_write
vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
// CHECK: vector.transfer_write
vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = (d0, d1)->(d1, d0)} : vector<3x7xf32>, memref<?x?xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] {permutation_map = #[[MAP0]]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = (d0, d1)->(d0, d1)} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>

return
}

Expand Down

0 comments on commit 8020ad3

Please sign in to comment.