1,070 changes: 1,070 additions & 0 deletions mlir/docs/AttributesAndTypes.md

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions mlir/docs/LangRef.md
Original file line number Diff line number Diff line change
Expand Up @@ -732,8 +732,7 @@ the lighter syntax: `!foo.something<a%%123^^^>>>` because it contains characters
that are not allowed in the lighter syntax, as well as unbalanced `<>`
characters.

See [here](Tutorials/DefiningAttributesAndTypes.md) to learn how to define
dialect types.
See [here](AttributesAndTypes.md) to learn how to define dialect types.

### Builtin Types

Expand Down Expand Up @@ -840,8 +839,7 @@ valid in the lighter syntax: `#foo.something<a%%123^^^>>>` because it contains
characters that are not allowed in the lighter syntax, as well as unbalanced
`<>` characters.

See [here](Tutorials/DefiningAttributesAndTypes.md) on how to define dialect
attribute values.
See [here](AttributesAndTypes.md) on how to define dialect attribute values.

### Builtin Attribute Values

Expand Down
338 changes: 0 additions & 338 deletions mlir/docs/OpDefinitions.md
Original file line number Diff line number Diff line change
Expand Up @@ -1494,344 +1494,6 @@ llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t value) {
}
```

## Type Definitions

MLIR defines the `TypeDef` class hierarchy to enable generation of data types from
their specifications. A type is defined by specializing the `TypeDef` class with
concrete contents for all the fields it requires. For example, an integer type
could be defined as:

```tablegen
// All of the types will extend this class.
class Test_Type<string name> : TypeDef<Test_Dialect, name> { }
// An alternate int type.
def IntegerType : Test_Type<"TestInteger"> {
let mnemonic = "int";
let summary = "An integer type with special semantics";
let description = [{
An alternate integer type. This type differentiates itself from the
standard integer type by not having a SignednessSemantics parameter, just
a width.
}];
let parameters = (ins "unsigned":$width);
// We define the printer inline.
let printer = [{
$_printer << "int<" << getImpl()->width << ">";
}];
// The parser is defined here also.
let parser = [{
if ($_parser.parseLess())
return Type();
int width;
if ($_parser.parseInteger(width))
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, width);
}];
}
```

### Type name

The name of the C++ class which gets generated defaults to
`<classParamName>Type` (e.g. `TestIntegerType` in the above example). This can
be overridden via the `cppClassName` field. The field `mnemonic` is to specify
the asm name for parsing. It is optional and not specifying it will imply that
no parser or printer methods are attached to this class.

### Type documentation

The `summary` and `description` fields exist and are to be used the same way as
in Operations. Namely, the summary should be a one-liner and `description`
should be a longer explanation.

### Type parameters

The `parameters` field is a list of the type's parameters. If no parameters are
specified (the default), this type is considered a singleton type. Parameters
are in the `"c++Type":$paramName` format. To use C++ types as parameters which
need allocation in the storage constructor, there are two options:

- Set `hasCustomStorageConstructor` to generate the TypeStorage class with a
constructor which is just declared -- no definition -- so you can write it
yourself.
- Use the `TypeParameter` tablegen class instead of the "c++Type" string.

### TypeParameter tablegen class

This is used to further specify attributes about each of the types parameters.
It includes documentation (`summary` and `syntax`), the C++ type to use, a
custom allocator to use in the storage constructor method, and a custom
comparator to decide if two instances of the parameter type are equal.

```tablegen
// DO NOT DO THIS!
let parameters = (ins "ArrayRef<int>":$dims);
```

The default storage constructor blindly copies fields by value. It does not know
anything about the types. In this case, the ArrayRef<int> requires allocation
with `dims = allocator.copyInto(dims)`.

You can specify the necessary constructor by specializing the `TypeParameter`
tblgen class:

```tablegen
class ArrayRefIntParam :
TypeParameter<"::llvm::ArrayRef<int>", "Array of ints"> {
let allocator = "$_dst = $_allocator.copyInto($_self);";
}
...
let parameters = (ins ArrayRefIntParam:$dims);
```

The `allocator` code block has the following substitutions:

- `$_allocator` is the TypeStorageAllocator in which to allocate objects.
- `$_dst` is the variable in which to place the allocated data.

The `comparator` code block has the following substitutions:

- `$_lhs` is an instance of the parameter type.
- `$_rhs` is an instance of the parameter type.

MLIR includes several specialized classes for common situations:

- `StringRefParameter<descriptionOfParam>` for StringRefs.
- `ArrayRefParameter<arrayOf, descriptionOfParam>` for ArrayRefs of value
types
- `SelfAllocationParameter<descriptionOfParam>` for C++ classes which contain
a method called `allocateInto(StorageAllocator &allocator)` to allocate
itself into `allocator`.
- `ArrayRefOfSelfAllocationParameter<arrayOf, descriptionOfParam>` for arrays
of objects which self-allocate as per the last specialization.

If we were to use one of these included specializations:

```tablegen
let parameters = (ins
ArrayRefParameter<"int", "The dimensions">:$dims
);
```

### Parsing and printing

If a mnemonic is specified, the `printer` and `parser` code fields are active.
The rules for both are:

- If null, generate just the declaration.
- If non-null and non-empty, use the code in the definition. The `$_printer`
or `$_parser` substitutions are valid and should be used.
- It is an error to have an empty code block.

For each dialect, two "dispatch" functions will be created: one for parsing and
one for printing. You should add calls to these in your `Dialect::printType` and
`Dialect::parseType` methods. They are static functions placed alongside the
type class definitions and have the following function signatures:

```c++
static Type generatedTypeParser(MLIRContext* ctxt, DialectAsmParser& parser, StringRef mnemonic);
LogicalResult generatedTypePrinter(Type type, DialectAsmPrinter& printer);
```
The mnemonic, parser, and printer fields are optional. If they're not defined,
the generated code will not include any parsing or printing code and omit the
type from the dispatch functions above. In this case, the dialect author is
responsible for parsing/printing the types in `Dialect::printType` and
`Dialect::parseType`.
### Other fields
- If the `genStorageClass` field is set to 1 (the default) a storage class is
generated with member variables corresponding to each of the specified
`parameters`.
- If the `genAccessors` field is 1 (the default) accessor methods will be
generated on the Type class (e.g. `int getWidth() const` in the example
above).
- If the `genVerifyDecl` field is set, a declaration for a method `static
LogicalResult verify(emitErrorFn, parameters...)` is added to the class as
well as a `getChecked(emitErrorFn, parameters...)` method which checks the
result of `verify` before calling `get`.
- The `storageClass` field can be used to set the name of the storage class.
- The `storageNamespace` field is used to set the namespace where the storage
class should sit. Defaults to "detail".
- The `extraClassDeclaration` field is used to include extra code in the class
declaration.
### Type builder methods
For each type, there are a few builders(`get`/`getChecked`) automatically
generated based on the parameters of the type. For example, given the following
type definition:
```tablegen
def MyType : ... {
let parameters = (ins "int":$intParam);
}
```

The following builders are generated:

```c++
// Type builders are named `get`, and return a new instance of a type for a
// given set of parameters.
static MyType get(MLIRContext *context, int intParam);

