219 changes: 5 additions & 214 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,119 +29,15 @@ class TypeRange;
namespace detail {

struct BaseMemRefTypeStorage;
struct ComplexTypeStorage;
struct FunctionTypeStorage;
struct IntegerTypeStorage;
struct MemRefTypeStorage;
struct OpaqueTypeStorage;
struct RankedTensorTypeStorage;
struct ShapedTypeStorage;
struct TupleTypeStorage;
struct UnrankedMemRefTypeStorage;
struct UnrankedTensorTypeStorage;
struct VectorTypeStorage;

} // namespace detail

//===----------------------------------------------------------------------===//
// ComplexType
//===----------------------------------------------------------------------===//

/// The 'complex' type represents a complex number with a parameterized element
/// type, which is composed of a real and imaginary value of that element type.
///
/// The element must be a floating point or integer scalar type.
///
class ComplexType
: public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> {
public:
using Base::Base;

/// Get or create a ComplexType with the provided element type.
static ComplexType get(Type elementType);

/// Get or create a ComplexType with the provided element type. This emits
/// and error at the specified location and returns null if the element type
/// isn't supported.
static ComplexType getChecked(Location location, Type elementType);

/// Verify the construction of an integer type.
static LogicalResult verifyConstructionInvariants(Location loc,
Type elementType);

Type getElementType();
};

//===----------------------------------------------------------------------===//
// IntegerType
//===----------------------------------------------------------------------===//

/// Integer types can have arbitrary bitwidth up to a large fixed limit.
class IntegerType
: public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> {
public:
using Base::Base;

/// Signedness semantics.
enum SignednessSemantics : uint32_t {
Signless, /// No signedness semantics
Signed, /// Signed integer
Unsigned, /// Unsigned integer
};

/// Get or create a new IntegerType of the given width within the context.
/// The created IntegerType is signless (i.e., no signedness semantics).
/// Assume the width is within the allowed range and assert on failures. Use
/// getChecked to handle failures gracefully.
static IntegerType get(MLIRContext *context, unsigned width);

/// Get or create a new IntegerType of the given width within the context.
/// The created IntegerType has signedness semantics as indicated via
/// `signedness`. Assume the width is within the allowed range and assert on
/// failures. Use getChecked to handle failures gracefully.
static IntegerType get(MLIRContext *context, unsigned width,
SignednessSemantics signedness);

/// Get or create a new IntegerType of the given width within the context,
/// defined at the given, potentially unknown, location. The created
/// IntegerType is signless (i.e., no signedness semantics). If the width is
/// outside the allowed range, emit errors and return a null type.
static IntegerType getChecked(Location location, unsigned width);

/// Get or create a new IntegerType of the given width within the context,
/// defined at the given, potentially unknown, location. The created
/// IntegerType has signedness semantics as indicated via `signedness`. If the
/// width is outside the allowed range, emit errors and return a null type.
static IntegerType getChecked(Location location, unsigned width,
SignednessSemantics signedness);

/// Verify the construction of an integer type.
static LogicalResult
verifyConstructionInvariants(Location loc, unsigned width,
SignednessSemantics signedness);

/// Return the bitwidth of this integer type.
unsigned getWidth() const;

/// Return the signedness semantics of this integer type.
SignednessSemantics getSignedness() const;

/// Return true if this is a signless integer type.
bool isSignless() const { return getSignedness() == Signless; }
/// Return true if this is a signed integer type.
bool isSigned() const { return getSignedness() == Signed; }
/// Return true if this is an unsigned integer type.
bool isUnsigned() const { return getSignedness() == Unsigned; }

/// Get or create a new IntegerType with the same signedness as `this` and a
/// bitwidth scaled by `scale`.
/// Return null if the scaled element type cannot be represented.
IntegerType scaleElementBitwidth(unsigned scale);

/// Integer representation maximal bitwidth.
static constexpr unsigned kMaxWidth = (1 << 24) - 1; // Aligned with LLVM
};

//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -170,68 +66,6 @@ class FloatType : public Type {
const llvm::fltSemantics &getFloatSemantics();
};

//===----------------------------------------------------------------------===//
// FunctionType
//===----------------------------------------------------------------------===//

/// Function types map from a list of inputs to a list of results.
class FunctionType
: public Type::TypeBase<FunctionType, Type, detail::FunctionTypeStorage> {
public:
using Base::Base;

static FunctionType get(MLIRContext *context, TypeRange inputs,
TypeRange results);

/// Input types.
unsigned getNumInputs() const;
Type getInput(unsigned i) const { return getInputs()[i]; }
ArrayRef<Type> getInputs() const;

/// Result types.
unsigned getNumResults() const;
Type getResult(unsigned i) const { return getResults()[i]; }
ArrayRef<Type> getResults() const;

/// Returns a new function type without the specified arguments and results.
FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
ArrayRef<unsigned> resultIndices);
};

//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//

/// Opaque types represent types of non-registered dialects. These are types
/// represented in their raw string form, and can only usefully be tested for
/// type equality.
class OpaqueType
: public Type::TypeBase<OpaqueType, Type, detail::OpaqueTypeStorage> {
public:
using Base::Base;

/// Get or create a new OpaqueType with the provided dialect and string data.
static OpaqueType get(MLIRContext *context, Identifier dialect,
StringRef typeData);

/// Get or create a new OpaqueType with the provided dialect and string data.
/// If the given identifier is not a valid namespace for a dialect, then a
/// null type is returned.
static OpaqueType getChecked(Location location, Identifier dialect,
StringRef typeData);

/// Returns the dialect namespace of the opaque type.
Identifier getDialectNamespace() const;

/// Returns the raw type data of the opaque type.
StringRef getTypeData() const;

/// Verify the construction of an opaque type.
static LogicalResult verifyConstructionInvariants(Location loc,
Identifier dialect,
StringRef typeData);
};

//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -444,9 +278,7 @@ class BaseMemRefType : public ShapedType {
using ShapedType::ShapedType;

/// Return true if the specified element type is ok in a memref.
static bool isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() || type.isa<VectorType, ComplexType>();
}
static bool isValidElementType(Type type);

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
Expand Down Expand Up @@ -584,51 +416,6 @@ class UnrankedMemRefType

ArrayRef<int64_t> getShape() const { return llvm::None; }
};

//===----------------------------------------------------------------------===//
// TupleType
//===----------------------------------------------------------------------===//

/// Tuple types represent a collection of other types. Note: This type merely
/// provides a common mechanism for representing tuples in MLIR. It is up to
/// dialect authors to provides operations for manipulating them, e.g.
/// extract_tuple_element. When possible, users should prefer multi-result
/// operations in the place of tuples.
class TupleType
: public Type::TypeBase<TupleType, Type, detail::TupleTypeStorage> {
public:
using Base::Base;

/// Get or create a new TupleType with the provided element types. Assumes the
/// arguments define a well-formed type.
static TupleType get(MLIRContext *context, TypeRange elementTypes);

/// Get or create an empty tuple type.
static TupleType get(MLIRContext *context);

/// Return the elements types for this tuple.
ArrayRef<Type> getTypes() const;

/// Accumulate the types contained in this tuple and tuples nested within it.
/// Note that this only flattens nested tuples, not any other container type,
/// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
/// (i32, tensor<i32>, f32, i64)
void getFlattenedTypes(SmallVectorImpl<Type> &types);

/// Return the number of held types.
size_t size() const;

/// Iterate over the held elements.
using iterator = ArrayRef<Type>::iterator;
iterator begin() const { return getTypes().begin(); }
iterator end() const { return getTypes().end(); }

/// Return the element type at index 'index'.
Type getType(size_t index) const {
assert(index < size() && "invalid index for tuple type");
return getTypes()[index];
}
};
} // end namespace mlir

