Skip to content

Commit

Permalink
[mlir][Vector] Introduce the MaskingOpInterface
Browse files Browse the repository at this point in the history
This MaskingOpInterface provides masking cababilitites to those
operations that implement it. For only is only implemented by the `vector.mask`
operation and it's used to break the dependency between the Vector
dialect (where the `vector.mask` op lives) and operations implementing
the MaskableOpInterface.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D136734
  • Loading branch information
dcaballe committed Oct 27, 2022
1 parent 5e4eec9 commit b1bc1a1
Show file tree
Hide file tree
Showing 14 changed files with 265 additions and 88 deletions.
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Expand Up @@ -13,7 +13,8 @@
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
#define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H

#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down
13 changes: 8 additions & 5 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Expand Up @@ -13,7 +13,8 @@
#ifndef VECTOR_OPS
#define VECTOR_OPS

include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td"
include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
Expand Down Expand Up @@ -2140,13 +2141,15 @@ def Vector_CreateMaskOp :
}

def Vector_MaskOp : Vector_Op<"mask", [
SingleBlockImplicitTerminator<"vector::YieldOp">, RecursiveMemoryEffects,
NoRegionArguments
SingleBlockImplicitTerminator<"vector::YieldOp">,
DeclareOpInterfaceMethods<MaskingOpInterface>,
RecursiveMemoryEffects, NoRegionArguments
]> {
let summary = "Predicates a maskable vector operation";
let description = [{
The `vector.mask` operation predicates the execution of another operation.
It takes an `i1` vector mask and an optional pass-thru vector as arguments.
The `vector.mask` is a `MaskingOpInterface` operation that predicates the
execution of another operation. It takes an `i1` vector mask and an
optional passthru vector as arguments.
A `vector.yield`-terminated region encloses the operation to be masked.
Values used within the region are captured from above. Only one *maskable*
operation can be masked with a `vector.mask` operation at a time. An
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Vector/Interfaces/CMakeLists.txt
@@ -1 +1,2 @@
add_mlir_interface(MaskingInterfaces)
add_mlir_interface(MaskableOpInterface)
add_mlir_interface(MaskingOpInterface)
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h
@@ -0,0 +1,23 @@
//===- MaskableOpInterface.h ----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the MaskableOpInterface.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_H_
#define MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_H_

#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"

// Include the generated interface declarations.
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h.inc"

#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_H_
72 changes: 72 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td
@@ -0,0 +1,72 @@
//===- MaskableOpInterfaces.td - Masking Interfaces Decls -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the definition file for the MaskableOpInterface.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_TD
#define MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_TD

include "mlir/IR/OpBase.td"

def MaskableOpInterface : OpInterface<"MaskableOpInterface"> {
let description = [{
The 'MaskableOpInterface' defines an operation that can be masked using a
MaskingOpInterface (e.g., `vector.mask`) and provides information about its
masking constraints and semantics.
}];
let cppNamespace = "::mlir::vector";
let methods = [
InterfaceMethod<
/*desc=*/"Returns true if the operation is masked by a "
"MaskingOpInterface.",
/*retTy=*/"bool",
/*methodName=*/"isMasked",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return mlir::isa<mlir::vector::MaskingOpInterface>($_op->getParentOp());
}]>,
InterfaceMethod<
/*desc=*/"Returns the MaskingOpInterface masking this operation.",
/*retTy=*/"mlir::vector::MaskingOpInterface",
/*methodName=*/"getMaskingOp",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return mlir::cast<mlir::vector::MaskingOpInterface>(
$_op->getParentOp());
}]>,
InterfaceMethod<
/*desc=*/"Returns true if the operation can have a passthru argument when"
" masked.",
/*retTy=*/"bool",
/*methodName=*/"supportsPassthru",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
}]>,
InterfaceMethod<
/*desc=*/"Returns the mask type expected by this operation. It requires "
"the operation to be vectorized.",
/*retTy=*/"mlir::VectorType",
/*methodName=*/"getExpectedMaskType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Default implementation is only aimed for operations that implement the
// `getVectorType()` method.
return $_op.getVectorType().cloneWith(/*shape=*/llvm::None,
IntegerType::get($_op.getContext(), /*width=*/1));
}]>,
];
}

#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_TD
52 changes: 0 additions & 52 deletions mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td