// If `genVerifyDecl` is set to 1, the following method is also generated.
static MyType getChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *context, int intParam);
```
If these autogenerated methods are not desired, such as when they conflict with
a custom builder method, a type can set `skipDefaultBuilders` to 1 to signal
that they should not be generated.
#### Custom type builder methods
The default build methods may cover a majority of the simple cases related to
type construction, but when they cannot satisfy a type's needs, you can define
additional convenience 'get' methods in the `builders` field as follows:
```tablegen
def MyType : ... {
let parameters = (ins "int":$intParam);
let builders = [
TypeBuilder<(ins "int":$intParam)>,
TypeBuilder<(ins CArg<"int", "0">:$intParam)>,
TypeBuilder<(ins CArg<"int", "0">:$intParam), [{
// Write the body of the `get` builder inline here.
return Base::get($_ctxt, intParam);
}]>,
TypeBuilderWithInferredContext<(ins "Type":$typeParam), [{
// This builder states that it can infer an MLIRContext instance from
// its arguments.
return Base::get(typeParam.getContext(), ...);
}]>,
];
}
```

The `builders` field is a list of custom builders that are added to the type
class. In this example, we provide several different convenience builders that
are useful in different scenarios. The `ins` prefix is common to many function
declarations in ODS, which use a TableGen [`dag`](#tablegen-syntax). What
follows is a comma-separated list of types (quoted string or `CArg`) and names
prefixed with the `$` sign. The use of `CArg` allows for providing a default
value to that argument. Let's take a look at each of these builders individually

The first builder will generate the declaration of a builder method that looks
like:

```tablegen
let builders = [
TypeBuilder<(ins "int":$intParam)>,
];
```

```c++
class MyType : /*...*/ {
/*...*/
static MyType get(::mlir::MLIRContext *context, int intParam);
};
```
This builder is identical to the one that will be automatically generated for
`MyType`. The `context` parameter is implicitly added by the generator, and is
used when building the Type instance (with `Base::get`). The distinction
here is that we can provide the implementation of this `get` method. With this
style of builder definition only the declaration is generated, the implementor
of `MyType` will need to provide a definition of `MyType::get`.
The second builder will generate the declaration of a builder method that looks
like:
```tablegen
let builders = [
TypeBuilder<(ins CArg<"int", "0">:$intParam)>,
];
```

```c++
class MyType : /*...*/ {
/*...*/
static MyType get(::mlir::MLIRContext *context, int intParam = 0);
};
```
The constraints here are identical to the first builder example except for the
fact that `intParam` now has a default value attached.
The third builder will generate the declaration of a builder method that looks
like:
```tablegen
let builders = [
TypeBuilder<(ins CArg<"int", "0">:$intParam), [{
// Write the body of the `get` builder inline here.
return Base::get($_ctxt, intParam);
}]>,
];
```

```c++
class MyType : /*...*/ {
/*...*/
static MyType get(::mlir::MLIRContext *context, int intParam = 0);
};

MyType MyType::get(::mlir::MLIRContext *context, int intParam) {
// Write the body of the `get` builder inline here.
return Base::get(context, intParam);
}
```
This is identical to the second builder example. The difference is that now, a
definition for the builder method will be generated automatically using the
provided code block as the body. When specifying the body inline, `$_ctxt` may
be used to access the `MLIRContext *` parameter.
The fourth builder will generate the declaration of a builder method that looks
like:
```tablegen
let builders = [
TypeBuilderWithInferredContext<(ins "Type":$typeParam), [{
// This builder states that it can infer an MLIRContext instance from
// its arguments.
return Base::get(typeParam.getContext(), ...);
}]>,
];
```

```c++
class MyType : /*...*/ {
/*...*/
static MyType get(Type typeParam);
};

