Skip to content

Commit

Permalink
[mlir] Support verification order (2/3)
Browse files Browse the repository at this point in the history
    This change gives explicit order of verifier execution and adds
    `hasRegionVerifier` and `verifyWithRegions` to increase the granularity
    of verifier classification. The orders are as below,

    1. InternalOpTrait will be verified first, they can be run independently.
    2. `verifyInvariants` which is constructed by ODS, it verifies the type,
       attributes, .etc.
    3. Other Traits/Interfaces that have marked their verifier as
       `verifyTrait` or `verifyWithRegions=0`.
    4. Custom verifier which is defined in the op and has marked
       `hasVerifier=1`

    If an operation has regions, then it may have the second phase,

    5. Traits/Interfaces that have marked their verifier as
       `verifyRegionTrait` or
       `verifyWithRegions=1`. This implies the verifier needs to access the
       operations in its regions.
    6. Custom verifier which is defined in the op and has marked
       `hasRegionVerifier=1`

    Note that the second phase will be run after the operations in the
    region are verified. Based on the verification order, you will be able to
    avoid verifying duplicate things.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D116789
  • Loading branch information
ChiaHungDuan committed Feb 25, 2022
1 parent d04d922 commit 9445b39
Show file tree
Hide file tree
Showing 28 changed files with 333 additions and 115 deletions.
37 changes: 33 additions & 4 deletions mlir/docs/OpDefinitions.md
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,39 @@ _additional_ verification, you can use
let hasVerifier = 1;
```

This will generate a `LogicalResult verify()` method declaration on the op class
that can be defined with any additional verification constraints. This method
will be invoked after the auto-generated verification code. The order of trait
verification excluding those of `hasVerifier` should not be relied upon.
or

```tablegen
let hasRegionVerifier = 1;
```

This will generate either `LogicalResult verify()` or
`LogicalResult verifyRegions()` method declaration on the op class
that can be defined with any additional verification constraints. These method
will be invoked on its verification order.

#### Verification Ordering

The verification of an operation involves several steps,

1. StructuralOpTrait will be verified first, they can be run independently.
1. `verifyInvariants` which is constructed by ODS, it verifies the type,
attributes, .etc.
1. Other Traits/Interfaces that have marked their verifier as `verifyTrait` or
`verifyWithRegions=0`.
1. Custom verifier which is defined in the op and has marked `hasVerifier=1`

If an operation has regions, then it may have the second phase,

1. Traits/Interfaces that have marked their verifier as `verifyRegionTrait` or
`verifyWithRegions=1`. This implies the verifier needs to access the
operations in its regions.
1. Custom verifier which is defined in the op and has marked
`hasRegionVerifier=1`

Note that the second phase will be run after the operations in the region are
verified. Verifiers further down the order can rely on certain invariants being
verified by a previous verifier and do not need to re-verify them.

### Declarative Assembly Format

Expand Down
14 changes: 9 additions & 5 deletions mlir/docs/Traits.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ class MyTrait : public TraitBase<ConcreteType, MyTrait> {
};
```
Operation traits may also provide a `verifyTrait` hook, that is called when
verifying the concrete operation. The trait verifiers will currently always be
invoked before the main `Op::verify`.
Operation traits may also provide a `verifyTrait` or `verifyRegionTrait` hook
that is called when verifying the concrete operation. The difference between
these two is that whether the verifier needs to access the regions, if so, the
operations in the regions will be verified before the verification of this
trait. The [verification order](OpDefinitions.md/#verification-ordering)
determines when a verifier will be invoked.
```c++
template <typename ConcreteType>
Expand All @@ -53,8 +56,9 @@ public:
```

Note: It is generally good practice to define the implementation of the
`verifyTrait` hook out-of-line as a free function when possible to avoid
instantiating the implementation for every concrete operation type.
`verifyTrait` or `verifyRegionTrait` hook out-of-line as a free function when
possible to avoid instantiating the implementation for every concrete operation
type.

Operation traits may also provide a `foldTrait` hook that is called when folding
the concrete operation. The trait folders will only be invoked if the concrete
Expand Down
10 changes: 6 additions & 4 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ bool isTopLevelValue(Value value);
class AffineDmaStartOp
: public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable,
OpTrait::VariadicOperands, OpTrait::ZeroResult,
AffineMapAccessInterface::Trait> {
OpTrait::OpInvariants, AffineMapAccessInterface::Trait> {
public:
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
Expand Down Expand Up @@ -227,7 +227,8 @@ class AffineDmaStartOp
static StringRef getOperationName() { return "affine.dma_start"; }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult verifyInvariants();
LogicalResult verifyInvariantsImpl();
LogicalResult verifyInvariants() { return verifyInvariantsImpl(); }
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);

Expand Down Expand Up @@ -268,7 +269,7 @@ class AffineDmaStartOp
class AffineDmaWaitOp
: public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable,
OpTrait::VariadicOperands, OpTrait::ZeroResult,
AffineMapAccessInterface::Trait> {
OpTrait::OpInvariants, AffineMapAccessInterface::Trait> {
public:
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
Expand Down Expand Up @@ -315,7 +316,8 @@ class AffineDmaWaitOp
static StringRef getTagMapAttrName() { return "tag_map"; }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult verifyInvariants();
LogicalResult verifyInvariantsImpl();
LogicalResult verifyInvariants() { return verifyInvariantsImpl(); }
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
};
Expand Down
48 changes: 37 additions & 11 deletions mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -2023,6 +2023,10 @@ class PredAttrTrait<string descr, Pred pred> : PredTrait<descr, pred>;
// OpTrait definitions
//===----------------------------------------------------------------------===//

// A trait that describes the structure of operation will be marked with
// `StructuralOpTrait` and they will be verified first.
class StructuralOpTrait;

// These classes are used to define operation specific traits.
class NativeOpTrait<string name, list<Trait> traits = []>
: NativeTrait<name, "Op"> {
Expand Down Expand Up @@ -2053,7 +2057,8 @@ class PredOpTrait<string descr, Pred pred, list<Trait> traits = []>
// Op defines an affine scope.
def AffineScope : NativeOpTrait<"AffineScope">;
// Op defines an automatic allocation scope.
def AutomaticAllocationScope : NativeOpTrait<"AutomaticAllocationScope">;
def AutomaticAllocationScope :
NativeOpTrait<"AutomaticAllocationScope">;
// Op supports operand broadcast behavior.
def ResultsBroadcastableShape :
NativeOpTrait<"ResultsBroadcastableShape">;
Expand All @@ -2074,9 +2079,11 @@ def SameTypeOperands : NativeOpTrait<"SameTypeOperands">;
// Op has same shape for all operands.
def SameOperandsShape : NativeOpTrait<"SameOperandsShape">;
// Op has same operand and result shape.
def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">;
def SameOperandsAndResultShape :
NativeOpTrait<"SameOperandsAndResultShape">;
// Op has the same element type (or type itself, if scalar) for all operands.
def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">;
def SameOperandsElementType :
NativeOpTrait<"SameOperandsElementType">;
// Op has the same operand and result element type (or type itself, if scalar).
def SameOperandsAndResultElementType :
NativeOpTrait<"SameOperandsAndResultElementType">;
Expand Down Expand Up @@ -2104,21 +2111,23 @@ def ElementwiseMappable : TraitList<[
]>;

// Op's regions have a single block.
def SingleBlock : NativeOpTrait<"SingleBlock">;
def SingleBlock : NativeOpTrait<"SingleBlock">, StructuralOpTrait;

// Op's regions have a single block with the specified terminator.
class SingleBlockImplicitTerminator<string op>
: ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>;
: ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>,
StructuralOpTrait;

// Op's regions don't have terminator.
def NoTerminator : NativeOpTrait<"NoTerminator">;
def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait;

// Op's parent operation is the provided one.
class HasParent<string op>
: ParamNativeOpTrait<"HasParent", op>;
: ParamNativeOpTrait<"HasParent", op>, StructuralOpTrait;

class ParentOneOf<list<string> ops>
: ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>;
: ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>,
StructuralOpTrait;

// Op result type is derived from the first attribute. If the attribute is an
// subclass of `TypeAttrBase`, its value is used, otherwise, the type of the
Expand Down Expand Up @@ -2147,13 +2156,15 @@ def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
// vector that has the same number of elements as the number of ODS declared
// operands. That means even if some operands are non-variadic, the attribute
// still need to have an element for its size, which is always 1.
def AttrSizedOperandSegments : NativeOpTrait<"AttrSizedOperandSegments">;
def AttrSizedOperandSegments :
NativeOpTrait<"AttrSizedOperandSegments">, StructuralOpTrait;
// Similar to AttrSizedOperandSegments, but used for results. The attribute
// should be named as `result_segment_sizes`.
def AttrSizedResultSegments : NativeOpTrait<"AttrSizedResultSegments">;
def AttrSizedResultSegments :
NativeOpTrait<"AttrSizedResultSegments">, StructuralOpTrait;

// Op attached regions have no arguments
def NoRegionArguments : NativeOpTrait<"NoRegionArguments">;
def NoRegionArguments : NativeOpTrait<"NoRegionArguments">, StructuralOpTrait;

//===----------------------------------------------------------------------===//
// OpInterface definitions
Expand Down Expand Up @@ -2191,6 +2202,11 @@ class OpInterfaceTrait<string name, code verifyBody = [{}],
// the operation being verified.
code verify = verifyBody;

// A bit indicating if the verifier needs to access the ops in the regions. If
// it set to `1`, the region ops will be verified before invoking this
// verifier.
bit verifyWithRegions = 0;

// Specify the list of traits that need to be verified before the verification
// of this OpInterfaceTrait.
list<Trait> dependentTraits = traits;
Expand Down Expand Up @@ -2467,6 +2483,16 @@ class Op<Dialect dialect, string mnemonic, list<Trait> props = []> {
// operation class. The operation should implement this method and verify the
// additional necessary invariants.
bit hasVerifier = 0;

// A bit indicating if the operation has additional invariants that need to
// verified and which associate with regions (aside from those verified by the
// traits). If set to `1`, an additional `LogicalResult verifyRegions()`
// declaration will be generated on the operation class. The operation should
// implement this method and verify the additional necessary invariants
// associated with regions. Note that this method is invoked after all the
// region ops are verified.
bit hasRegionVerifier = 0;

// A custom code block corresponding to the extra verification code of the
// operation.
// NOTE: This field is deprecated in favor of `hasVerifier` and is slated for
Expand Down
63 changes: 60 additions & 3 deletions mlir/include/mlir/IR/OpDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ class OpState {
protected:
/// If the concrete type didn't implement a custom verifier hook, just fall
/// back to this one which accepts everything.
LogicalResult verifyInvariants() { return success(); }
LogicalResult verify() { return success(); }
LogicalResult verifyRegions() { return success(); }

/// Parse the custom form of an operation. Unless overridden, this method will
/// first try to get an operation parser from the op's dialect. Otherwise the
Expand Down Expand Up @@ -376,6 +377,18 @@ struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
};
} // namespace detail

/// `verifyInvariantsImpl` verifies the invariants like the types, attrs, .etc.
/// It should be run after core traits and before any other user defined traits.
/// In order to run it in the correct order, wrap it with OpInvariants trait so
/// that tblgen will be able to put it in the right order.
template <typename ConcreteType>
class OpInvariants : public TraitBase<ConcreteType, OpInvariants> {
public:
static LogicalResult verifyTrait(Operation *op) {
return cast<ConcreteType>(op).verifyInvariantsImpl();
}
};

/// This class provides the API for ops that are known to have no
/// SSA operand.
template <typename ConcreteType>
Expand Down Expand Up @@ -1572,6 +1585,14 @@ using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
template <typename T>
using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;

/// Trait to check if T provides a `verifyTrait` method.
template <typename T, typename... Args>
using has_verify_region_trait =
decltype(T::verifyRegionTrait(std::declval<Operation *>()));
template <typename T>
using detect_has_verify_region_trait =
llvm::is_detected<has_verify_region_trait, T>;

/// The internal implementation of `verifyTraits` below that returns the result
/// of verifying the current operation with all of the provided trait types
/// `Ts`.
Expand All @@ -1589,6 +1610,26 @@ template <typename TraitTupleT>
static LogicalResult verifyTraits(Operation *op) {
return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
}

/// The internal implementation of `verifyRegionTraits` below that returns the
/// result of verifying the current operation with all of the provided trait
/// types `Ts`.
template <typename... Ts>
static LogicalResult verifyRegionTraitsImpl(Operation *op,
std::tuple<Ts...> *) {
LogicalResult result = success();
(void)std::initializer_list<int>{
(result = succeeded(result) ? Ts::verifyRegionTrait(op) : failure(),
0)...};
return result;
}

/// Given a tuple type containing a set of traits that contain a
/// `verifyTrait` method, return the result of verifying the given operation.
template <typename TraitTupleT>
static LogicalResult verifyRegionTraits(Operation *op) {
return verifyRegionTraitsImpl(op, (TraitTupleT *)nullptr);
}
} // namespace op_definition_impl

//===----------------------------------------------------------------------===//
Expand All @@ -1603,7 +1644,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
public:
/// Inherit getOperation from `OpState`.
using OpState::getOperation;
using OpState::verifyInvariants;
using OpState::verify;
using OpState::verifyRegions;

/// Return if this operation contains the provided trait.
template <template <typename T> class Trait>
Expand Down Expand Up @@ -1704,6 +1746,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
using VerifiableTraitsTupleT =
typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
Traits<ConcreteType>...>::type;
/// A tuple type containing the region traits that have a verify function.
using VerifiableRegionTraitsTupleT = typename detail::FilterTypes<
op_definition_impl::detect_has_verify_region_trait,
Traits<ConcreteType>...>::type;

/// Returns an interface map containing the interfaces registered to this
/// operation.
Expand Down Expand Up @@ -1839,11 +1885,22 @@ class Op : public OpState, public Traits<ConcreteType>... {
"Op class shouldn't define new data members");
return failure(
failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
failed(cast<ConcreteType>(op).verifyInvariants()));
failed(cast<ConcreteType>(op).verify()));
}
static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants);
}
/// Implementation of `VerifyRegionInvariantsFn` OperationName hook.
static LogicalResult verifyRegionInvariants(Operation *op) {
static_assert(hasNoDataMembers(),
"Op class shouldn't define new data members");
return failure(failed(op_definition_impl::verifyRegionTraits<
VerifiableRegionTraitsTupleT>(op)) ||
failed(cast<ConcreteType>(op).verifyRegions()));
}
static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);
}

