Skip to content

Commit

Permalink
[mlir][vector] Support multiple result types in vector.mask
Browse files Browse the repository at this point in the history
The verifier already had support for multiple result types, but the op definition assumed a single, optional result.

Differential Revision: https://reviews.llvm.org/D141683
  • Loading branch information
matthias-springer committed Jan 13, 2023
1 parent f601039 commit f94131a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
21 changes: 14 additions & 7 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2287,10 +2287,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
The `vector.mask` is a `MaskingOpInterface` operation that predicates the
execution of another operation. It takes an `i1` vector mask and an
optional passthru vector as arguments.
A `vector.yield`-terminated region encloses the operation to be masked.
Values used within the region are captured from above. Only one *maskable*
operation can be masked with a `vector.mask` operation at a time. An
operation is *maskable* if it implements the `MaskableOpInterface`.

A implicitly `vector.yield`-terminated region encloses the operation to be
masked. Values used within the region are captured from above. Only one
*maskable* operation can be masked with a `vector.mask` operation at a time.
An operation is *maskable* if it implements the `MaskableOpInterface`. The
terminator yields all results of the maskable operation to the result of
this operation.

The vector mask argument holds a bit for each vector lane and determines
which vector lanes should execute the maskable operation and which ones
Expand Down Expand Up @@ -2321,23 +2324,27 @@ def Vector_MaskOp : Vector_Op<"mask", [
```
vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref<?xf32> } : vector<16xi1>
```

```
vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> tensor<?xf32>
```
}];

// TODO: Support multiple results and passthru values.
let arguments = (ins VectorOf<[I1]>:$mask,
Optional<AnyType>:$passthru);
let results = (outs Optional<AnyType>:$results);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$maskRegion);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Value":$mask,
CArg<"function_ref<void(OpBuilder &, Location)>",
"buildTerminatedBody">:$maskRegion)>,
OpBuilder<(ins "Type":$resultType, "Value":$mask,
OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
CArg<"function_ref<void(OpBuilder &, Location)>",
"buildTerminatedBody">:$maskRegion)>,
OpBuilder<(ins "Type":$resultType, "Value":$mask,
OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
"Value":$passthru,
CArg<"function_ref<void(OpBuilder &, Location)>",
"buildTerminatedBody">:$maskRegion)>
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5288,20 +5288,20 @@ void MaskOp::build(
}

void MaskOp::build(
OpBuilder &builder, OperationState &result, Type resultType, Value mask,
function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
build(builder, result, resultType, mask, /*passthru=*/Value(),
OpBuilder &builder, OperationState &result, TypeRange resultTypes,
Value mask, function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
build(builder, result, resultTypes, mask, /*passthru=*/Value(),
maskRegionBuilder);
}

void MaskOp::build(
OpBuilder &builder, OperationState &result, Type resultType, Value mask,
Value passthru,
OpBuilder &builder, OperationState &result, TypeRange resultTypes,
Value mask, Value passthru,
function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
build(builder, result, mask, maskRegionBuilder);
if (passthru)
result.addOperands(passthru);
result.addTypes(resultType);
result.addTypes(resultTypes);
}

ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
Expand Down

0 comments on commit f94131a

Please sign in to comment.