Skip to content

Commit

Permalink
[mlir][Vector] Introduce 'vector.mask' operation and MaskableOpInterface
Browse files Browse the repository at this point in the history
This patch introduces the `vector.mask` operation and the MaskableOpInterface
as described in https://discourse.llvm.org/t/rfc-vector-masking-representation-in-mlir/64964.
The `vector.mask` operation is used to predicate the execution of operations
implementing the MaskableOpInterface. This interface will be implemented by maskable
operations and provides information about its masking constraints and semantics.

For now, only vector transfer and reduction ops implement the MaskableOpInterface
for illustration and testing purposes.

Reviewed By: nicolasvasilache, rriddle

Differential Revision: https://reviews.llvm.org/D134939
  • Loading branch information
dcaballe committed Oct 10, 2022
1 parent deb8f8a commit 2d10f81
Show file tree
Hide file tree
Showing 15 changed files with 497 additions and 5 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Interfaces)
add_subdirectory(Transforms)
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
#define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H

#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -49,6 +50,10 @@ namespace detail {
struct BitmaskEnumStorage;
} // namespace detail

/// Default callback to build a region with a 'vector.yield' terminator with no
/// arguments.
void buildTerminatedBody(OpBuilder &builder, Location loc);

/// Return whether `srcType` can be broadcast to `dstVectorType` under the
/// semantics of the `vector.broadcast` op.
enum class BroadcastableToResult {
Expand Down
84 changes: 83 additions & 1 deletion mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Expand Up @@ -13,6 +13,7 @@
#ifndef VECTOR_OPS
#define VECTOR_OPS

include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
Expand Down Expand Up @@ -283,6 +284,7 @@ def Vector_ReductionOp :
Vector_Op<"reduction", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface,
["getShapeForUnroll"]>]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
Expand Down Expand Up @@ -360,7 +362,7 @@ def Vector_MultiDimReductionOp :
}];
let builders = [
OpBuilder<(ins "Value":$source, "Value":$acc,
"ArrayRef<bool>":$reductionMask, "CombiningKind":$kind)>
"ArrayRef<bool>":$reductionMask, "CombiningKind":$kind)>
];
let extraClassDeclaration = [{
static StringRef getKindAttrStrName() { return "kind"; }
Expand Down Expand Up @@ -1050,6 +1052,7 @@ def Vector_TransferReadOp :
Vector_Op<"transfer_read", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
AttrSizedOperandSegments
]>,
Expand Down Expand Up @@ -1246,6 +1249,12 @@ def Vector_TransferReadOp :
"ValueRange":$indices,
CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
];

let extraClassDeclaration = [{
// MaskableOpInterface methods.
bool supportsPassthru() { return true; }
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
Expand All @@ -1256,6 +1265,7 @@ def Vector_TransferWriteOp :
Vector_Op<"transfer_write", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
AttrSizedOperandSegments
]>,
Expand Down Expand Up @@ -2120,6 +2130,78 @@ def Vector_CreateMaskOp :
let assemblyFormat = "$operands attr-dict `:` type(results)";
}