static constexpr bool hasNoDataMembers() {
// Checking that the derived class does not define any member by comparing
Expand Down
20 changes: 14 additions & 6 deletions mlir/include/mlir/IR/OperationSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class OperationName {
llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
using VerifyInvariantsFn =
llvm::unique_function<LogicalResult(Operation *) const>;
using VerifyRegionInvariantsFn =
llvm::unique_function<LogicalResult(Operation *) const>;

protected:
/// This class represents a type erased version of an operation. It contains
Expand Down Expand Up @@ -112,6 +114,7 @@ class OperationName {
ParseAssemblyFn parseAssemblyFn;
PrintAssemblyFn printAssemblyFn;
VerifyInvariantsFn verifyInvariantsFn;
VerifyRegionInvariantsFn verifyRegionInvariantsFn;

/// A list of attribute names registered to this operation in StringAttr
/// form. This allows for operation classes to use StringAttr for attribute
Expand Down Expand Up @@ -238,16 +241,18 @@ class RegisteredOperationName : public OperationName {
static void insert(Dialect &dialect) {
insert(T::getOperationName(), dialect, TypeID::get<T>(),
T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
T::getVerifyInvariantsFn(), T::getFoldHookFn(),
T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(),
T::getHasTraitFn(), T::getAttributeNames());
T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(),
T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(),
T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames());
}
/// The use of this method is in general discouraged in favor of
/// 'insert<CustomOp>(dialect)'.
static void
insert(StringRef name, Dialect &dialect, TypeID typeID,
ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
VerifyInvariantsFn &&verifyInvariants,
VerifyRegionInvariantsFn &&verifyRegionInvariants,
FoldHookFn &&foldHook,
GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
ArrayRef<StringRef> attrNames);
Expand All @@ -272,12 +277,15 @@ class RegisteredOperationName : public OperationName {
return impl->printAssemblyFn(op, p, defaultDialect);
}

/// This hook implements the verifier for this operation. It should emits an
/// error message and returns failure if a problem is detected, or returns
/// These hooks implement the verifiers for this operation. It should emits
/// an error message and returns failure if a problem is detected, or returns
/// success if everything is ok.
LogicalResult verifyInvariants(Operation *op) const {
return impl->verifyInvariantsFn(op);
}
LogicalResult verifyRegionInvariants(Operation *op) const {
return impl->verifyRegionInvariantsFn(op);
}

/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
Expand Down
Loading

0 comments on commit 9445b39

Please sign in to comment.