Skip to content

Commit

Permalink
[mlir][ODS] Add support for optional operands and results with a new …
Browse files Browse the repository at this point in the history
…Optional directive.

Summary: This revision adds support for specifying operands or results as "optional". This is a special case of variadic where the number of elements is either 0 or 1. Operands and results of this kind will have accessors generated using Value instead of the range types, making it more natural to interface with.

Differential Revision: https://reviews.llvm.org/D77863
  • Loading branch information
River707 committed Apr 10, 2020
1 parent 2a922da commit aba1acc
Show file tree
Hide file tree
Showing 21 changed files with 379 additions and 121 deletions.
29 changes: 23 additions & 6 deletions mlir/docs/OpDefinitions.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,28 @@ To declare a variadic operand, wrap the `TypeConstraint` for the operand with

Normally operations have no variadic operands or just one variadic operand. For
the latter case, it is easy to deduce which dynamic operands are for the static
variadic operand definition. But if an operation has more than one variadic
operands, it would be impossible to attribute dynamic operands to the
corresponding static variadic operand definitions without further information
from the operation. Therefore, the `SameVariadicOperandSize` trait is needed to
indicate that all variadic operands have the same number of dynamic values.
variadic operand definition. Though, if an operation has more than one variable
length operands (either optional or variadic), it would be impossible to
attribute dynamic operands to the corresponding static variadic operand
definitions without further information from the operation. Therefore, either
the `SameVariadicOperandSize` or `AttrSizedOperandSegments` trait is needed to
indicate that all variable length operands have the same number of dynamic
values.

#### Optional operands

To declare an optional operand, wrap the `TypeConstraint` for the operand with
`Optional<...>`.

Normally operations have no optional operands or just one optional operand. For
the latter case, it is easy to deduce which dynamic operands are for the static
operand definition. Though, if an operation has more than one variable length
operands (either optional or variadic), it would be impossible to attribute
dynamic operands to the corresponding static variadic operand definitions
without further information from the operation. Therefore, either the
`SameVariadicOperandSize` or `AttrSizedOperandSegments` trait is needed to
indicate that all variable length operands have the same number of dynamic
values.

#### Optional attributes

Expand Down Expand Up @@ -693,7 +710,7 @@ information. An optional group is defined by wrapping a set of elements within
the group.
- Any attribute variable may be used, but only optional attributes can be
marked as the anchor.
- Only variadic, i.e. optional, operand arguments can be used.
- Only variadic or optional operand arguments can be used.
- The operands to a type directive must be defined within the optional
group.

Expand Down
10 changes: 7 additions & 3 deletions mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,17 @@ class DialectType<Dialect d, Pred condition, string descr = ""> :
}

// A variadic type constraint. It expands to zero or more of the base type. This
// class is used for supporting variadic operands/results. An op can declare no
// more than one variadic operand/result, and that operand/result must be the
// last one in the operand/result list.
// class is used for supporting variadic operands/results.
class Variadic<Type type> : TypeConstraint<type.predicate, type.description> {
Type baseType = type;
}

// An optional type constraint. It expands to either zero or one of the base
// type. This class is used for supporting optional operands/results.
class Optional<Type type> : TypeConstraint<type.predicate, type.description> {
Type baseType = type;
}

