Skip to content

Commit

Permalink
[VectorOps] Refine BroadcastOp in VectorOps dialect
Browse files Browse the repository at this point in the history
Since second argument is always fully overwritten and
shape is define in "to" clause, it is not needed.
Also renamed "into" to "to" now that arg is dropped.

PiperOrigin-RevId: 282686475
  • Loading branch information
aartbik authored and tensorflower-gardener committed Nov 27, 2019
1 parent f27ceb7 commit e2232fb
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 28 deletions.
18 changes: 8 additions & 10 deletions mlir/include/mlir/Dialect/VectorOps/VectorOps.td
Expand Up @@ -165,27 +165,25 @@ def Vector_ContractionOp :
def Vector_BroadcastOp :
Vector_Op<"broadcast", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"dest operand and result have same type",
TCresIsSameAsOpBase<0, 1>>]>,
Arguments<(ins AnyType:$source, AnyVector:$dest)>,
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Arguments<(ins AnyType:$source)>,
Results<(outs AnyVector:$vector)> {
let summary = "broadcast operation";
let description = [{
Broadcasts the scalar or k-D vector value in the source to the n-D
destination vector of a proper shape such that the broadcast makes sense.
Broadcasts the scalar or k-D vector value in the source operand
to a n-D result vector such that the broadcast makes sense.

Examples:
```
%0 = constant 0.0 : f32
%1 = vector.broadcast %0, %x : f32 into vector<16xf32>
%2 = vector.broadcast %1, %y : vector<16xf32> into vector<4x16xf32>
%1 = vector.broadcast %0 : f32 to vector<16xf32>
%2 = vector.broadcast %1 : vector<16xf32> to vector<4x16xf32>
```
}];
let extraClassDeclaration = [{
Type getSourceType() { return source()->getType(); }
VectorType getDestVectorType() {
return dest()->getType().cast<VectorType>();
VectorType getVectorType() {
return vector()->getType().cast<VectorType>();
}
}];
}
Expand Down
18 changes: 8 additions & 10 deletions mlir/lib/Dialect/VectorOps/VectorOps.cpp
Expand Up @@ -373,14 +373,14 @@ static LogicalResult verify(ExtractElementOp op) {
//===----------------------------------------------------------------------===//

static void print(OpAsmPrinter &p, BroadcastOp op) {
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest();
p << op.getOperationName() << " " << *op.source();
p << " : " << op.getSourceType();
p << " into " << op.getDestVectorType();
p << " to " << op.getVectorType();
}

static LogicalResult verify(BroadcastOp op) {
VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
VectorType dstVectorType = op.getDestVectorType();
VectorType dstVectorType = op.getVectorType();
// Scalar to vector broadcast is always valid. A vector
// to vector broadcast needs some additional checking.
if (srcVectorType) {
Expand All @@ -397,16 +397,14 @@ static LogicalResult verify(BroadcastOp op) {

static ParseResult parseBroadcastOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType source, dest;
OpAsmParser::OperandType source;
Type sourceType;
VectorType destType;
return failure(parser.parseOperand(source) || parser.parseComma() ||
parser.parseOperand(dest) ||
VectorType vectorType;
return failure(parser.parseOperand(source) ||
parser.parseColonType(sourceType) ||
parser.parseKeywordType("into", destType) ||
parser.parseKeywordType("to", vectorType) ||
parser.resolveOperand(source, sourceType, result.operands) ||
parser.resolveOperand(dest, destType, result.operands) ||
parser.addTypeToList(destType, result.types));
parser.addTypeToList(vectorType, result.types));
}

//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/VectorOps/invalid.mlir
Expand Up @@ -2,9 +2,9 @@

// -----

func @broadcast_rank_too_high(%arg0: vector<4x4xf32>, %arg1: vector<4xf32>) {
func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) {
// expected-error@+1 {{source rank higher than destination rank}}
%2 = vector.broadcast %arg0, %arg1 : vector<4x4xf32> into vector<4xf32>
%1 = vector.broadcast %arg0 : vector<4x4xf32> to vector<4xf32>
}

// -----
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Dialect/VectorOps/ops.mlir
Expand Up @@ -23,12 +23,12 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>) {
}

// CHECK-LABEL: @vector_broadcast
func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>) {
// CHECK: vector.broadcast %{{.*}}, %{{.*}} : f32 into vector<16xf32>
%0 = vector.broadcast %a, %b : f32 into vector<16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}}, %{{.*}} : vector<16xf32> into vector<8x16xf32>
%1 = vector.broadcast %b, %c : vector<16xf32> into vector<8x16xf32>
return
func @vector_broadcast(%a: f32, %b: vector<16xf32>) -> vector<8x16xf32> {
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
%0 = vector.broadcast %a : f32 to vector<16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
%1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32>
return %1 : vector<8x16xf32>
}

// CHECK-LABEL: @extractelement
Expand Down

0 comments on commit e2232fb

Please sign in to comment.