/
InferTypeOpInterface.td
134 lines (120 loc) · 5.29 KB
/
InferTypeOpInterface.td
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
//===- InferTypeOpInterface.td - Infer Type interfaces -----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains a set of interfaces that can be used to define information
// related to type inference.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INFERTYPEOPINTERFACE
#define MLIR_INFERTYPEOPINTERFACE
include "mlir/IR/OpBase.td"
// OpInterface to compute the return type of an operation. The arguments match
// those in Operation::create with the exception that the location is optional
// (if no location is provided, then the method will not emit an error on
// mismatch).
def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
let description = [{
Interface to infer the return types for an operation that could be used
during op construction, verification or type inference.
}];
let cppNamespace = "::mlir";
let methods = [
StaticInterfaceMethod<
/*desc=*/[{Infer the return types that an op would generate.
The method takes an optional location which, if set, will be used to
report errors on. The operands and attributes correspond to those with
which an Operation would be created (e.g., as used in Operation::create)
and the regions of the op.
}],
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"inferReturnTypes",
/*args=*/(ins "::mlir::MLIRContext *":$context,
"::llvm::Optional<::mlir::Location>":$location,
"::mlir::ValueRange":$operands,
"::mlir::DictionaryAttr":$attributes,
"::mlir::RegionRange":$regions,
"::llvm::SmallVectorImpl<::mlir::Type>&":$inferredReturnTypes)
>,
StaticInterfaceMethod<
/*desc=*/"Returns whether two array of types are compatible result types"
" for an op.",
/*retTy=*/"bool",
/*methodName=*/"isCompatibleReturnTypes",
/*args=*/(ins "::mlir::TypeRange":$lhs, "::mlir::TypeRange":$rhs),
/*methodBody=*/[{
return ConcreteOp::isCompatibleReturnTypes(lhs, rhs);
}],
/*defaultImplementation=*/[{
/// Returns whether two arrays are equal as strongest check for
/// compatibility by default.
return lhs == rhs;
}]
>,
];
let verify = [{
return detail::verifyInferredResultTypes($_op);
}];
}
def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
let description = [{
Interface to infer the components of a ShapedType returned by an operation
that could be used during op construction, verification or shape inference.
The components consists of element type, shape and raw attribute.
}];
let cppNamespace = "::mlir";
let methods = [
StaticInterfaceMethod<
/*desc=*/[{Infer the components of return type of shape containter.
The method takes an optional location which, if set, will be used to
report errors on. The operands and attributes correspond to those with
which an Operation would be created (e.g., as used in Operation::create)
and the regions of the op.
Unknown (e.g., unranked) shape and nullptrs for element type and attribute
may be returned by this function while returning success. E.g., partial
population of components is not error condition.
}],
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"inferReturnTypeComponents",
/*args=*/(ins "::mlir::MLIRContext*":$context,
"::mlir::Optional<::mlir::Location>":$location,
"::mlir::ValueRange":$operands,
"::mlir::DictionaryAttr":$attributes,
"::mlir::RegionRange":$regions,
"::mlir::SmallVectorImpl<::mlir::ShapedTypeComponents>&":
$inferredReturnShapes)
>,
InterfaceMethod<
/*desc=*/[{Reify the shape computation for the operation.
Insert operations using the given OpBuilder that computes the result
shape.
}],
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"reifyReturnTypeShapes",
/*args=*/(ins "::mlir::OpBuilder&":$builder,
"::mlir::SmallVectorImpl<::mlir::Value>&":$reifiedReturnShapes),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
>,
];
}
// Convenience class grouping together type and shaped type op interfaces for
// ops that have tensor return types.
class InferTensorType<list<string> overridenMethods = []> {
list<OpTrait> traits = [
// Op implements infer type op interface.
InferTypeOpInterface,
// The op will have methods implementing the ShapedType type inference
// interface.
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>,
// The op produces tensors and will use the ShapedType type infer interface
// along with knowledge that it is producing Tensors to infer the type.
NativeOpTrait<"InferTensorType">
];
}
defvar InferTensorTypeWithReify = InferTensorType<["reifyReturnTypeShapes"]>;
#endif // MLIR_INFERTYPEOPINTERFACE