// A type that can be constructed using MLIR::Builder.
// Note that this does not "inherit" from Type because it would require
// duplicating Type subclasses for buildable and non-buildable cases to avoid
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,9 @@ class OpAsmParser {
/// Parse a type.
virtual ParseResult parseType(Type &result) = 0;

/// Parse an optional type.
virtual OptionalParseResult parseOptionalType(Type &result) = 0;

/// Parse a type of a specific type.
template <typename TypeT>
ParseResult parseType(TypeT &result) {
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/TableGen/Argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,14 @@ struct NamedAttribute {
struct NamedTypeConstraint {
// Returns true if this operand/result has constraint to be satisfied.
bool hasPredicate() const;
// Returns true if this is an optional type constraint. This is a special case
// of variadic for 0 or 1 type.
bool isOptional() const;
// Returns true if this operand/result is variadic.
bool isVariadic() const;
// Returns true if this is a variable length type constraint. This is either
// variadic or optional.
bool isVariableLength() const { return isOptional() || isVariadic(); }

llvm::StringRef name;
TypeConstraint constraint;
Expand Down
8 changes: 4 additions & 4 deletions mlir/include/mlir/TableGen/Operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class Operator {
using value_iterator = NamedTypeConstraint *;
using value_range = llvm::iterator_range<value_iterator>;

// Returns true if this op has variadic operands or results.
// Returns true if this op has variable length operands or results.
bool isVariadic() const;

// Returns true if default builders should not be generated.
Expand All @@ -115,8 +115,8 @@ class Operator {
// Returns the `index`-th result's decorators.
var_decorator_range getResultDecorators(int index) const;

// Returns the number of variadic results in this operation.
unsigned getNumVariadicResults() const;
// Returns the number of variable length results in this operation.
unsigned getNumVariableLengthResults() const;

// Op attribute iterators.
using attribute_iterator = const NamedAttribute *;
Expand All @@ -142,7 +142,7 @@ class Operator {
}

// Returns the number of variadic operands in this operation.
unsigned getNumVariadicOperands() const;
unsigned getNumVariableLengthOperands() const;

// Returns the total number of arguments.
int getNumArgs() const { return arguments.size(); }
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/TableGen/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,16 @@ class TypeConstraint : public Constraint {

static bool classof(const Constraint *c) { return c->getKind() == CK_Type; }

// Returns true if this is an optional type constraint.
bool isOptional() const;

// Returns true if this is a variadic type constraint.
bool isVariadic() const;

// Returns true if this is a variable length type constraint. This is either
// variadic or optional.
bool isVariableLength() const { return isOptional() || isVariadic(); }

// Returns the builder call for this constraint if this is a buildable type,
// returns None otherwise.
Optional<StringRef> getBuilderCall() const;
Expand Down
33 changes: 33 additions & 0 deletions mlir/lib/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ class Parser {
ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);

/// Optionally parse a type.
OptionalParseResult parseOptionalType(Type &type);

/// Parse an arbitrary type.
Type parseType();

Expand Down Expand Up @@ -899,6 +902,31 @@ ParseResult Parser::parseToken(Token::Kind expectedToken,
// Type Parsing
//===----------------------------------------------------------------------===//

/// Optionally parse a type.
OptionalParseResult Parser::parseOptionalType(Type &type) {
// There are many different starting tokens for a type, check them here.
switch (getToken().getKind()) {
case Token::l_paren:
case Token::kw_memref:
case Token::kw_tensor:
case Token::kw_complex:
case Token::kw_tuple:
case Token::kw_vector:
case Token::inttype:
case Token::kw_bf16:
case Token::kw_f16:
case Token::kw_f32:
case Token::kw_f64:
case Token::kw_index:
case Token::kw_none:
case Token::exclamation_identifier:
return failure(!(type = parseType()));

default:
return llvm::None;
}
}

/// Parse an arbitrary type.
///
/// type ::= function-type
Expand Down Expand Up @@ -4509,6 +4537,11 @@ class CustomOpAsmParser : public OpAsmParser {
return failure(!(result = parser.parseType()));
}

/// Parse an optional type.
OptionalParseResult parseOptionalType(Type &result) override {
return parser.parseOptionalType(result);
}

/// Parse an arrow followed by a type list.
ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override {
if (parseArrow() || parser.parseFunctionResultTypes(result))
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/TableGen/Argument.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ bool tblgen::NamedTypeConstraint::hasPredicate() const {
return !constraint.getPredicate().isNull();
}

bool tblgen::NamedTypeConstraint::isOptional() const {
return constraint.isOptional();
}

bool tblgen::NamedTypeConstraint::isVariadic() const {
return constraint.isVariadic();
}
20 changes: 8 additions & 12 deletions mlir/lib/TableGen/Operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ StringRef tblgen::Operator::getExtraClassDeclaration() const {

const llvm::Record &tblgen::Operator::getDef() const { return def; }

bool tblgen::Operator::isVariadic() const {
return getNumVariadicOperands() != 0 || getNumVariadicResults() != 0;
}

bool tblgen::Operator::skipDefaultBuilders() const {
return def.getValueAsBit("skipDefaultBuilders");
}
Expand Down Expand Up @@ -119,16 +115,16 @@ auto tblgen::Operator::getResultDecorators(int index) const
return *result->getValueAsListInit("decorators");
}

unsigned tblgen::Operator::getNumVariadicResults() const {
return std::count_if(
results.begin(), results.end(),
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
unsigned tblgen::Operator::getNumVariableLengthResults() const {
return llvm::count_if(results, [](const NamedTypeConstraint &c) {
return c.constraint.isVariableLength();
});
}

unsigned tblgen::Operator::getNumVariadicOperands() const {
return std::count_if(
operands.begin(), operands.end(),
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
unsigned tblgen::Operator::getNumVariableLengthOperands() const {
return llvm::count_if(operands, [](const NamedTypeConstraint &c) {
return c.constraint.isVariableLength();
});
}

tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/TableGen/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
// If this operand is variadic, then return a range. Otherwise, return the
// value itself.
if (operand->isVariadic()) {
if (operand->isVariableLength()) {
auto repl = formatv(fmt, name);
LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
return std::string(repl);
Expand Down
6 changes: 5 additions & 1 deletion mlir/lib/TableGen/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ TypeConstraint::TypeConstraint(const llvm::Record *record)
TypeConstraint::TypeConstraint(const llvm::DefInit *init)
: TypeConstraint(init->getDef()) {}

bool TypeConstraint::isOptional() const {
return def->isSubClassOf("Optional");
}

bool TypeConstraint::isVariadic() const {
return def->isSubClassOf("Variadic");
}
Expand All @@ -34,7 +38,7 @@ bool TypeConstraint::isVariadic() const {
// returns None otherwise.
Optional<StringRef> TypeConstraint::getBuilderCall() const {
const llvm::Record *baseType = def;
if (isVariadic())
if (isVariableLength())
baseType = baseType->getValueAsDef("baseType");

// Check to see if this type constraint has a builder call.
Expand Down
41 changes: 31 additions & 10 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1179,39 +1179,41 @@ def FormatBuildableTypeOp : TEST_Op<"format_buildable_type_op"> {
}

// Test various mixings of result type formatting.
class FormatResultBase<string name, string fmt> : TEST_Op<name> {
class FormatResultBase<string suffix, string fmt>
: TEST_Op<"format_result_" # suffix # "_op"> {
let results = (outs I64:$buildable_res, AnyMemRef:$result);
let assemblyFormat = fmt;
}
def FormatResultAOp : FormatResultBase<"format_result_a_op", [{
def FormatResultAOp : FormatResultBase<"a", [{
type($result) attr-dict
}]>;
def FormatResultBOp : FormatResultBase<"format_result_b_op", [{
def FormatResultBOp : FormatResultBase<"b", [{
type(results) attr-dict
}]>;
def FormatResultCOp : FormatResultBase<"format_result_c_op", [{
def FormatResultCOp : FormatResultBase<"c", [{
functional-type($buildable_res, $result) attr-dict
}]>;

// Test various mixings of operand type formatting.
class FormatOperandBase<string name, string fmt> : TEST_Op<name> {
class FormatOperandBase<string suffix, string fmt>
: TEST_Op<"format_operand_" # suffix # "_op"> {
let arguments = (ins I64:$buildable, AnyMemRef:$operand);
let assemblyFormat = fmt;
}

def FormatOperandAOp : FormatOperandBase<"format_operand_a_op", [{
def FormatOperandAOp : FormatOperandBase<"a", [{
operands `:` type(operands) attr-dict
}]>;
def FormatOperandBOp : FormatOperandBase<"format_operand_b_op", [{
def FormatOperandBOp : FormatOperandBase<"b", [{
operands `:` type($operand) attr-dict
}]>;
def FormatOperandCOp : FormatOperandBase<"format_operand_c_op", [{
def FormatOperandCOp : FormatOperandBase<"c", [{
$buildable `,` $operand `:` type(operands) attr-dict
}]>;
def FormatOperandDOp : FormatOperandBase<"format_operand_d_op", [{
def FormatOperandDOp : FormatOperandBase<"d", [{
$buildable `,` $operand `:` type($operand) attr-dict
}]>;
def FormatOperandEOp : FormatOperandBase<"format_operand_e_op", [{
def FormatOperandEOp : FormatOperandBase<"e", [{
$buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict
}]>;

Expand All @@ -1220,6 +1222,25 @@ def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> {
let assemblyFormat = "$targets attr-dict";
}

// Test various mixings of optional operand and result type formatting.
class FormatOptionalOperandResultOpBase<string suffix, string fmt>
: TEST_Op<"format_optional_operand_result_" # suffix # "_op",
[AttrSizedOperandSegments]> {
let arguments = (ins Optional<I64>:$optional, Variadic<I64>:$variadic);
let results = (outs Optional<I64>:$optional_res);
let assemblyFormat = fmt;
}

def FormatOptionalOperandResultAOp : FormatOptionalOperandResultOpBase<"a", [{
`(` $optional `:` type($optional) `)` `:` type($optional_res)
(`[` $variadic^ `]`)? attr-dict
}]>;

def FormatOptionalOperandResultBOp : FormatOptionalOperandResultOpBase<"b", [{
(`(` $optional^ `:` type($optional) `)`)? `:` type($optional_res)
(`[` $variadic^ `]`)? attr-dict
}]>;

//===----------------------------------------------------------------------===//
// Test SideEffects
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/mlir-tblgen/op-decl.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ def NS_DOp : NS_Op<"op_with_two_operands", []> {
// CHECK-LABEL: NS::DOp declarations
// CHECK: OpTrait::NOperands<2>::Impl

def NS_EOp : NS_Op<"op_with_optionals", []> {
let arguments = (ins Optional<I32>:$a);
let results = (outs Optional<F32>:$b);
}

// CHECK-LABEL: NS::EOp declarations
// CHECK: Value a();
// CHECK: Value b();
// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, /*optional*/Type b, /*optional*/Value a)

// Check that default builders can be suppressed.
// ---

Expand Down
13 changes: 12 additions & 1 deletion mlir/test/mlir-tblgen/op-format-spec.td
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def OptionalInvalidF : TestFormat_Op<"optional_invalid_f", [{
def OptionalInvalidG : TestFormat_Op<"optional_invalid_g", [{
($attr^) attr-dict
}]>, Arguments<(ins I64Attr:$attr)>;
// CHECK: error: only variadic operands can be used within an optional group
// CHECK: error: only variable length operands can be used within an optional group
def OptionalInvalidH : TestFormat_Op<"optional_invalid_h", [{
($arg^) attr-dict
}]>, Arguments<(ins I64:$arg)>;
Expand Down Expand Up @@ -327,6 +327,17 @@ def ZCoverageInvalidF : TestFormat_Op<"variable_invalid_f", [{
}]> {
let successors = (successor AnySuccessor:$successor);
}
// CHECK: error: type of operand #0, named 'operand', is not buildable and a buildable type cannot be inferred
// CHECK: note: suggest adding a type constraint to the operation or adding a 'type($operand)' directive to the custom assembly format
def ZCoverageInvalidG : TestFormat_Op<"variable_invalid_g", [{
operands attr-dict
}]>, Arguments<(ins Optional<I64>:$operand)>;
// CHECK: error: type of result #0, named 'result', is not buildable and a buildable type cannot be inferred
// CHECK: note: suggest adding a type constraint to the operation or adding a 'type($result)' directive to the custom assembly format
def ZCoverageInvalidH : TestFormat_Op<"variable_invalid_h", [{
attr-dict
}]>, Results<(outs Optional<I64>:$result)>;

// CHECK-NOT: error
def ZCoverageValidA : TestFormat_Op<"variable_valid_a", [{
$operand type($operand) type($result) attr-dict
Expand Down
Loading

0 comments on commit aba1acc

Please sign in to comment.