//===----------------------------------------------------------------------===//
Expand All @@ -647,6 +434,10 @@ inline bool BaseMemRefType::classof(Type type) {
return type.isa<MemRefType, UnrankedMemRefType>();
}

inline bool BaseMemRefType::isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() || type.isa<ComplexType, VectorType>();
}

inline bool FloatType::classof(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>();
}
Expand Down
240 changes: 240 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,42 @@ class Builtin_Type<string name> : TypeDef<Builtin_Dialect, name> {
let mnemonic = ?;
}

//===----------------------------------------------------------------------===//
// ComplexType
//===----------------------------------------------------------------------===//

def Builtin_Complex : Builtin_Type<"Complex"> {
let summary = "Complex number with a parameterized element type";
let description = [{
Syntax:

```
complex-type ::= `complex` `<` type `>`
```

The value of `complex` type represents a complex number with a parameterized
element type, which is composed of a real and imaginary value of that
element type. The element must be a floating point or integer scalar type.

Examples:

```mlir
complex<f32>
complex<i32>
```
}];
let parameters = (ins "Type":$elementType);
let builders = [
TypeBuilderWithInferredContext<(ins "Type":$elementType), [{
return Base::get(elementType.getContext(), elementType);
}], [{
return Base::getChecked($_loc, elementType);
}]>
];
let skipDefaultBuilders = 1;
let genVerifyInvariantsDecl = 1;
}

//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -65,6 +101,48 @@ def Builtin_Float64 : Builtin_FloatType<"Float64"> {
let summary = "64-bit floating-point type";
}

//===----------------------------------------------------------------------===//
// FunctionType
//===----------------------------------------------------------------------===//