MyType MyType::get(Type typeParam) {
// This builder states that it can infer an MLIRContext instance from its
// arguments.
return Base::get(typeParam.getContext(), ...);
}
```
In this builder example, the main difference from the third builder example
there is that the `MLIRContext` parameter is no longer added. This is because
the type builder used `TypeBuilderWithInferredContext` implies that the context
parameter is not necessary as it can be inferred from the arguments to the
builder.
## Debugging Tips

### Run `mlir-tblgen` to see the generated content
Expand Down
694 changes: 0 additions & 694 deletions mlir/docs/Tutorials/DefiningAttributesAndTypes.md

This file was deleted.

1 change: 1 addition & 0 deletions mlir/examples/toy/Ch3/mlir/ToyCombine.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef TOY_COMBINE
#define TOY_COMBINE

include "mlir/IR/PatternBase.td"
include "toy/Ops.td"

/// Note: The DRR definition used for defining patterns is shown below:
Expand Down
1 change: 1 addition & 0 deletions mlir/examples/toy/Ch4/mlir/ToyCombine.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef TOY_COMBINE
#define TOY_COMBINE

include "mlir/IR/PatternBase.td"
include "toy/Ops.td"

/// Note: The DRR definition used for defining patterns is shown below:
Expand Down
1 change: 1 addition & 0 deletions mlir/examples/toy/Ch5/mlir/ToyCombine.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef TOY_COMBINE
#define TOY_COMBINE

include "mlir/IR/PatternBase.td"
include "toy/Ops.td"

/// Note: The DRR definition used for defining patterns is shown below:
Expand Down
1 change: 1 addition & 0 deletions mlir/examples/toy/Ch6/mlir/ToyCombine.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef TOY_COMBINE
#define TOY_COMBINE

include "mlir/IR/PatternBase.td"
include "toy/Ops.td"

/// Note: The DRR definition used for defining patterns is shown below:
Expand Down
1 change: 1 addition & 0 deletions mlir/examples/toy/Ch7/mlir/ToyCombine.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef TOY_COMBINE
#define TOY_COMBINE

include "mlir/IR/PatternBase.td"
include "toy/Ops.td"

/// Note: The DRR definition used for defining patterns is shown below:
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_ASYNC_IR_ASYNCTYPES
#define MLIR_DIALECT_ASYNC_IR_ASYNCTYPES

include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/Async/IR/AsyncDialect.td"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -46,6 +47,7 @@ def Async_ValueType : Async_Type<"Value", "value"> {
return $_get(valueType.getContext(), valueType);
}]>
];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_EMITC_IR_EMITCATTRIBUTES
#define MLIR_DIALECT_EMITC_IR_EMITCATTRIBUTES

include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/EmitC/IR/EmitCBase.td"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -40,6 +41,8 @@ def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> {
}];

let parameters = (ins StringRefParameter<"the opaque value">:$value);

let hasCustomAssemblyFormat = 1;
}

#endif // MLIR_DIALECT_EMITC_IR_EMITCATTRIBUTES
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_EMITC_IR_EMITCTYPES
#define MLIR_DIALECT_EMITC_IR_EMITCTYPES

include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/EmitC/IR/EmitCBase.td"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -41,6 +42,7 @@ def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
}];

let parameters = (ins StringRefParameter<"the opaque value">:$value);
let hasCustomAssemblyFormat = 1;
}

def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#ifndef LLVMIR_ATTRDEFS
#define LLVMIR_ATTRDEFS

include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"


// All of the attributes will extend this class.
class LLVM_Attr<string name> : AttrDef<LLVM_Dialect, name>;

Expand All @@ -23,6 +23,7 @@ def FastmathFlagsAttr : LLVM_Attr<"FMF"> {
let parameters = (ins
"FastmathFlags":$flags
);
let hasCustomAssemblyFormat = 1;
}

// Attribute definition for the LLVM Linkage enum.
Expand All @@ -31,6 +32,7 @@ def LinkageAttr : LLVM_Attr<"Linkage"> {
let parameters = (ins
"linkage::Linkage":$linkage
);
let hasCustomAssemblyFormat = 1;
}

def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> {
Expand Down Expand Up @@ -63,6 +65,7 @@ def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> {
AttrBuilder<(ins "ArrayRef<std::pair<LoopOptionCase, int64_t>>":$sortedOptions)>,
AttrBuilder<(ins "LoopOptionsAttrBuilder &":$optionBuilders)>
];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_PDL_IR_PDLTYPES
#define MLIR_DIALECT_PDL_IR_PDLTYPES

include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/PDL/IR/PDLDialect.td"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -66,6 +67,7 @@ def PDL_Range : PDL_Type<"Range", "range"> {
}]>,
];
let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef SPARSETENSOR_ATTRDEFS
#define SPARSETENSOR_ATTRDEFS

include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td"
include "mlir/IR/TensorEncoding.td"

Expand Down Expand Up @@ -79,6 +80,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
);

let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
// Dimension level types that define sparse tensors:
Expand Down
387 changes: 387 additions & 0 deletions mlir/include/mlir/IR/AttrTypeBase.td

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions mlir/include/mlir/IR/BuiltinAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef BUILTIN_ATTRIBUTES
#define BUILTIN_ATTRIBUTES

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/SubElementInterfaces.td"
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/BuiltinLocationAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef BUILTIN_LOCATION_ATTRIBUTES_TD
#define BUILTIN_LOCATION_ATTRIBUTES_TD

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"

// Base class for Builtin dialect location attributes.
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef BUILTIN_TYPES
#define BUILTIN_TYPES

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/SubElementInterfaces.td"
Expand Down
8 changes: 4 additions & 4 deletions mlir/include/mlir/IR/EnumAttr.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//

#ifndef ENUM_ATTR
#define ENUM_ATTR
#ifndef ENUMATTR_TD
#define ENUMATTR_TD

include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"

// A C++ enum as an attribute parameter. The parameter implements a parser and
// printer for the enum by dispatching calls to `stringToSymbol` and
Expand Down Expand Up @@ -96,4 +96,4 @@ class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
let assemblyFormat = "$value";
}

#endif // ENUM_ATTR
#endif // ENUMATTR_TD
563 changes: 0 additions & 563 deletions mlir/include/mlir/IR/OpBase.td

Large diffs are not rendered by default.

221 changes: 221 additions & 0 deletions mlir/include/mlir/IR/PatternBase.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
//===-- PatternBase.td - Base pattern definition file ------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This files contains all of the base constructs for defining DRR patterns.
//
//===----------------------------------------------------------------------===//

#ifndef PATTERNBASE_TD
#define PATTERNBASE_TD

include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
// Pattern definitions
//===----------------------------------------------------------------------===//

// Marker used to identify the delta value added to the default benefit value.
def addBenefit;

// Base class for op+ -> op+ rewrite rules. These allow declaratively
// specifying rewrite rules.
//
// A rewrite rule contains two components: a source pattern and one or more
// result patterns. Each pattern is specified as a (recursive) DAG node (tree)
// in the form of `(node arg0, arg1, ...)`.
//
// The `node` are normally MLIR ops, but it can also be one of the directives
// listed later in this section.
//
// ## Symbol binding
//
// In the source pattern, `argN` can be used to specify matchers (e.g., using
// type/attribute type constraints, etc.) and bound to a name for later use.
// We can also bind names to op instances to reference them later in
// multi-entity constraints. Operands in the source pattern can have
// the same name. This bounds one operand to the name while verifying
// the rest are all equal.
//
//
// In the result pattern, `argN` can be used to refer to a previously bound
// name, with potential transformations (e.g., using tAttr, etc.). `argN` can
// itself be nested DAG node. We can also bound names to ops to reference
// them later in other result patterns.
//
// For example,
//
// ```
// def : Pattern<(OneResultOp1:$op1 $arg0, $arg1, $arg0),
// [(OneResultOp2:$op2 $arg0, $arg1),
// (OneResultOp3 $op2 (OneResultOp4))],
// [(HasStaticShapePred $op1)]>;
// ```
//
// First `$arg0` and '$arg1' are bound to the `OneResultOp1`'s first
// and second arguments and used later to build `OneResultOp2`. Second '$arg0'
// is verified to be equal to the first '$arg0' operand.
// `$op1` is bound to `OneResultOp1` and used to check whether the result's
// shape is static. `$op2` is bound to `OneResultOp2` and used to
// build `OneResultOp3`.
//
// ## Multi-result op
//
// To create multi-result ops in result pattern, you can use a syntax similar
// to uni-result op, and it will act as a value pack for all results:
//
// ```
// def : Pattern<(ThreeResultOp ...),
// [(TwoResultOp ...), (OneResultOp ...)]>;
// ```
//
// Then `TwoResultOp` will replace the first two values of `ThreeResultOp`.
//
// You can also use `$<name>__N` to explicitly access the N-th result.
// ```
// def : Pattern<(FiveResultOp ...),
// [(TwoResultOp1:$res1__1 ...), (replaceWithValue $res1__0),
// (TwoResultOp2:$res2 ...), (replaceWithValue $res2__1)]>;
// ```
//
// Then the values generated by `FiveResultOp` will be replaced by
//
// * `FiveResultOp`#0: `TwoResultOp1`#1
// * `FiveResultOp`#1: `TwoResultOp1`#0
// * `FiveResultOp`#2: `TwoResultOp2`#0
// * `FiveResultOp`#3: `TwoResultOp2`#1
// * `FiveResultOp`#4: `TwoResultOp2`#1
class Pattern<dag source, list<dag> results, list<dag> preds = [],
dag benefitAdded = (addBenefit 0)> {
dag sourcePattern = source;
// Result patterns. Each result pattern is expected to replace one result
// of the root op in the source pattern. In the case of more result patterns
// than needed to replace the source op, only the last N results generated
// by the last N result pattern is used to replace a N-result source op.
// So that the beginning result patterns can be used to generate additional
// ops to aid building the results used for replacement.
list<dag> resultPatterns = results;
// Multi-entity constraints. Each constraint here involves multiple entities
// matched in source pattern and places further constraints on them as a
// whole.
list<dag> constraints = preds;
// The delta value added to the default benefit value. The default value is
// the number of ops in the source pattern. The rule with the highest final
// benefit value will be applied first if there are multiple rules matches.
// This delta value can be either positive or negative.
dag benefitDelta = benefitAdded;
}