def Vector_MaskOp : Vector_Op<"mask", [
SingleBlockImplicitTerminator<"vector::YieldOp">, RecursiveSideEffects,
NoRegionArguments
]> {
let summary = "Predicates a maskable vector operation";
let description = [{
The `vector.mask` operation predicates the execution of another operation.
It takes an `i1` vector mask and an optional pass-thru vector as arguments.
A `vector.yield`-terminated region encloses the operation to be masked.
Values used within the region are captured from above. Only one *maskable*
operation can be masked with a `vector.mask` operation at a time. An
operation is *maskable* if it implements the `MaskableOpInterface`.

The vector mask argument holds a bit for each vector lane and determines
which vector lanes should execute the maskable operation and which ones
should not. The `vector.mask` operation returns the value produced by the
masked execution of the nested operation, if any. The masked-off lanes in
the result vector are taken from the corresponding lanes of the pass-thru
argument, if provided, or left unmodified, otherwise.

The `vector.mask` operation does not prescribe how a maskable operation
should be masked or how a masked operation should be lowered. Masking
constraints and some semantic details are provided by each maskable
operation through the `MaskableOpInterface`. Lowering of masked operations
is implementation defined. For instance, scalarizing the masked operation
or executing the operation for the masked-off lanes are valid lowerings as
long as the execution of masked-off lanes does not change the observable
behavior of the program.

Examples:

```
%0 = vector.mask %mask { vector.reduction <add>, %a : vector<8xi32> into i32 } : vector<8xi1> -> i32
```

```
%0 = vector.mask %mask, %passthru { arith.divsi %a, %b : vector<8xi32> } : vector<8xi1> -> vector<8xi32>
```

```
vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref<?xf32> } : vector<16xi1>
```
}];

// TODO: Support multiple results and passthru values.
let arguments = (ins VectorOf<[I1]>:$mask,
Optional<AnyType>:$passthru);
let results = (outs Optional<AnyType>:$results);
let regions = (region SizedRegion<1>:$maskRegion);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Value":$mask,
CArg<"function_ref<void(OpBuilder &, Location)>",
"buildTerminatedBody">:$maskRegion)>,
OpBuilder<(ins "Type":$resultType, "Value":$mask,
CArg<"function_ref<void(OpBuilder &, Location)>",
"buildTerminatedBody">:$maskRegion)>,
OpBuilder<(ins "Type":$resultType, "Value":$mask,
"Value":$passthru,
CArg<"function_ref<void(OpBuilder &, Location)>",
"buildTerminatedBody">:$maskRegion)>
];

let extraClassDeclaration = [{
static void ensureTerminator(Region &region, Builder &builder, Location loc);
}];

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

def Vector_TransposeOp :
Vector_Op<"transpose", [NoSideEffect,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/Interfaces/CMakeLists.txt
@@ -0,0 +1 @@
add_mlir_interface(MaskingInterfaces)
22 changes: 22 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h
@@ -0,0 +1,22 @@
//===- MaskingInterfaces.h - Masking interfaces ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the interfaces for masking operations.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_
#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"

/// Include the generated interface declarations.
#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h.inc"

#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_
52 changes: 52 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td
@@ -0,0 +1,52 @@
//===- MaskingInterfaces.td - Masking Interfaces Decls === -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the definition file for vector masking related interfaces.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES
#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES

include "mlir/IR/OpBase.td"

def MaskableOpInterface : OpInterface<"MaskableOpInterface"> {
let description = [{
The 'MaskableOpInterface' define an operation that can be masked using the
`vector.mask` operation and provides information about its masking
constraints and semantics.
}];
let cppNamespace = "::mlir::vector";
let methods = [
InterfaceMethod<
/*desc=*/"Returns true if the operation may have a passthru argument when"
" masked.",
/*retTy=*/"bool",
/*methodName=*/"supportsPassthru",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
}]>,
InterfaceMethod<
/*desc=*/"Returns the mask type expected by this operation. It requires the"
" operation to be vectorized.",
/*retTy=*/"mlir::VectorType",
/*methodName=*/"getExpectedMaskType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Default implementation is only aimed for operations that implement the
// `getVectorType()` method.
return $_op.getVectorType().cloneWith(
/*shape=*/llvm::None, IntegerType::get($_op.getContext(), /*width=*/1));
}]>,
];
}

#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Interfaces)
add_subdirectory(Transforms)
add_subdirectory(Utils)
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/IR/CMakeLists.txt
Expand Up @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRVectorDialect
MLIRDataLayoutInterfaces
MLIRDialectUtils
MLIRIR
MLIRMaskingInterfaces
MLIRMemRefDialect
MLIRSideEffectInterfaces
MLIRTensorDialect
Expand Down

0 comments on commit 2d10f81

Please sign in to comment.