def Builtin_Function : Builtin_Type<"Function"> {
let summary = "Map from a list of inputs to a list of results";
let description = [{
Syntax:

```
// Function types may have multiple results.
function-result-type ::= type-list-parens | non-function-type
function-type ::= type-list-parens `->` function-result-type
```

The function type can be thought of as a function signature. It consists of
a list of formal parameter types and a list of formal result types.
```
}];
let parameters = (ins "ArrayRef<Type>":$inputs, "ArrayRef<Type>":$results);
let builders = [
TypeBuilder<(ins CArg<"TypeRange">:$inputs, CArg<"TypeRange">:$results), [{
return Base::get($_ctxt, inputs, results);
}]>
];
let skipDefaultBuilders = 1;
let genStorageClass = 0;
let extraClassDeclaration = [{
/// Input types.
unsigned getNumInputs() const;
Type getInput(unsigned i) const { return getInputs()[i]; }

/// Result types.
unsigned getNumResults() const;
Type getResult(unsigned i) const { return getResults()[i]; }

/// Returns a new function type without the specified arguments and results.
FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
ArrayRef<unsigned> resultIndices);
}];
}

//===----------------------------------------------------------------------===//
// IndexType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -96,6 +174,70 @@ def Builtin_Index : Builtin_Type<"Index"> {
}];
}

//===----------------------------------------------------------------------===//
// IntegerType
//===----------------------------------------------------------------------===//

def Builtin_Integer : Builtin_Type<"Integer"> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:

```
// Sized integers like i1, i4, i8, i16, i32.
signed-integer-type ::= `si` [1-9][0-9]*
unsigned-integer-type ::= `ui` [1-9][0-9]*
signless-integer-type ::= `i` [1-9][0-9]*
integer-type ::= signed-integer-type |
unsigned-integer-type |
signless-integer-type
```

Integer types have a designated bit width and may optionally have signedness
semantics.

**Rationale:** low precision integers (like `i2`, `i4` etc) are useful for
low-precision inference chips, and arbitrary precision integers are useful
for hardware synthesis (where a 13 bit multiplier is a lot cheaper/smaller
than a 16 bit one).
}];
let parameters = (ins "unsigned":$width, "SignednessSemantics":$signedness);
let builders = [
TypeBuilder<(ins "unsigned":$width,
CArg<"SignednessSemantics", "Signless">:$signedness)>
];

// IntegerType uses a special storage class that compacts parameters to save
// memory.
let genStorageClass = 0;
let skipDefaultBuilders = 1;
let genVerifyInvariantsDecl = 1;
let extraClassDeclaration = [{
/// Signedness semantics.
enum SignednessSemantics : uint32_t {
Signless, /// No signedness semantics
Signed, /// Signed integer
Unsigned, /// Unsigned integer
};

/// Return true if this is a signless integer type.
bool isSignless() const { return getSignedness() == Signless; }
/// Return true if this is a signed integer type.
bool isSigned() const { return getSignedness() == Signed; }
/// Return true if this is an unsigned integer type.
bool isUnsigned() const { return getSignedness() == Unsigned; }

/// Get or create a new IntegerType with the same signedness as `this` and a
/// bitwidth scaled by `scale`.
/// Return null if the scaled element type cannot be represented.
IntegerType scaleElementBitwidth(unsigned scale);

/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
}];
}

//===----------------------------------------------------------------------===//
// NoneType
//===----------------------------------------------------------------------===//
Expand All @@ -111,4 +253,102 @@ def Builtin_None : Builtin_Type<"None"> {
}];
}

//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//

def Builtin_Opaque : Builtin_Type<"Opaque"> {
let summary = "Type of a non-registered dialect";
let description = [{
Syntax:

```
opaque-type ::= `opaque` `<` type `>`
```

Opaque types represent types of non-registered dialects. These are types
represented in their raw string form, and can only usefully be tested for
type equality.

Examples:

```mlir
opaque<"llvm", "struct<(i32, float)>">
opaque<"pdl", "value">
```
}];
let parameters = (ins
"Identifier":$dialectNamespace,
StringRefParameter<"">:$typeData
);
let genVerifyInvariantsDecl = 1;
}

//===----------------------------------------------------------------------===//
// TupleType
//===----------------------------------------------------------------------===//

def Builtin_Tuple : Builtin_Type<"Tuple"> {
let summary = "Fixed-sized collection of other types";
let description = [{
Syntax:

```
tuple-type ::= `tuple` `<` (type ( `,` type)*)? `>`
```

The value of `tuple` type represents a fixed-size collection of elements,
where each element may be of a different type.

**Rationale:** Though this type is first class in the type system, MLIR
provides no standard operations for operating on `tuple` types
([rationale](Rationale/Rationale.md#tuple-types)).

Examples:

```mlir
// Empty tuple.
tuple<>

// Single element
tuple<f32>

// Many elements.
tuple<i32, f32, tensor<i1>, i5>
```
}];
let parameters = (ins "ArrayRef<Type>":$types);
let builders = [
TypeBuilder<(ins "TypeRange":$elementTypes), [{
return Base::get($_ctxt, elementTypes);
}]>,
TypeBuilder<(ins), [{
return Base::get($_ctxt, TypeRange());
}]>
];
let skipDefaultBuilders = 1;
let genStorageClass = 0;
let extraClassDeclaration = [{
/// Accumulate the types contained in this tuple and tuples nested within
/// it. Note that this only flattens nested tuples, not any other container
/// type, e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is
/// flattened to (i32, tensor<i32>, f32, i64)
void getFlattenedTypes(SmallVectorImpl<Type> &types);

/// Return the number of held types.
size_t size() const;

/// Iterate over the held elements.
using iterator = ArrayRef<Type>::iterator;
iterator begin() const { return getTypes().begin(); }
iterator end() const { return getTypes().end(); }

/// Return the element type at index 'index'.
Type getType(size_t index) const {
assert(index < size() && "invalid index for tuple type");
return getTypes()[index];
}
}];
}

#endif // BUILTIN_TYPES
82 changes: 82 additions & 0 deletions mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -2430,6 +2430,73 @@ def replaceWithValue;
// Data type generation
//===----------------------------------------------------------------------===//

// Class for defining a custom type getter.
//
// TableGen generates several generic getter methods for each type by default,
// corresponding to the specified dag parameters. If the default generated ones
// cannot cover some use case, custom getters can be defined using instances of
// this class.
//
// The signature of the `get` is always either:
//
// ```c++
// static <Type-Name> get(MLIRContext *context, <other-parameters>...) {
// <body>...
// }
// ```
//
// or:
//
// ```c++
// static <TypeName> get(MLIRContext *context, <parameters>...);
// ```
//
// To define a custom getter, the parameter list and body should be passed
// in as separate template arguments to this class. The parameter list is a
// TableGen DAG with `ins` operation with named arguments, which has either:
// - string initializers ("Type":$name) to represent a typed parameter, or
// - CArg-typed initializers (CArg<"Type", "default">:$name) to represent a
// typed parameter that may have a default value.
// The type string is used verbatim to produce code and, therefore, must be a
// valid C++ type. It is used inside the C++ namespace of the parent Type's
// dialect; explicit namespace qualification like `::mlir` may be necessary if
// Types are not placed inside the `mlir` namespace. The default value string is
// used verbatim to produce code and must be a valid C++ initializer the given
// type. For example, the following signature specification
//
// ```
// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg)>
// ```
//
// has an integer parameter and a float parameter with a default value.
//
// If an empty string is passed in for `body`, then *only* the builder
// declaration will be generated; this provides a way to define complicated
// builders entirely in C++.
//
// `checkedBody` is similar to `body`, but is the code block used when
// generating a `getChecked` method.
class TypeBuilder<dag parameters, code bodyCode = "",
code checkedBodyCode = ""> {
dag dagParams = parameters;
code body = bodyCode;
code checkedBody = checkedBodyCode;

// The context parameter can be inferred from one of the other parameters and
// is not implicitly added to the parameter list.
bit hasInferredContextParam = 0;
}

// A class of TypeBuilder that is able to infer the MLIRContext parameter from
// one of the other builder parameters. Instances of this builder do not have
// `MLIRContext *` implicitly added to the parameter list.
class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
code checkedBodyCode = "">
: TypeBuilder<parameters, bodyCode> {
code checkedBody = checkedBodyCode;
let hasInferredContextParam = 1;
}

// Define a new type, named `name`, belonging to `dialect` that inherits from
// the given C++ base class.
class TypeDef<Dialect dialect, string name,
Expand Down Expand Up @@ -2475,6 +2542,18 @@ class TypeDef<Dialect dialect, string name,
// for re-allocating ArrayRefs. It is defined below.)
dag parameters = (ins);

// Custom type builder methods.
// In addition to the custom builders provided here, and unless
// skipDefaultBuilders is set, a default builder is generated with the
// following signature:
//
// ```c++
// static <TypeName> get(MLIRContext *, <parameters>);
// ```
//
// Note that builders should only be provided when a type has parameters.
list<TypeBuilder> builders = ?;

// Use the lowercased name as the keyword for parsing/printing. Specify only
// if you want tblgen to generate declarations and/or definitions of
// printer/parser for this type.
Expand All @@ -2488,6 +2567,9 @@ class TypeDef<Dialect dialect, string name,

// If set, generate accessors for each Type parameter.
bit genAccessors = 1;
// Avoid generating default get/getChecked functions. Custom get methods must
// be provided.
bit skipDefaultBuilders = 0;
// Generate the verifyConstructionInvariants declaration and getChecked
// method.
bit genVerifyInvariantsDecl = 0;
Expand Down
85 changes: 85 additions & 0 deletions mlir/include/mlir/TableGen/Builder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
//===- Builder.h - Builder classes ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Builder wrapper to simplify using TableGen Record for building
// operations/types/etc.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TABLEGEN_BUILDER_H_
#define MLIR_TABLEGEN_BUILDER_H_

#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"

namespace llvm {
class Init;
class Record;
class SMLoc;
} // end namespace llvm

namespace mlir {
namespace tblgen {

/// Wrapper class with helper methods for accessing Builders defined in
/// TableGen.
class Builder {
public:
/// This class represents a single parameter to a builder method.
class Parameter {
public:
/// Return a string containing the C++ type of this parameter.
StringRef getCppType() const;

/// Return an optional string containing the name of this parameter. If
/// None, no name was specified for this parameter by the user.
Optional<StringRef> getName() const { return name; }

/// Return an optional string containing the default value to use for this
/// parameter.
Optional<StringRef> getDefaultValue() const;

private:
Parameter(Optional<StringRef> name, const llvm::Init *def)
: name(name), def(def) {}

/// The optional name of the parameter.
Optional<StringRef> name;

/// The tablegen definition of the parameter. This is either a StringInit,
/// or a CArg DefInit.
const llvm::Init *def;

// Allow access to the constructor.
friend Builder;
};

/// Construct a builder from the given Record instance.
Builder(const llvm::Record *record, ArrayRef<llvm::SMLoc> loc);

/// Return a list of parameters used in this build method.
ArrayRef<Parameter> getParameters() const { return parameters; }

/// Return an optional string containing the body of the builder.
Optional<StringRef> getBody() const;

protected:
/// The TableGen definition of this builder.
const llvm::Record *def;

private:
/// A collection of parameters to the builder.
SmallVector<Parameter> parameters;
};

} // end namespace tblgen
} // end namespace mlir

#endif // MLIR_TABLEGEN_BUILDER_H_
7 changes: 7 additions & 0 deletions mlir/include/mlir/TableGen/Operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Builder.h"
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Region.h"
Expand Down Expand Up @@ -287,6 +288,9 @@ class Operator {
// Returns the OperandOrAttribute corresponding to the index.
OperandOrAttribute getArgToOperandOrAttribute(int index) const;

// Returns the builders of this operation.
ArrayRef<Builder> getBuilders() const { return builders; }

private:
// Populates the vectors containing operands, attributes, results and traits.
void populateOpStructure();
Expand Down Expand Up @@ -332,6 +336,9 @@ class Operator {
// Map from argument to attribute or operand number.
SmallVector<OperandOrAttribute, 4> attrOrOperandMapping;

// The builders of this operator.
SmallVector<Builder> builders;

// The number of native attributes stored in the leading positions of
// `attributes`.
int numNativeAttributes;
Expand Down
43 changes: 39 additions & 4 deletions mlir/include/mlir/TableGen/TypeDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,45 @@
#define MLIR_TABLEGEN_TYPEDEF_H

#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/Builder.h"

namespace llvm {
class Record;
class DagInit;
class Record;
class SMLoc;
} // namespace llvm

namespace mlir {
namespace tblgen {

class Dialect;
class TypeParameter;

//===----------------------------------------------------------------------===//
// TypeBuilder
//===----------------------------------------------------------------------===//

/// Wrapper class that represents a Tablegen TypeBuilder.
class TypeBuilder : public Builder {
public:
using Builder::Builder;

/// Return an optional code body used for the `getChecked` variant of this
/// builder.
Optional<StringRef> getCheckedBody() const;

/// Returns true if this builder is able to infer the MLIRContext parameter.
bool hasInferredContextParameter() const;
};

//===----------------------------------------------------------------------===//
// TypeDef
//===----------------------------------------------------------------------===//

/// Wrapper class that contains a TableGen TypeDef's record and provides helper
/// methods for accessing them.
class TypeDef {
public:
explicit TypeDef(const llvm::Record *def) : def(def) {}
explicit TypeDef(const llvm::Record *def);

// Get the dialect for which this type belongs.
Dialect getDialect() const;
Expand Down Expand Up @@ -95,6 +116,13 @@ class TypeDef {
// Get the code location (for error printing).
ArrayRef<llvm::SMLoc> getLoc() const;

// Returns true if the default get/getChecked methods should be skipped during
// generation.
bool skipDefaultBuilders() const;

// Returns the builders of this type.
ArrayRef<TypeBuilder> getBuilders() const { return builders; }

// Returns whether two TypeDefs are equal by checking the equality of the
// underlying record.
bool operator==(const TypeDef &other) const;
Expand All @@ -107,8 +135,15 @@ class TypeDef {

private:
const llvm::Record *def;

// The builders of this type definition.
SmallVector<TypeBuilder> builders;
};

//===----------------------------------------------------------------------===//
// TypeParameter
//===----------------------------------------------------------------------===//

// A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs
// to parameterize them.
class TypeParameter {
Expand Down
46 changes: 4 additions & 42 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,6 @@ using namespace mlir::detail;
/// ComplexType
//===----------------------------------------------------------------------===//

ComplexType ComplexType::get(Type elementType) {
return Base::get(elementType.getContext(), elementType);
}

ComplexType ComplexType::getChecked(Location location, Type elementType) {
return Base::getChecked(location, elementType);
}

/// Verify the construction of an integer type.
LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
Type elementType) {
Expand All @@ -47,8 +39,6 @@ LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
return success();
}

Type ComplexType::getElementType() { return getImpl()->elementType; }

//===----------------------------------------------------------------------===//
// Integer Type
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -126,11 +116,6 @@ FloatType FloatType::scaleElementBitwidth(unsigned scale) {
// FunctionType
//===----------------------------------------------------------------------===//

FunctionType FunctionType::get(MLIRContext *context, TypeRange inputs,
TypeRange results) {
return Base::get(context, inputs, results);
}

unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }

ArrayRef<Type> FunctionType::getInputs() const {
Expand Down Expand Up @@ -189,24 +174,6 @@ FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
// OpaqueType
//===----------------------------------------------------------------------===//

OpaqueType OpaqueType::get(MLIRContext *context, Identifier dialect,
StringRef typeData) {
return Base::get(context, dialect, typeData);
}

OpaqueType OpaqueType::getChecked(Location location, Identifier dialect,
StringRef typeData) {
return Base::getChecked(location, dialect, typeData);
}

/// Returns the dialect namespace of the opaque type.
Identifier OpaqueType::getDialectNamespace() const {
return getImpl()->dialectNamespace;
}

/// Returns the raw type data of the opaque type.
StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }

/// Verify the construction of an opaque type.
LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
Identifier dialect,
Expand Down Expand Up @@ -693,15 +660,6 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
/// TupleType
//===----------------------------------------------------------------------===//

/// Get or create a new TupleType with the provided element types. Assumes the
/// arguments define a well-formed type.
TupleType TupleType::get(MLIRContext *context, TypeRange elementTypes) {
return Base::get(context, elementTypes);
}

/// Get or create an empty tuple type.
TupleType TupleType::get(MLIRContext *context) { return get(context, {}); }

/// Return the elements types for this tuple.
ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }

Expand All @@ -721,6 +679,10 @@ void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
/// Return the number of element types.
size_t TupleType::size() const { return getImpl()->size(); }

//===----------------------------------------------------------------------===//
// Type Utilities
//===----------------------------------------------------------------------===//

AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
int64_t offset,
MLIRContext *context) {
Expand Down
8 changes: 0 additions & 8 deletions mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -772,21 +772,13 @@ getCachedIntegerType(unsigned width,
}
}

IntegerType IntegerType::get(MLIRContext *context, unsigned width) {
return get(context, width, IntegerType::Signless);
}

IntegerType IntegerType::get(MLIRContext *context, unsigned width,
IntegerType::SignednessSemantics signedness) {
if (auto cached = getCachedIntegerType(width, signedness, context))
return cached;
return Base::get(context, width, signedness);
}

IntegerType IntegerType::getChecked(Location location, unsigned width) {
return getChecked(location, width, IntegerType::Signless);
}

IntegerType IntegerType::getChecked(Location location, unsigned width,
SignednessSemantics signedness) {
if (auto cached =
Expand Down
43 changes: 0 additions & 43 deletions mlir/lib/IR/TypeDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,6 @@ class MLIRContext;

namespace detail {

/// Opaque Type Storage and Uniquing.
struct OpaqueTypeStorage : public TypeStorage {
OpaqueTypeStorage(Identifier dialectNamespace, StringRef typeData)
: dialectNamespace(dialectNamespace), typeData(typeData) {}

/// The hash key used for uniquing.
using KeyTy = std::pair<Identifier, StringRef>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(dialectNamespace, typeData);
}

static OpaqueTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
StringRef tyData = allocator.copyInto(key.second);
return new (allocator.allocate<OpaqueTypeStorage>())
OpaqueTypeStorage(key.first, tyData);
}

// The dialect namespace.
Identifier dialectNamespace;

// The parser type data for this opaque type.
StringRef typeData;
};

/// Integer Type Storage and Uniquing.
struct IntegerTypeStorage : public TypeStorage {
IntegerTypeStorage(unsigned width,
Expand Down Expand Up @@ -290,24 +265,6 @@ struct UnrankedMemRefTypeStorage : public BaseMemRefTypeStorage {
}
};

/// Complex Type Storage.
struct ComplexTypeStorage : public TypeStorage {
ComplexTypeStorage(Type elementType) : elementType(elementType) {}

/// The hash key used for uniquing.
using KeyTy = Type;
bool operator==(const KeyTy &key) const { return key == elementType; }

/// Construction.
static ComplexTypeStorage *construct(TypeStorageAllocator &allocator,
Type elementType) {
return new (allocator.allocate<ComplexTypeStorage>())
ComplexTypeStorage(elementType);
}

Type elementType;
};

/// A type representing a collection of other types.
struct TupleTypeStorage final
: public TypeStorage,
Expand Down
74 changes: 74 additions & 0 deletions mlir/lib/TableGen/Builder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//===- Builder.cpp - Builder definitions ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/TableGen/Builder.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"

using namespace mlir;
using namespace mlir::tblgen;

//===----------------------------------------------------------------------===//
// Builder::Parameter
//===----------------------------------------------------------------------===//

/// Return a string containing the C++ type of this parameter.
StringRef Builder::Parameter::getCppType() const {
if (const auto *stringInit = dyn_cast<llvm::StringInit>(def))
return stringInit->getValue();
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
return record->getValueAsString("type");
}

/// Return an optional string containing the default value to use for this
/// parameter.
Optional<StringRef> Builder::Parameter::getDefaultValue() const {
if (isa<llvm::StringInit>(def))
return llvm::None;
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
Optional<StringRef> value = record->getValueAsOptionalString("defaultValue");
return value && !value->empty() ? value : llvm::None;
}

//===----------------------------------------------------------------------===//
// Builder
//===----------------------------------------------------------------------===//

Builder::Builder(const llvm::Record *record, ArrayRef<llvm::SMLoc> loc)
: def(record) {
// Initialize the parameters of the builder.
const llvm::DagInit *dag = def->getValueAsDag("dagParams");
auto *defInit = dyn_cast<llvm::DefInit>(dag->getOperator());
if (!defInit || !defInit->getDef()->getName().equals("ins"))
PrintFatalError(def->getLoc(), "expected 'ins' in builders");

bool seenDefaultValue = false;
for (unsigned i = 0, e = dag->getNumArgs(); i < e; ++i) {
const llvm::StringInit *paramName = dag->getArgName(i);
const llvm::Init *paramValue = dag->getArg(i);
Parameter param(paramName ? paramName->getValue() : Optional<StringRef>(),
paramValue);

// Similarly to C++, once an argument with a default value is detected, the
// following arguments must have default values as well.
if (param.getDefaultValue()) {
seenDefaultValue = true;
} else if (seenDefaultValue) {
PrintFatalError(loc,
"expected an argument with default value after other "
"arguments with default values");
}
parameters.emplace_back(param);
}
}

/// Return an optional string containing the body of the builder.
Optional<StringRef> Builder::getBody() const {
Optional<StringRef> body = def->getValueAsOptionalString("body");
return body && !body->empty() ? body : llvm::None;
}
1 change: 1 addition & 0 deletions mlir/lib/TableGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
llvm_add_library(MLIRTableGen STATIC
Argument.cpp
Attribute.cpp
Builder.cpp
Constraint.cpp
Dialect.cpp
Format.cpp
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/TableGen/Operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,18 @@ void Operator::populateOpStructure() {
regions.push_back({name, region});
}

// Populate the builders.
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (llvm::Init *init : builderList->getValues())
builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc());
} else if (skipDefaultBuilders()) {
PrintFatalError(
def.getLoc(),
"default builders are skipped and no custom builders provided");
}

LLVM_DEBUG(print(llvm::dbgs()));
}

Expand Down
53 changes: 53 additions & 0 deletions mlir/lib/TableGen/TypeDef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,34 @@
//===----------------------------------------------------------------------===//

#include "mlir/TableGen/TypeDef.h"
#include "mlir/TableGen/Dialect.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"

using namespace mlir;
using namespace mlir::tblgen;

//===----------------------------------------------------------------------===//
// TypeBuilder
//===----------------------------------------------------------------------===//

/// Return an optional code body used for the `getChecked` variant of this
/// builder.
Optional<StringRef> TypeBuilder::getCheckedBody() const {
Optional<StringRef> body = def->getValueAsOptionalString("checkedBody");
return body && !body->empty() ? body : llvm::None;
}

/// Returns true if this builder is able to infer the MLIRContext parameter.
bool TypeBuilder::hasInferredContextParameter() const {
return def->getValueAsBit("hasInferredContextParam");
}

//===----------------------------------------------------------------------===//
// TypeDef
//===----------------------------------------------------------------------===//

Dialect TypeDef::getDialect() const {
auto *dialectDef =
dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
Expand Down Expand Up @@ -98,6 +119,11 @@ llvm::Optional<StringRef> TypeDef::getExtraDecls() const {
return value.empty() ? llvm::Optional<StringRef>() : value;
}
llvm::ArrayRef<llvm::SMLoc> TypeDef::getLoc() const { return def->getLoc(); }

bool TypeDef::skipDefaultBuilders() const {
return def->getValueAsBit("skipDefaultBuilders");
}

bool TypeDef::operator==(const TypeDef &other) const {
return def == other.def;
}
Expand All @@ -106,6 +132,33 @@ bool TypeDef::operator<(const TypeDef &other) const {
return getName() < other.getName();
}

//===----------------------------------------------------------------------===//
// TypeParameter
//===----------------------------------------------------------------------===//

TypeDef::TypeDef(const llvm::Record *def) : def(def) {
// Populate the builders.
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (llvm::Init *init : builderList->getValues()) {
TypeBuilder builder(cast<llvm::DefInit>(init)->getDef(), def->getLoc());

// Ensure that all parameters have names.
for (const TypeBuilder::Parameter &param : builder.getParameters()) {
if (!param.getName())
PrintFatalError(def->getLoc(),
"type builder parameters must have a name");
}
builders.emplace_back(builder);
}
} else if (skipDefaultBuilders()) {
PrintFatalError(
def->getLoc(),
"default builders are skipped and no custom builders provided");
}
}

StringRef TypeParameter::getName() const {
return def->getArgName(num)->getValue();
}
Expand Down
20 changes: 14 additions & 6 deletions mlir/test/lib/Dialect/Test/TestTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def IntegerType : Test_Type<"TestInteger"> {
let genVerifyInvariantsDecl = 1;
let parameters = (
ins
"unsigned":$width,
// SignednessSemantics is defined below.
"::mlir::test::TestIntegerType::SignednessSemantics":$signedness,
"unsigned":$width
"::mlir::test::TestIntegerType::SignednessSemantics":$signedness
);

// We define the printer inline.
Expand All @@ -63,6 +63,17 @@ def IntegerType : Test_Type<"TestInteger"> {
$_printer << ", " << getImpl()->width << ">";
}];

// Define custom builder methods.
let builders = [
TypeBuilder<(ins "unsigned":$width,
CArg<"SignednessSemantics", "Signless">:$signedness), [{
return Base::get($_ctxt, width, signedness);
}], [{
return Base::getChecked($_loc, width, signedness);
}]>
];
let skipDefaultBuilders = 1;

// The parser is defined here also.
let parser = [{
if (parser.parseLess()) return Type();
Expand All @@ -73,7 +84,7 @@ def IntegerType : Test_Type<"TestInteger"> {
if ($_parser.parseInteger(width)) return Type();
if ($_parser.parseGreater()) return Type();
Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
return getChecked(loc, signedness, width);
return getChecked(loc, width, signedness);
}];

// Any extra code one wants in the type's class declaration.
Expand All @@ -85,9 +96,6 @@ def IntegerType : Test_Type<"TestInteger"> {
Unsigned, /// Unsigned integer
};

/// This extra function is necessary since it doesn't include signedness
static IntegerType getChecked(unsigned width, Location location);

/// Return true if this is a signless integer type.
bool isSignless() const { return getSignedness() == Signless; }
/// Return true if this is a signed integer type.
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/lib/Dialect/Test/TestTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT

// Example type validity checker.
LogicalResult TestIntegerType::verifyConstructionInvariants(
Location loc, TestIntegerType::SignednessSemantics ss, unsigned int width) {
Location loc, unsigned width, TestIntegerType::SignednessSemantics ss) {
if (width > 8)
return failure();
return success();
Expand Down
24 changes: 13 additions & 11 deletions mlir/test/mlir-tblgen/typedefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ include "mlir/IR/OpBase.td"
// DEF: ::mlir::test::SingleParameterType,
// DEF: ::mlir::test::IntegerType

// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic)
// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(ctxt, parser);
// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &parser, ::llvm::StringRef mnemonic)
// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(context, parser);
// DEF return ::mlir::Type();

def Test_Dialect: Dialect {
Expand All @@ -33,7 +33,7 @@ def Test_Dialect: Dialect {
class TestType<string name> : TypeDef<Test_Dialect, name> { }

def A_SimpleTypeA : TestType<"SimpleA"> {
// DECL: class SimpleAType: public ::mlir::Type
// DECL: class SimpleAType : public ::mlir::Type
}

def RTLValueType : Type<CPred<"isRTLValueType($_self)">, "Type"> {
Expand All @@ -56,12 +56,13 @@ def B_CompoundTypeA : TestType<"CompoundA"> {

let genVerifyInvariantsDecl = 1;

// DECL-LABEL: class CompoundAType: public ::mlir::Type
// DECL-LABEL: class CompoundAType : public ::mlir::Type
// DECL: static CompoundAType getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::mlir::Type getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; }
// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
// DECL: static ::mlir::Type parse(::mlir::MLIRContext *context,
// DECL-NEXT: ::mlir::DialectAsmParser &parser);
// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
// DECL: int getWidthOfSomething() const;
// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
// DECL: SomeCppStruct getExampleCppType() const;
Expand All @@ -75,10 +76,11 @@ def C_IndexType : TestType<"Index"> {
StringRefParameter<"Label for index">:$label
);

// DECL-LABEL: class IndexType: public ::mlir::Type
// DECL-LABEL: class IndexType : public ::mlir::Type
// DECL: static ::llvm::StringRef getMnemonic() { return "index"; }
// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
// DECL: static ::mlir::Type parse(::mlir::MLIRContext *context,
// DECL-NEXT: ::mlir::DialectAsmParser &parser);
// DECL: void print(::mlir::DialectAsmPrinter &printer) const;
}

def D_SingleParameterType : TestType<"SingleParameter"> {
Expand All @@ -100,7 +102,7 @@ def E_IntegerType : TestType<"Integer"> {
TypeParameter<"unsigned", "Bitwidth of integer">:$width
);

// DECL-LABEL: IntegerType: public ::mlir::Type
// DECL-LABEL: IntegerType : public ::mlir::Type

let extraClassDeclaration = [{
/// Signedness semantics.
Expand Down
98 changes: 30 additions & 68 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ static cl::opt<std::string> opExcFilter(

static const char *const tblgenNamePrefix = "tblgen_";
static const char *const generatedArgName = "odsArg";
static const char *const builder = "odsBuilder";
static const char *const odsBuilder = "odsBuilder";
static const char *const builderOpState = "odsState";

// The logic to calculate the actual value range for a declared operand/result
Expand Down Expand Up @@ -1326,96 +1326,58 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
body << " }\n";
}

/// Returns a signature of the builder as defined by a dag-typed initializer.
/// Updates the context `fctx` to enable replacement of $_builder and $_state
/// in the body. Reports errors at `loc`.
static std::string builderSignatureFromDAG(const DagInit *init,
ArrayRef<llvm::SMLoc> loc) {
auto *defInit = dyn_cast<DefInit>(init->getOperator());
if (!defInit || !defInit->getDef()->getName().equals("ins"))
PrintFatalError(loc, "expected 'ins' in builders");
/// Returns a signature of the builder. Updates the context `fctx` to enable
/// replacement of $_builder and $_state in the body.
static std::string getBuilderSignature(const Builder &builder) {
ArrayRef<Builder::Parameter> params(builder.getParameters());

// Inject builder and state arguments.
llvm::SmallVector<std::string, 8> arguments;
arguments.reserve(init->getNumArgs() + 2);
arguments.push_back(llvm::formatv("::mlir::OpBuilder &{0}", builder).str());
arguments.reserve(params.size() + 2);
arguments.push_back(
llvm::formatv("::mlir::OpBuilder &{0}", odsBuilder).str());
arguments.push_back(
llvm::formatv("::mlir::OperationState &{0}", builderOpState).str());

// Accept either a StringInit or a DefInit with two string values as dag
// arguments. The former corresponds to the type, the latter to the type and
// the default value. Similarly to C++, once an argument with a default value
// is detected, the following arguments must have default values as well.
bool seenDefaultValue = false;
for (unsigned i = 0, e = init->getNumArgs(); i < e; ++i) {
for (unsigned i = 0, e = params.size(); i < e; ++i) {
// If no name is provided, generate one.
StringInit *argName = init->getArgName(i);
Optional<StringRef> paramName = params[i].getName();
std::string name =
argName ? argName->getValue().str() : "odsArg" + std::to_string(i);
paramName ? paramName->str() : "odsArg" + std::to_string(i);

Init *argInit = init->getArg(i);
StringRef type;
std::string defaultValue;
if (StringInit *strType = dyn_cast<StringInit>(argInit)) {
type = strType->getValue();
} else {
const Record *typeAndDefaultValue = cast<DefInit>(argInit)->getDef();
type = typeAndDefaultValue->getValueAsString("type");
StringRef defaultValueRef =
typeAndDefaultValue->getValueAsString("defaultValue");
if (!defaultValueRef.empty()) {
seenDefaultValue = true;
defaultValue = llvm::formatv(" = {0}", defaultValueRef).str();
}
}
if (seenDefaultValue && defaultValue.empty())
PrintFatalError(loc,
"expected an argument with default value after other "
"arguments with default values");
if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
defaultValue = llvm::formatv(" = {0}", *defaultParamValue).str();
arguments.push_back(
llvm::formatv("{0} {1}{2}", type, name, defaultValue).str());
llvm::formatv("{0} {1}{2}", params[i].getCppType(), name, defaultValue)
.str());
}

return llvm::join(arguments, ", ");
}

void OpEmitter::genBuilder() {
// Handle custom builders if provided.
// TODO: Create wrapper class for OpBuilder to hide the native
// TableGen API calls here.
{
auto *listInit = dyn_cast_or_null<ListInit>(def.getValueInit("builders"));
if (listInit) {
for (Init *init : listInit->getValues()) {
Record *builderDef = cast<DefInit>(init)->getDef();
std::string paramStr = builderSignatureFromDAG(
builderDef->getValueAsDag("dagParams"), op.getLoc());

StringRef body = builderDef->getValueAsString("body");
bool hasBody = !body.empty();
OpMethod::Property properties =
hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
auto *method =
opClass.addMethodAndPrune("void", "build", properties, paramStr);
for (const Builder &builder : op.getBuilders()) {
std::string paramStr = getBuilderSignature(builder);

FmtContext fctx;
fctx.withBuilder(builder);
fctx.addSubst("_state", builderOpState);
if (hasBody)
method->body() << tgfmt(body, &fctx);
}
}
if (op.skipDefaultBuilders()) {
if (!listInit || listInit->empty())
PrintFatalError(
op.getLoc(),
"default builders are skipped and no custom builders provided");
return;
}
Optional<StringRef> body = builder.getBody();
OpMethod::Property properties =
body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
auto *method =
opClass.addMethodAndPrune("void", "build", properties, paramStr);

FmtContext fctx;
fctx.withBuilder(odsBuilder);
fctx.addSubst("_state", builderOpState);
if (body)
method->body() << tgfmt(*body, &fctx);
}

// Generate default builders that requires all result type, operands, and
// attributes as parameters.
if (op.skipDefaultBuilders())
return;

// We generate three classes of builders here:
// 1. one having a stand-alone parameter for each operand / attribute, and
Expand Down
245 changes: 173 additions & 72 deletions mlir/tools/mlir-tblgen/TypeDefGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class DialectAsmPrinter;
/// {0}: The name of the typeDef class.
/// {1}: The name of the type base class.
static const char *const typeDefDeclSingletonBeginStr = R"(
class {0}: public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{
class {0} : public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{
public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
Expand All @@ -166,9 +166,9 @@ static const char *const typeDefDeclSingletonBeginStr = R"(
static const char *const typeDefDeclParametricBeginStr = R"(
namespace {2} {
struct {3};
}
class {0}: public ::mlir::Type::TypeBase<{0}, {1},
{2}::{3}> {{
} // end namespace {2}
class {0} : public ::mlir::Type::TypeBase<{0}, {1},
{2}::{3}> {{
public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
Expand All @@ -177,18 +177,68 @@ static const char *const typeDefDeclParametricBeginStr = R"(

/// The snippet for print/parse.
static const char *const typeDefParsePrint = R"(
static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
void print(::mlir::DialectAsmPrinter& printer) const;
static ::mlir::Type parse(::mlir::MLIRContext *context,
::mlir::DialectAsmParser &parser);
void print(::mlir::DialectAsmPrinter &printer) const;
)";

/// The code block for the verifyConstructionInvariants and getChecked.
///
/// {0}: List of parameters, parameters style.
/// {0}: The name of the typeDef class.
/// {1}: List of parameters, parameters style.
static const char *const typeDefDeclVerifyStr = R"(
static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{0});
static ::mlir::Type getChecked(::mlir::Location loc{0});
static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{1});
)";

/// Emit the builders for the given type.
static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os,
TypeParamCommaFormatter &paramTypes) {
StringRef typeClass = typeDef.getCppClassName();
bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
if (!typeDef.skipDefaultBuilders()) {
os << llvm::formatv(
" static {0} get(::mlir::MLIRContext *context{1});\n", typeClass,
paramTypes);
if (genCheckedMethods) {
os << llvm::formatv(
" static {0} getChecked(::mlir::Location loc{1});\n", typeClass,
paramTypes);
}
}

// Generate the builders specified by the user.
for (const TypeBuilder &builder : typeDef.getBuilders()) {
std::string paramStr;
llvm::raw_string_ostream paramOS(paramStr);
llvm::interleaveComma(
builder.getParameters(), paramOS,
[&](const TypeBuilder::Parameter &param) {
// Note: TypeBuilder parameters are guaranteed to have names.
paramOS << param.getCppType() << " " << *param.getName();
if (Optional<StringRef> defaultParamValue = param.getDefaultValue())
paramOS << " = " << *defaultParamValue;
});
paramOS.flush();

// Generate the `get` variant of the builder.
os << " static " << typeClass << " get(";
if (!builder.hasInferredContextParameter()) {
os << "::mlir::MLIRContext *context";
if (!paramStr.empty())
os << ", ";
}
os << paramStr << ");\n";

// Generate the `getChecked` variant of the builder.
if (genCheckedMethods) {
os << " static " << typeClass << " getChecked(::mlir::Location loc";
if (!paramStr.empty())
os << ", " << paramStr;
os << ");\n";
}
}
}

/// Generate the declaration for the given typeDef class.
static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
SmallVector<TypeParameter, 4> params;
Expand All @@ -212,21 +262,22 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
TypeParamCommaFormatter emitTypeNamePairsAfterComma(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
if (!params.empty()) {
os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n",
typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
}
emitTypeBuilderDecls(typeDef, os, emitTypeNamePairsAfterComma);

// Emit the verify invariants declaration.
if (typeDef.genVerifyInvariantsDecl())
os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma);
// Emit the verify invariants declaration.
if (typeDef.genVerifyInvariantsDecl())
os << llvm::formatv(typeDefDeclVerifyStr, typeDef.getCppClassName(),
emitTypeNamePairsAfterComma);
}

// Emit the mnenomic, if specified.
if (auto mnenomic = typeDef.getMnemonic()) {
os << " static ::llvm::StringRef getMnemonic() { return \"" << mnenomic
<< "\"; }\n";

// If mnemonic specified, emit print/parse declarations.
os << typeDefParsePrint;
if (typeDef.getParserCode() || typeDef.getPrinterCode() || !params.empty())
os << typeDefParsePrint;
}

if (typeDef.genAccessors()) {
Expand Down Expand Up @@ -330,17 +381,6 @@ static const char *const typeDefStorageClassConstructorReturn = R"(
}
)";

/// The code block for the getChecked definition.
///
/// {0}: List of parameters, parameters style.
/// {1}: C++ type class name.
/// {2}: Comma separated list of parameter names.
static const char *const typeDefDefGetCheckeStr = R"(
::mlir::Type {1}::getChecked(Location loc{0}) {{
return Base::getChecked(loc{2});
}
)";

/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
SmallVector<TypeParameter, 4> parameters;
Expand Down Expand Up @@ -403,13 +443,13 @@ static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
parameters, /* prependComma */ false));

// 3) Emit the construct method.
if (typeDef.hasStorageCustomConstructor())
if (typeDef.hasStorageCustomConstructor()) {
// If user wants to build the storage constructor themselves, declare it
// here and then they can write the definition elsewhere.
os << " static " << typeDef.getStorageClassName()
<< " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy "
"&key);\n";
else {
} else {
// If not, autogenerate one.

// First, unbox the parameters.
Expand Down Expand Up @@ -445,7 +485,7 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
// Both the mnenomic and printerCode must be defined (for parity with
// parserCode).
os << "void " << typeDef.getCppClassName()
<< "::print(::mlir::DialectAsmPrinter& printer) const {\n";
<< "::print(::mlir::DialectAsmPrinter &printer) const {\n";
if (*printerCode == "") {
// If no code specified, emit error.
PrintFatalError(typeDef.getLoc(),
Expand All @@ -460,7 +500,7 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
if (auto parserCode = typeDef.getParserCode()) {
// The mnenomic must be defined so the dispatcher knows how to dispatch.
os << "::mlir::Type " << typeDef.getCppClassName()
<< "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& "
<< "::parse(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &"
"parser) "
"{\n";
if (*parserCode == "") {
Expand All @@ -470,51 +510,112 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
": parser (if specified) must have non-empty code");
}
auto fmtCtxt =
FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "ctxt");
FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "context");
os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
}
}

/// Print all the typedef-specific definition code.
static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
NamespaceEmitter ns(os, typeDef.getDialect());
SmallVector<TypeParameter, 4> parameters;
typeDef.getParameters(parameters);

// Emit the storage class, if requested and necessary.
if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0)
emitStorageClass(typeDef, os);

if (!parameters.empty()) {
/// Emit the builders for the given type.
static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
ArrayRef<TypeParameter> typeDefParams) {
bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
StringRef typeClass = typeDef.getCppClassName();
if (!typeDef.skipDefaultBuilders()) {
os << llvm::formatv(
"{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
" return Base::get(ctxt{2});\n}\n",
typeDef.getCppClassName(),
"{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n"
" return Base::get(context{2});\n}\n",
typeClass,
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, typeDefParams),
TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
parameters));
typeDefParams));
if (genCheckedMethods) {
os << llvm::formatv(
"{0} {0}::getChecked(::mlir::Location loc{1}) {{\n"
" return Base::getChecked(loc{2});\n}\n",
typeClass,
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
typeDefParams),
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::JustParams, typeDefParams));
}
}

// Emit the parameter accessors.
if (typeDef.genAccessors())
for (const TypeParameter &parameter : parameters) {
SmallString<16> name = parameter.getName();
name[0] = llvm::toUpper(name[0]);
os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
parameter.getCppType(), name, parameter.getName(),
typeDef.getCppClassName());
// Generate the builders specified by the user.
auto builderFmtCtx = FmtContext().addSubst("_ctxt", "context");
auto checkedBuilderFmtCtx = FmtContext()
.addSubst("_loc", "loc")
.addSubst("_ctxt", "loc.getContext()");
for (const TypeBuilder &builder : typeDef.getBuilders()) {
Optional<StringRef> body = builder.getBody();
Optional<StringRef> checkedBody =
genCheckedMethods ? builder.getCheckedBody() : llvm::None;
if (!body && !checkedBody)
continue;
std::string paramStr;
llvm::raw_string_ostream paramOS(paramStr);
llvm::interleaveComma(builder.getParameters(), paramOS,
[&](const TypeBuilder::Parameter &param) {
// Note: TypeBuilder parameters are guaranteed to
// have names.
paramOS << param.getCppType() << " "
<< *param.getName();
});
paramOS.flush();

// Emit the `get` variant of the builder.
if (body) {
os << llvm::formatv("{0} {0}::get(", typeClass);
if (!builder.hasInferredContextParameter()) {
os << "::mlir::MLIRContext *context";
if (!paramStr.empty())
os << ", ";
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
tgfmt(*body, &builderFmtCtx).str());
} else {
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, *body);
}
}

// Generate getChecked() method.
if (typeDef.genVerifyInvariantsDecl()) {
os << llvm::formatv(
typeDefDefGetCheckeStr,
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
typeDef.getCppClassName(),
TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
parameters));
// Emit the `getChecked` variant of the builder.
if (checkedBody) {
os << llvm::formatv("{0} {0}::getChecked(::mlir::Location loc",
typeClass);
if (!paramStr.empty())
os << ", " << paramStr;
os << llvm::formatv(") {{\n {0};\n}\n",
tgfmt(*checkedBody, &checkedBuilderFmtCtx));
}
}
}

/// Print all the typedef-specific definition code.
static void emitTypeDefDef(const TypeDef &typeDef, raw_ostream &os) {
NamespaceEmitter ns(os, typeDef.getDialect());

SmallVector<TypeParameter, 4> parameters;
typeDef.getParameters(parameters);
if (!parameters.empty()) {
// Emit the storage class, if requested and necessary.
if (typeDef.genStorageClass())
emitStorageClass(typeDef, os);

// Emit the builders for this type.
emitTypeBuilderDefs(typeDef, os, parameters);

// Generate accessor definitions only if we also generate the storage class.
// Otherwise, let the user define the exact accessor definition.
if (typeDef.genAccessors() && typeDef.genStorageClass()) {
// Emit the parameter accessors.
for (const TypeParameter &parameter : parameters) {
SmallString<16> name = parameter.getName();
name[0] = llvm::toUpper(name[0]);
os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
parameter.getCppType(), name, parameter.getName(),
typeDef.getCppClassName());
}
}
}

// If mnemonic is specified maybe print definitions for the parser and printer
Expand All @@ -534,9 +635,9 @@ static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {

// The parser dispatch is just a list of if-elses, matching on the
// mnemonic and calling the class's parse function.
os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext* "
"ctxt, "
"::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n";
os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext *"
"context, ::mlir::DialectAsmParser &parser, "
"::llvm::StringRef mnemonic) {\n";
for (const TypeDef &type : types) {
if (type.getMnemonic()) {
os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return "
Expand All @@ -547,9 +648,9 @@ static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
// If the type has no parameters and no parser code, just invoke a normal
// `get`.
if (type.getNumParameters() == 0 && !type.getParserCode())
os << "get(ctxt);\n";
os << "get(context);\n";
else
os << "parse(ctxt, parser);\n";
os << "parse(context, parser);\n";
}
}
os << " return ::mlir::Type();\n";
Expand All @@ -559,7 +660,7 @@ static void emitParsePrintDispatch(ArrayRef<TypeDef> types, raw_ostream &os) {
// printer.
os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type "
"type, "
"::mlir::DialectAsmPrinter& printer) {\n"
"::mlir::DialectAsmPrinter &printer) {\n"
<< " return ::llvm::TypeSwitch<::mlir::Type, "
"::mlir::LogicalResult>(type)\n";
for (const TypeDef &type : types) {
Expand Down Expand Up @@ -594,7 +695,7 @@ static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper,

IfDefScope scope("GET_TYPEDEF_CLASSES", os);
emitParsePrintDispatch(typeDefs, os);
for (auto typeDef : typeDefs)
for (const TypeDef &typeDef : typeDefs)
emitTypeDefDef(typeDef, os);

return false;
Expand Down