// Form of a pattern which produces a single result.
class Pat<dag pattern, dag result, list<dag> preds = [],
dag benefitAdded = (addBenefit 0)> :
Pattern<pattern, [result], preds, benefitAdded>;

// Native code call wrapper. This allows invoking an arbitrary C++ expression
// to create an op operand/attribute or replace an op result.
//
// ## Placeholders
//
// If used as a DAG leaf, i.e., `(... NativeCodeCall<"...">:$arg, ...)`,
// the wrapped expression can take special placeholders listed below:
//
// * `$_builder` will be replaced by the current `mlir::PatternRewriter`.
// * `$_self` will be replaced by the defining operation in a source pattern.
// E.g., `NativeCodeCall<"Foo($_self, &$0)> I32Attr:$attr)>`, `$_self` will be
// replaced with the defining operation of the first operand of OneArgOp.
//
// If used as a DAG node, i.e., `(NativeCodeCall<"..."> <arg0>, ..., <argN>)`,
// then positional placeholders are also supported; placeholder `$N` in the
// wrapped C++ expression will be replaced by `<argN>`.
//
// ## Bind multiple results
//
// To bind multi-results and access the N-th result with `$<name>__N`, specify
// the number of return values in the template. Note that only `Value` type is
// supported for multiple results binding.

class NativeCodeCall<string expr, int returns = 1> {
string expression = expr;
int numReturns = returns;
}

class NativeCodeCallVoid<string expr> : NativeCodeCall<expr, 0>;

def ConstantLikeMatcher : NativeCodeCall<"::mlir::success("
"::mlir::matchPattern($_self->getResult(0), ::mlir::m_Constant(&$0)))">;

//===----------------------------------------------------------------------===//
// Rewrite directives
//===----------------------------------------------------------------------===//

// Directive used in result pattern to indicate that no new op are generated,
// so to replace the matched DAG with an existing SSA value.
def replaceWithValue;

// Directive used in result patterns to specify the location of the generated
// op. This directive must be used as a trailing argument to op creation or
// native code calls.
//
// Usage:
// * Create a named location: `(location "myLocation")`
// * Copy the location of a captured symbol: `(location $arg)`
// * Create a fused location: `(location "metadata", $arg0, $arg1)`

def location;

// Directive used in result patterns to specify return types for a created op.
// This allows ops to be created without relying on type inference with
// `OpTraits` or an op builder with deduction.
//
// This directive must be used as a trailing argument to op creation.
//
// Specify one return type with a string literal:
//
// ```
// (AnOp $val, (returnType "$_builder.getI32Type()"))
// ```
//
// Pass a captured value to copy its return type:
//
// ```
// (AnOp $val, (returnType $val));
// ```
//
// Pass a native code call inside a DAG to create a new type with arguments.
//
// ```
// (AnOp $val,
// (returnType (NativeCodeCall<"$_builder.getTupleType({$0})"> $val)));
// ```
//
// Specify multiple return types with multiple of any of the above.

def returnType;

// Directive used to specify the operands may be matched in either order. When
// two adjacents are marked with `either`, it'll try to match the operands in
// either ordering of constraints. Example:
//
// ```
// (TwoArgOp (either $firstArg, (AnOp $secondArg)))
// ```
// The above pattern will accept either `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and
// `"test.TwoArgOp"(%AnOpArg, %I32Arg)`.
//
// Only operand is supported with `either` and note that an operation with
// `Commutative` trait doesn't imply that it'll have the same behavior than
// `either` while pattern matching.
def either;

//===----------------------------------------------------------------------===//
// Common value constraints
//===----------------------------------------------------------------------===//

def HasNoUseOf: Constraint<
CPred<"$_self.use_empty()">, "has no use">;