This file was deleted.

@@ -1,22 +1,22 @@
//===- MaskingInterfaces.h - Masking interfaces ---------------------------===//
//===- MaskingOpInterface.h -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the interfaces for masking operations.
// This file implements the MaskingOpInterface.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_
#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_
#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_H_
#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_H_

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"

/// Include the generated interface declarations.
#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h.inc"
// Include the generated interface declarations.
#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h.inc"

#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_
#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_H_
58 changes: 58 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td
@@ -0,0 +1,58 @@
//===- MaskingOpInterfaces.td - MaskingOpInterface Decls = -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the definition file for the MaskingOpInterface.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_TD
#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_TD

include "mlir/IR/OpBase.td"

def MaskingOpInterface : OpInterface<"MaskingOpInterface"> {
let description = [{
The 'MaskingOpInterface' defines an vector operation that can apply masking
to its own or other vector operations.
}];
let cppNamespace = "::mlir::vector";
let methods = [
InterfaceMethod<
/*desc=*/"Returns the mask value of this masking operation.",
/*retTy=*/"mlir::Value",
/*methodName=*/"getMask",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"">,
InterfaceMethod<
/*desc=*/"Returns the operation masked by this masking operation.",
// TODO: Return a MaskableOpInterface when interface infra can handle
// dependences between interfaces.
/*retTy=*/"Operation *",
/*methodName=*/"getMaskableOp",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"">,
InterfaceMethod<
/*desc=*/"Returns true if the masking operation has a passthru value.",
/*retTy=*/"bool",
/*methodName=*/"hasPassthru",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"">,
InterfaceMethod<
/*desc=*/"Returns the passthru value of this masking operation.",
/*retTy=*/"mlir::Value",
/*methodName=*/"getPassthru",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"">,
];
}

#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_TD
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Vector/IR/CMakeLists.txt
Expand Up @@ -15,7 +15,8 @@ add_mlir_dialect_library(MLIRVectorDialect
MLIRDestinationStyleOpInterface
MLIRDialectUtils
MLIRIR
MLIRMaskingInterfaces
MLIRMaskableOpInterface
MLIRMaskingOpInterface
MLIRMemRefDialect
MLIRSideEffectInterfaces
MLIRTensorDialect
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Expand Up @@ -5003,6 +5003,14 @@ LogicalResult MaskOp::verify() {
return success();
}

// MaskingOpInterface definitions.

/// Returns the operation masked by this 'vector.mask'.
Operation *MaskOp::getMaskableOp() { return &getMaskRegion().front().front(); }

/// Returns true if 'vector.mask' has a passthru value.
bool MaskOp::hasPassthru() { return getPassthru() != Value(); }

//===----------------------------------------------------------------------===//
// ScanOp
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/Vector/Interfaces/CMakeLists.txt
@@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
MaskingInterfaces.cpp
MaskableOpInterface.cpp
MaskingOpInterface.cpp
)

function(add_mlir_interface_library name)
Expand All @@ -17,5 +18,5 @@ function(add_mlir_interface_library name)
)
endfunction(add_mlir_interface_library)

add_mlir_interface_library(MaskingInterfaces)

add_mlir_interface_library(MaskableOpInterface)
add_mlir_interface_library(MaskingOpInterface)
@@ -1,16 +1,18 @@
//===- MaskingInterfaces.cpp - Masking interfaces ----------====-*- C++ -*-===//
//===- MaskableOpInterfaces.cpp - MaskableOpInterface Defs -====-*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"

using namespace mlir;
using namespace mlir::vector;

//===----------------------------------------------------------------------===//
// Masking Interfaces
// MaskableOpInterface Defs
//===----------------------------------------------------------------------===//

/// Include the definitions of the masking interfaces.
#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.cpp.inc"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.cpp.inc"
18 changes: 18 additions & 0 deletions mlir/lib/Dialect/Vector/Interfaces/MaskingOpInterface.cpp
@@ -0,0 +1,18 @@
//===- MaskingOpInterface.cpp - MaskingOpInterface Defs -----====-*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"

using namespace mlir;
using namespace mlir::vector;

//===----------------------------------------------------------------------===//
// MaskingOpInterface Defs
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.cpp.inc"

0 comments on commit b1bc1a1

Please sign in to comment.