#endif // PATTERNBASE_TD
23 changes: 3 additions & 20 deletions mlir/include/mlir/TableGen/AttrOrTypeDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,30 +175,13 @@ class AttrOrTypeDef {
/// supposed to auto-generate them.
Optional<StringRef> getMnemonic() const;

/// Returns the code to use as the types printer method. If not specified,
/// return a non-value. Otherwise, return the contents of that code block.
Optional<StringRef> getPrinterCode() const;

/// Returns the code to use as the parser method. If not specified, returns
/// None. Otherwise, returns the contents of that code block.
Optional<StringRef> getParserCode() const;
/// Returns if the attribute or type has a custom assembly format implemented
/// in C++. Corresponds to the `hasCustomAssemblyFormat` field.
bool hasCustomAssemblyFormat() const;

/// Returns the custom assembly format, if one was specified.
Optional<StringRef> getAssemblyFormat() const;

/// An attribute or type with parameters needs a parser.
bool needsParserPrinter() const { return getNumParameters() != 0; }

/// Returns true if this attribute or type has a generated parser.
bool hasGeneratedParser() const {
return getParserCode() || getAssemblyFormat();
}

/// Returns true if this attribute or type has a generated printer.
bool hasGeneratedPrinter() const {
return getPrinterCode() || getAssemblyFormat();
}

/// Returns true if the accessors based on the parameters should be generated.
bool genAccessors() const;

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_CONVERSION_GPUTONVVM_TD
#define MLIR_CONVERSION_GPUTONVVM_TD

include "mlir/IR/PatternBase.td"
include "mlir/Dialect/GPU/GPUOps.td"
include "mlir/Dialect/LLVMIR/NVVMOps.td"

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/GPUToROCDL/GPUToROCDL.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_CONVERSION_GPUTOROCDL_TD
#define MLIR_CONVERSION_GPUTOROCDL_TD

include "mlir/IR/PatternBase.td"
include "mlir/Dialect/GPU/GPUOps.td"
include "mlir/Dialect/LLVMIR/ROCDLOps.td"

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_CONVERSION_SHAPETOSTANDARD_TD
#define MLIR_CONVERSION_SHAPETOSTANDARD_TD

include "mlir/IR/PatternBase.td"
include "mlir/Dialect/Shape/IR/ShapeOps.td"

def BroadcastableStringAttr : NativeCodeCall<[{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef ARITHMETIC_PATTERNS
#define ARITHMETIC_PATTERNS

include "mlir/IR/PatternBase.td"
include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td"

// Add two integer attributes and create a new one with the result.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//

include "mlir/IR/PatternBase.td"
include "mlir/Dialect/SPIRV/IR/SPIRVOps.td"

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include "mlir/IR/PatternBase.td"
include "mlir/Dialect/Shape/IR/ShapeOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"

Expand Down
32 changes: 26 additions & 6 deletions mlir/lib/TableGen/AttrOrTypeDef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,30 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
for (unsigned i = 0, e = parametersDag->getNumArgs(); i < e; ++i)
parameters.push_back(AttrOrTypeParameter(parametersDag, i));
}

// Verify the use of the mnemonic field.
bool hasCppFormat = hasCustomAssemblyFormat();
bool hasDeclarativeFormat = getAssemblyFormat().hasValue();
if (getMnemonic()) {
if (hasCppFormat && hasDeclarativeFormat) {
PrintFatalError(getLoc(), "cannot specify both 'assemblyFormat' "
"and 'hasCustomAssemblyFormat'");
}
if (!parameters.empty() && !hasCppFormat && !hasDeclarativeFormat) {
PrintFatalError(getLoc(),
"must specify either 'assemblyFormat' or "
"'hasCustomAssemblyFormat' when 'mnemonic' is set");
}
} else if (hasCppFormat || hasDeclarativeFormat) {
PrintFatalError(getLoc(),
"'assemblyFormat' or 'hasCustomAssemblyFormat' can only be "
"used when 'mnemonic' is set");
}
// Assembly format requires accessors to be generated.
if (hasDeclarativeFormat && !genAccessors()) {
PrintFatalError(getLoc(),
"'assemblyFormat' requires 'genAccessors' to be true");
}
}

Dialect AttrOrTypeDef::getDialect() const {
Expand Down Expand Up @@ -122,12 +146,8 @@ Optional<StringRef> AttrOrTypeDef::getMnemonic() const {
return def->getValueAsOptionalString("mnemonic");
}

Optional<StringRef> AttrOrTypeDef::getPrinterCode() const {
return def->getValueAsOptionalString("printer");
}

Optional<StringRef> AttrOrTypeDef::getParserCode() const {
return def->getValueAsOptionalString("parser");
bool AttrOrTypeDef::hasCustomAssemblyFormat() const {
return def->getValueAsBit("hasCustomAssemblyFormat");
}

Optional<StringRef> AttrOrTypeDef::getAssemblyFormat() const {
Expand Down
7 changes: 6 additions & 1 deletion mlir/test/lib/Dialect/Test/TestAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

// To get the test dialect definition.
include "TestDialect.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/SubElementInterfaces.td"

Expand Down Expand Up @@ -42,6 +43,7 @@ def CompoundAttrA : Test_Attr<"CompoundA"> {
"An example of an array of ints" // Parameter description.
>: $arrayOfInts
);
let hasCustomAssemblyFormat = 1;
}
def CompoundAttrNested : Test_Attr<"CompoundAttrNested"> {
let mnemonic = "cmpnd_nested";
Expand All @@ -53,21 +55,22 @@ def CompoundAttrNested : Test_Attr<"CompoundAttrNested"> {
def AttrWithSelfTypeParam : Test_Attr<"AttrWithSelfTypeParam"> {
let mnemonic = "attr_with_self_type_param";
let parameters = (ins AttributeSelfTypeParameter<"">:$type);
let hasCustomAssemblyFormat = 1;
}

// An attribute testing AttributeSelfTypeParameter.
def AttrWithTypeBuilder : Test_Attr<"AttrWithTypeBuilder"> {
let mnemonic = "attr_with_type_builder";
let parameters = (ins "::mlir::IntegerAttr":$attr);
let typeBuilder = "$_attr.getType()";
let hasCustomAssemblyFormat = 1;
}

def TestAttrTrait : NativeAttrTrait<"TestAttrTrait">;

// The definition of a singleton attribute that has a trait.
def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
let mnemonic = "attr_with_trait";
let parameters = (ins );
}

// Test support for ElementsAttrInterface.
Expand Down Expand Up @@ -105,6 +108,7 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
}
}];
let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
}

def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
Expand All @@ -119,6 +123,7 @@ def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
"::mlir::Attribute":$second,
"::mlir::Attribute":$third
);
let hasCustomAssemblyFormat = 1;
}

// A more complex parameterized attribute with multiple level of nesting.
Expand Down
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include "mlir/Dialect/DLTI/DLTIBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/PatternBase.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
Expand Down
55 changes: 5 additions & 50 deletions mlir/test/lib/Dialect/Test/TestTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def CompoundTypeA : Test_Type<"CompoundA"> {
let extraClassDeclaration = [{
struct SomeCppStruct {};
}];
let hasCustomAssemblyFormat = 1;
}

// A more complex and nested parameterized type.
Expand Down Expand Up @@ -92,12 +93,8 @@ def IntegerType : Test_Type<"TestInteger"> {
"::test::TestIntegerType::SignednessSemantics":$signedness
);

// We define the printer inline.
let printer = [{
$_printer << "<";
printSignedness($_printer, getImpl()->signedness);
$_printer << ", " << getImpl()->width << ">";
}];
// Indicate we use a custom format.
let hasCustomAssemblyFormat = 1;

// Define custom builder methods.
let builders = [
Expand All @@ -108,19 +105,6 @@ def IntegerType : Test_Type<"TestInteger"> {
];
let skipDefaultBuilders = 1;

// The parser is defined here also.
let parser = [{
if ($_parser.parseLess()) return Type();
SignednessSemantics signedness;
if (parseSignedness($_parser, signedness)) return Type();
if ($_parser.parseComma()) return Type();
int width;
if ($_parser.parseInteger(width)) return Type();
if ($_parser.parseGreater()) return Type();
Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
return getChecked(loc, loc.getContext(), width, signedness);
}];

// Any extra code one wants in the type's class declaration.
let extraClassDeclaration = [{
/// Signedness semantics.
Expand Down Expand Up @@ -150,37 +134,7 @@ class FieldInfo_Type<string name> : Test_Type<name> {
"::test::FieldInfo", // FieldInfo is defined/declared in TestTypes.h.
"Models struct fields">: $fields
);

// Prints the type in this format:
// struct<[{field1Name, field1Type}, {field2Name, field2Type}]
let printer = [{
$_printer << "<";
for (size_t i=0, e = getImpl()->fields.size(); i < e; i++) {
const auto& field = getImpl()->fields[i];
$_printer << "{" << field.name << "," << field.type << "}";
if (i < getImpl()->fields.size() - 1)
$_printer << ",";
}
$_printer << ">";
}];

// Parses the above format
let parser = [{
llvm::SmallVector<FieldInfo, 4> parameters;
if ($_parser.parseLess()) return Type();
while (mlir::succeeded($_parser.parseOptionalLBrace())) {
llvm::StringRef name;
if ($_parser.parseKeyword(&name)) return Type();
if ($_parser.parseComma()) return Type();
Type type;
if ($_parser.parseType(type)) return Type();
if ($_parser.parseRBrace()) return Type();
parameters.push_back(FieldInfo {name, type});
if ($_parser.parseOptionalComma()) break;
}
if ($_parser.parseGreater()) return Type();
return get($_ctxt, parameters);
}];
let hasCustomAssemblyFormat = 1;
}

def StructType : FieldInfo_Type<"Struct"> {
Expand Down Expand Up @@ -208,6 +162,7 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [

public:
}];
let hasCustomAssemblyFormat = 1;
}

def TestMemRefElementType : Test_Type<"TestMemRefElementType",
Expand Down
167 changes: 107 additions & 60 deletions mlir/test/lib/Dialect/Test/TestTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,65 @@ static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT
return llvm::hash_combine(fi.name, fi.type);
}

//===----------------------------------------------------------------------===//
// TestCustomType
//===----------------------------------------------------------------------===//

static LogicalResult parseCustomTypeA(AsmParser &parser,
FailureOr<int> &a_result) {
a_result.emplace();
return parser.parseInteger(*a_result);
}

static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }

static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
FailureOr<Optional<int>> &b_result) {
if (a < 0)
return success();
for (int i : llvm::seq(0, a))
if (failed(parser.parseInteger(i)))
return failure();
b_result.emplace(0);
return parser.parseInteger(**b_result);
}

static void printCustomTypeB(AsmPrinter &printer, int a, Optional<int> b) {
if (a < 0)
return;
printer << ' ';
for (int i : llvm::seq(0, a))
printer << i << ' ';
printer << *b;
}

static LogicalResult parseFooString(AsmParser &parser,
FailureOr<std::string> &foo) {
std::string result;
if (parser.parseString(&result))
return failure();
foo = std::move(result);
return success();
}

static void printFooString(AsmPrinter &printer, StringRef foo) {
printer << '"' << foo << '"';
}

static LogicalResult parseBarString(AsmParser &parser, StringRef foo) {
return parser.parseKeyword(foo);
}

static void printBarString(AsmPrinter &printer, StringRef foo) {
printer << ' ' << foo;
}
//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//

#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.cpp.inc"

//===----------------------------------------------------------------------===//
// CompoundAType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -129,6 +188,54 @@ TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

Type TestIntegerType::parse(AsmParser &parser) {
SignednessSemantics signedness;
int width;
if (parser.parseLess() || parseSignedness(parser, signedness) ||
parser.parseComma() || parser.parseInteger(width) ||
parser.parseGreater())
return Type();
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
return getChecked(loc, loc.getContext(), width, signedness);
}

void TestIntegerType::print(AsmPrinter &p) const {
p << "<";
printSignedness(p, getSignedness());
p << ", " << getWidth() << ">";
}

//===----------------------------------------------------------------------===//
// TestStructType
//===----------------------------------------------------------------------===//

Type StructType::parse(AsmParser &p) {
SmallVector<FieldInfo, 4> parameters;
if (p.parseLess())
return Type();
while (succeeded(p.parseOptionalLBrace())) {
Type type;
StringRef name;
if (p.parseKeyword(&name) || p.parseComma() || p.parseType(type) ||
p.parseRBrace())
return Type();
parameters.push_back(FieldInfo{name, type});
if (p.parseOptionalComma())
break;
}
if (p.parseGreater())
return Type();
return get(p.getContext(), parameters);
}

void StructType::print(AsmPrinter &p) const {
p << "<";
llvm::interleaveComma(getFields(), p, [&](const FieldInfo &field) {
p << "{" << field.name << "," << field.type << "}";
});
p << ">";
}

//===----------------------------------------------------------------------===//
// TestType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -208,66 +315,6 @@ unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
return 1;
}

//===----------------------------------------------------------------------===//
// TestCustomType
//===----------------------------------------------------------------------===//

static LogicalResult parseCustomTypeA(AsmParser &parser,
FailureOr<int> &a_result) {
a_result.emplace();
return parser.parseInteger(*a_result);
}

static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }

static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
FailureOr<Optional<int>> &b_result) {
if (a < 0)
return success();
for (int i : llvm::seq(0, a))
if (failed(parser.parseInteger(i)))
return failure();
b_result.emplace(0);
return parser.parseInteger(**b_result);
}

static void printCustomTypeB(AsmPrinter &printer, int a, Optional<int> b) {
if (a < 0)
return;
printer << ' ';
for (int i : llvm::seq(0, a))
printer << i << ' ';
printer << *b;
}

static LogicalResult parseFooString(AsmParser &parser,
FailureOr<std::string> &foo) {
std::string result;
if (parser.parseString(&result))
return failure();
foo = std::move(result);
return success();
}

static void printFooString(AsmPrinter &printer, StringRef foo) {
printer << '"' << foo << '"';
}

static LogicalResult parseBarString(AsmParser &parser, StringRef foo) {
return parser.parseKeyword(foo);
}

static void printBarString(AsmPrinter &printer, StringRef foo) {
printer << ' ' << foo;
}

//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//

#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.cpp.inc"

//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include -asmformat-error-is-fatal=false %s 2>&1 | FileCheck %s

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

def Test_Dialect : Dialect {
Expand Down
1 change: 1 addition & 0 deletions mlir/test/mlir-tblgen/attr-or-type-format.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=ATTR
// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=TYPE

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

/// Test that attribute and type printers and parsers are correctly generated.
Expand Down
4 changes: 4 additions & 0 deletions mlir/test/mlir-tblgen/attrdefs.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: mlir-tblgen -gen-attrdef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

// DECL: #ifdef GET_ATTRDEF_CLASSES
Expand Down Expand Up @@ -59,6 +60,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
);

let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;

// DECL-LABEL: class CompoundAAttr : public ::mlir::Attribute
// DECL: static CompoundAAttr getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
Expand Down Expand Up @@ -101,6 +103,7 @@ def C_IndexAttr : TestAttr<"Index"> {
ins
StringRefParameter<"Label for index">:$label
);
let hasCustomAssemblyFormat = 1;

// DECL-LABEL: class IndexAttr : public ::mlir::Attribute
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
Expand All @@ -126,6 +129,7 @@ def E_AttrWithTypeBuilder : TestAttr<"AttrWithTypeBuilder"> {
let mnemonic = "attr_with_type_builder";
let parameters = (ins "::mlir::IntegerAttr":$attr);
let typeBuilder = "$_attr.getType()";
let hasCustomAssemblyFormat = 1;
}

// DEF-LABEL: struct AttrWithTypeBuilderAttrStorage
Expand Down
1 change: 1 addition & 0 deletions mlir/test/mlir-tblgen/default-type-attr-print-parser.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=ATTR
// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=TYPE

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

/// Test that attribute and type printers and parsers are correctly generated.
Expand Down
1 change: 1 addition & 0 deletions mlir/test/mlir-tblgen/expect-symbol.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include %s 2>&1 | FileCheck %s

include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"

def Test_Dialect : Dialect {
let name = "test";
Expand Down
1 change: 1 addition & 0 deletions mlir/test/mlir-tblgen/op-attribute.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
// RUN: mlir-tblgen -print-records -I %S/../../include %s | FileCheck %s --check-prefix=RECORD

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

def Test_Dialect : Dialect {
Expand Down
1 change: 1 addition & 0 deletions mlir/test/mlir-tblgen/op-decl-and-defs.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEFS

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down
1 change: 1 addition & 0 deletions mlir/test/mlir-tblgen/rewriter-errors.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR7 %s 2>&1 | FileCheck --check-prefix=ERROR7 %s

include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"

// Check using the dialect name as the namespace
def A_Dialect : Dialect {
Expand Down
1 change: 1 addition & 0 deletions mlir/test/mlir-tblgen/rewriter-indexing.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s

include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"

def Test_Dialect : Dialect {
let name = "test";
Expand Down
1 change: 1 addition & 0 deletions mlir/test/mlir-tblgen/rewriter-static-matcher.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s

include "mlir/IR/OpBase.td"
include "mlir/IR/PatternBase.td"

def Test_Dialect : Dialect {
let name = "test";
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/mlir-tblgen/testdialect-typedefs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<
return
}

// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int<none, 3>}>)
// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla}, {field2,!test.int<none, 3>}>)
func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int<none, 3>} > ) {
return
}
50 changes: 25 additions & 25 deletions mlir/test/mlir-tblgen/typedefs.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: mlir-tblgen -gen-typedef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

// DECL: #ifdef GET_TYPEDEF_CLASSES
Expand Down Expand Up @@ -34,8 +35,8 @@ include "mlir/IR/OpBase.td"

def Test_Dialect: Dialect {
// DECL-NOT: TestDialect
let name = "TestDialect";
let cppNamespace = "::test";
let name = "TestDialect";
let cppNamespace = "::test";
}

class TestType<string name> : TypeDef<Test_Dialect, name> { }
Expand All @@ -53,16 +54,16 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
let summary = "A more complex parameterized type";
let description = "This type is to test a reasonably complex type";
let mnemonic = "cmpnd_a";
let parameters = (
ins
"int":$widthOfSomething,
"::test::SimpleTypeA": $exampleTdType,
"SomeCppStruct": $exampleCppType,
ArrayRefParameter<"int", "Matrix dimensions">:$dims,
RTLValueType:$inner
let parameters = (ins
"int":$widthOfSomething,
"::test::SimpleTypeA": $exampleTdType,
"SomeCppStruct": $exampleCppType,
ArrayRefParameter<"int", "Matrix dimensions">:$dims,
RTLValueType:$inner
);

let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;

// DECL-LABEL: class CompoundAType : public ::mlir::Type
// DECL: static CompoundAType getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
Expand All @@ -78,12 +79,12 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
}

def C_IndexType : TestType<"Index"> {
let mnemonic = "index";
let mnemonic = "index";

let parameters = (
ins
StringRefParameter<"Label for index">:$label
);
let parameters = (ins
StringRefParameter<"Label for index">:$label
);
let hasCustomAssemblyFormat = 1;

// DECL-LABEL: class IndexType : public ::mlir::Type
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
Expand All @@ -94,8 +95,7 @@ def C_IndexType : TestType<"Index"> {
}

def D_SingleParameterType : TestType<"SingleParameter"> {
let parameters = (
ins
let parameters = (ins
"int": $num
);
// DECL-LABEL: struct SingleParameterTypeStorage;
Expand All @@ -104,17 +104,17 @@ def D_SingleParameterType : TestType<"SingleParameter"> {
}

def E_IntegerType : TestType<"Integer"> {
let mnemonic = "int";
let genVerifyDecl = 1;
let parameters = (
ins
"SignednessSemantics":$signedness,
TypeParameter<"unsigned", "Bitwidth of integer">:$width
);
let mnemonic = "int";
let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
let parameters = (ins
"SignednessSemantics":$signedness,
TypeParameter<"unsigned", "Bitwidth of integer">:$width
);

// DECL-LABEL: IntegerType : public ::mlir::Type

let extraClassDeclaration = [{
let extraClassDeclaration = [{
/// Signedness semantics.
enum SignednessSemantics {
Signless, /// No signedness semantics
Expand All @@ -131,7 +131,7 @@ def E_IntegerType : TestType<"Integer"> {
bool isSigned() const { return getSignedness() == Signed; }
/// Return true if this is an unsigned integer type.
bool isUnsigned() const { return getSignedness() == Unsigned; }
}];
}];

// DECL: /// Signedness semantics.
// DECL-NEXT: enum SignednessSemantics {
Expand Down
1 change: 1 addition & 0 deletions mlir/test/python/python_test_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef PYTHON_TEST_OPS
#define PYTHON_TEST_OPS

include "mlir/IR/AttrTypeBase.td"
include "mlir/Bindings/Python/Attributes.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
Expand Down
85 changes: 21 additions & 64 deletions mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ class DefGen {
/// Emit a checked custom builder.
void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);

//===--------------------------------------------------------------------===//
// Parser and Printer Emission
void emitParserPrinterBody(MethodBody &parser, MethodBody &printer);

//===--------------------------------------------------------------------===//
// Interface Method Emission

Expand Down Expand Up @@ -264,28 +260,29 @@ void DefGen::emitParserPrinter() {
auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
"::llvm::StringLiteral", "getMnemonic");
mnemonic->body().indent() << strfmt("return {\"{0}\"};", *def.getMnemonic());

// Declare the parser and printer, if needed.
if (!def.needsParserPrinter() && !def.hasGeneratedParser() &&
!def.hasGeneratedPrinter())
bool hasAssemblyFormat = def.getAssemblyFormat().hasValue();
if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat)
return;

// Declare the parser.
SmallVector<MethodParameter> parserParams;
parserParams.emplace_back("::mlir::AsmParser &", "odsParser");
if (isa<AttrDef>(&def))
parserParams.emplace_back("::mlir::Type", "odsType");
auto *parser = defCls.addMethod(
strfmt("::mlir::{0}", valueType), "parse",
def.hasGeneratedParser() ? Method::Static : Method::StaticDeclaration,
std::move(parserParams));
auto *parser = defCls.addMethod(strfmt("::mlir::{0}", valueType), "parse",
hasAssemblyFormat ? Method::Static
: Method::StaticDeclaration,
std::move(parserParams));
// Declare the printer.
auto props =
def.hasGeneratedPrinter() ? Method::Const : Method::ConstDeclaration;
auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration;
Method *printer =
defCls.addMethod("void", "print", props,
MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
// Emit the bodies.
emitParserPrinterBody(parser->body(), printer->body());
// Emit the bodies if we are using the declarative format.
if (hasAssemblyFormat)
return generateAttrOrTypeFormat(def, parser->body(), printer->body());
}

void DefGen::emitAccessors() {
Expand Down Expand Up @@ -406,50 +403,6 @@ void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
m->body().indent().getStream().printReindented(bodyStr);
}

//===----------------------------------------------------------------------===//
// Parser and Printer Emission

void DefGen::emitParserPrinterBody(MethodBody &parser, MethodBody &printer) {
Optional<StringRef> parserCode = def.getParserCode();
Optional<StringRef> printerCode = def.getPrinterCode();
Optional<StringRef> asmFormat = def.getAssemblyFormat();
// Verify the parser-printer specification first.
if (asmFormat && (parserCode || printerCode)) {
PrintFatalError(def.getLoc(),
def.getName() + ": assembly format cannot be specified at "
"the same time as printer or parser code");
}
// Specified code cannot be empty.
if (parserCode && parserCode->empty())
PrintFatalError(def.getLoc(), def.getName() + ": parser cannot be empty");
if (printerCode && printerCode->empty())
PrintFatalError(def.getLoc(), def.getName() + ": printer cannot be empty");
// Assembly format requires accessors to be generated.
if (asmFormat && !def.genAccessors()) {
PrintFatalError(def.getLoc(),
def.getName() +
": the generated printer from 'assemblyFormat' "
"requires 'genAccessors' to be true");
}

// Generate the parser and printer bodies.
if (asmFormat)
return generateAttrOrTypeFormat(def, parser, printer);

FmtContext ctx = FmtContext({{"_parser", "odsParser"},
{"_printer", "odsPrinter"},
{"_type", "odsType"}});
if (parserCode) {
ctx.addSubst("_ctxt", "odsParser.getContext()");
parser.indent().getStream().printReindented(tgfmt(*parserCode, &ctx).str());
}
if (printerCode) {
ctx.addSubst("_ctxt", "odsPrinter.getContext()");
printer.indent().getStream().printReindented(
tgfmt(*printerCode, &ctx).str());
}
}

//===----------------------------------------------------------------------===//
// Interface Method Emission

Expand Down Expand Up @@ -652,8 +605,8 @@ class DefGenerator {
/// A specialized generator for AttrDefs.
struct AttrDefGenerator : public DefGenerator {
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitions("AttrDef"), os, "Attr",
"Attribute",
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
"Attr", "Attribute",
/*isAttrGenerator=*/true,
/*needsDialectParserPrinter=*/
!records.getAllDerivedDefinitions("DialectAttr").empty()) {
Expand All @@ -662,8 +615,9 @@ struct AttrDefGenerator : public DefGenerator {
/// A specialized generator for TypeDefs.
struct TypeDefGenerator : public DefGenerator {
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitions("TypeDef"), os, "Type",
"Type", /*isAttrGenerator=*/false,
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
"Type", "Type",
/*isAttrGenerator=*/false,
/*needsDialectParserPrinter=*/
!records.getAllDerivedDefinitions("DialectType").empty()) {
}
Expand Down Expand Up @@ -828,18 +782,21 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
for (auto &def : defs) {
if (!def.getMnemonic())
continue;
bool hasParserPrinterDecl =
def.hasCustomAssemblyFormat() || def.getAssemblyFormat();
std::string defClass = strfmt(
"{0}::{1}", def.getDialect().getCppNamespace(), def.getCppClassName());

// If the def has no parameters or parser code, invoke a normal `get`.
std::string parseOrGet =
def.needsParserPrinter() || def.hasGeneratedParser()
hasParserPrinterDecl
? strfmt("parse(parser{0})", isAttrGenerator ? ", type" : "")
: "get(parser.getContext())";
parse.body() << llvm::formatv(getValueForMnemonic, defClass, parseOrGet);

// If the def has no parameters and no printer, just print the mnemonic.
StringRef printDef = "";
if (def.needsParserPrinter() || def.hasGeneratedPrinter())
if (hasParserPrinterDecl)
printDef = "\nt.print(printer);";
printer.body() << llvm::formatv(printValue, defClass, printDef);
}
Expand Down
9 changes: 4 additions & 5 deletions mlir/tools/mlir-tblgen/OpDocGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,7 @@ static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) {
os << "\n" << def.getSummary() << "\n";

// Emit the syntax if present.
if (def.getMnemonic() && def.getPrinterCode() == StringRef() &&
def.getParserCode() == StringRef())
if (def.getMnemonic() && !def.hasCustomAssemblyFormat())
emitAttrOrTypeDefAssemblyFormat(def, os);

// Emit the description if present.
Expand Down Expand Up @@ -337,11 +336,11 @@ static void emitDialectDoc(const Dialect &dialect, ArrayRef<AttrDef> attrDefs,
static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
std::vector<Record *> opDefs = getRequestedOpDefinitions(recordKeeper);
std::vector<Record *> typeDefs =
recordKeeper.getAllDerivedDefinitions("DialectType");
recordKeeper.getAllDerivedDefinitionsIfDefined("DialectType");
std::vector<Record *> typeDefDefs =
recordKeeper.getAllDerivedDefinitions("TypeDef");
recordKeeper.getAllDerivedDefinitionsIfDefined("TypeDef");
std::vector<Record *> attrDefDefs =
recordKeeper.getAllDerivedDefinitions("AttrDef");
recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef");

std::set<Dialect> dialectsWithDocs;

Expand Down