Skip to content

Commit

Permalink
[mlir][IR][NFC] Move the definitions of Complex/Function/Integer/Opaq…
Browse files Browse the repository at this point in the history
…ue/TupleType to ODS

The type tablegen backend now has enough support to represent these types well enough, so we can now move them to be declaratively defined.

Differential Revision: https://reviews.llvm.org/D94275
  • Loading branch information
River707 committed Jan 11, 2021
1 parent 948be58 commit d79642b
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 307 deletions.
219 changes: 5 additions & 214 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,119 +29,15 @@ class TypeRange;
namespace detail {

struct BaseMemRefTypeStorage;
struct ComplexTypeStorage;
struct FunctionTypeStorage;
struct IntegerTypeStorage;
struct MemRefTypeStorage;
struct OpaqueTypeStorage;
struct RankedTensorTypeStorage;
struct ShapedTypeStorage;
struct TupleTypeStorage;
struct UnrankedMemRefTypeStorage;
struct UnrankedTensorTypeStorage;
struct VectorTypeStorage;

} // namespace detail

//===----------------------------------------------------------------------===//
// ComplexType
//===----------------------------------------------------------------------===//

/// The 'complex' type represents a complex number with a parameterized element
/// type, which is composed of a real and imaginary value of that element type.
///
/// The element must be a floating point or integer scalar type.
///
class ComplexType
: public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> {
public:
using Base::Base;

/// Get or create a ComplexType with the provided element type.
static ComplexType get(Type elementType);

/// Get or create a ComplexType with the provided element type. This emits
/// and error at the specified location and returns null if the element type
/// isn't supported.
static ComplexType getChecked(Location location, Type elementType);

/// Verify the construction of an integer type.
static LogicalResult verifyConstructionInvariants(Location loc,
Type elementType);

Type getElementType();
};

//===----------------------------------------------------------------------===//
// IntegerType
//===----------------------------------------------------------------------===//

/// Integer types can have arbitrary bitwidth up to a large fixed limit.
class IntegerType
: public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> {
public:
using Base::Base;

/// Signedness semantics.
enum SignednessSemantics : uint32_t {
Signless, /// No signedness semantics
Signed, /// Signed integer
Unsigned, /// Unsigned integer
};

/// Get or create a new IntegerType of the given width within the context.
/// The created IntegerType is signless (i.e., no signedness semantics).
/// Assume the width is within the allowed range and assert on failures. Use
/// getChecked to handle failures gracefully.
static IntegerType get(MLIRContext *context, unsigned width);

/// Get or create a new IntegerType of the given width within the context.
/// The created IntegerType has signedness semantics as indicated via
/// `signedness`. Assume the width is within the allowed range and assert on
/// failures. Use getChecked to handle failures gracefully.
static IntegerType get(MLIRContext *context, unsigned width,
SignednessSemantics signedness);

/// Get or create a new IntegerType of the given width within the context,
/// defined at the given, potentially unknown, location. The created
/// IntegerType is signless (i.e., no signedness semantics). If the width is
/// outside the allowed range, emit errors and return a null type.
static IntegerType getChecked(Location location, unsigned width);

/// Get or create a new IntegerType of the given width within the context,
/// defined at the given, potentially unknown, location. The created
/// IntegerType has signedness semantics as indicated via `signedness`. If the
/// width is outside the allowed range, emit errors and return a null type.
static IntegerType getChecked(Location location, unsigned width,
SignednessSemantics signedness);

/// Verify the construction of an integer type.
static LogicalResult
verifyConstructionInvariants(Location loc, unsigned width,
SignednessSemantics signedness);

/// Return the bitwidth of this integer type.
unsigned getWidth() const;

/// Return the signedness semantics of this integer type.
SignednessSemantics getSignedness() const;

/// Return true if this is a signless integer type.
bool isSignless() const { return getSignedness() == Signless; }
/// Return true if this is a signed integer type.
bool isSigned() const { return getSignedness() == Signed; }
/// Return true if this is an unsigned integer type.
bool isUnsigned() const { return getSignedness() == Unsigned; }

/// Get or create a new IntegerType with the same signedness as `this` and a
/// bitwidth scaled by `scale`.
/// Return null if the scaled element type cannot be represented.
IntegerType scaleElementBitwidth(unsigned scale);

/// Integer representation maximal bitwidth.
static constexpr unsigned kMaxWidth = (1 << 24) - 1; // Aligned with LLVM
};

//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -170,68 +66,6 @@ class FloatType : public Type {
const llvm::fltSemantics &getFloatSemantics();
};

//===----------------------------------------------------------------------===//
// FunctionType
//===----------------------------------------------------------------------===//

/// Function types map from a list of inputs to a list of results.
class FunctionType
: public Type::TypeBase<FunctionType, Type, detail::FunctionTypeStorage> {
public:
using Base::Base;

static FunctionType get(MLIRContext *context, TypeRange inputs,
TypeRange results);

/// Input types.
unsigned getNumInputs() const;
Type getInput(unsigned i) const { return getInputs()[i]; }
ArrayRef<Type> getInputs() const;

/// Result types.
unsigned getNumResults() const;
Type getResult(unsigned i) const { return getResults()[i]; }
ArrayRef<Type> getResults() const;

/// Returns a new function type without the specified arguments and results.
FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
ArrayRef<unsigned> resultIndices);
};

//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//

/// Opaque types represent types of non-registered dialects. These are types
/// represented in their raw string form, and can only usefully be tested for
/// type equality.
class OpaqueType
: public Type::TypeBase<OpaqueType, Type, detail::OpaqueTypeStorage> {
public:
using Base::Base;

/// Get or create a new OpaqueType with the provided dialect and string data.
static OpaqueType get(MLIRContext *context, Identifier dialect,
StringRef typeData);

/// Get or create a new OpaqueType with the provided dialect and string data.
/// If the given identifier is not a valid namespace for a dialect, then a
/// null type is returned.
static OpaqueType getChecked(Location location, Identifier dialect,
StringRef typeData);

/// Returns the dialect namespace of the opaque type.
Identifier getDialectNamespace() const;

/// Returns the raw type data of the opaque type.
StringRef getTypeData() const;

/// Verify the construction of an opaque type.
static LogicalResult verifyConstructionInvariants(Location loc,
Identifier dialect,
StringRef typeData);
};

//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -444,9 +278,7 @@ class BaseMemRefType : public ShapedType {
using ShapedType::ShapedType;

/// Return true if the specified element type is ok in a memref.
static bool isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() || type.isa<VectorType, ComplexType>();
}
static bool isValidElementType(Type type);

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
Expand Down Expand Up @@ -584,51 +416,6 @@ class UnrankedMemRefType

ArrayRef<int64_t> getShape() const { return llvm::None; }
};

//===----------------------------------------------------------------------===//
// TupleType
//===----------------------------------------------------------------------===//

/// Tuple types represent a collection of other types. Note: This type merely
/// provides a common mechanism for representing tuples in MLIR. It is up to
/// dialect authors to provides operations for manipulating them, e.g.
/// extract_tuple_element. When possible, users should prefer multi-result
/// operations in the place of tuples.
class TupleType
: public Type::TypeBase<TupleType, Type, detail::TupleTypeStorage> {
public:
using Base::Base;

/// Get or create a new TupleType with the provided element types. Assumes the
/// arguments define a well-formed type.
static TupleType get(MLIRContext *context, TypeRange elementTypes);

/// Get or create an empty tuple type.
static TupleType get(MLIRContext *context);

/// Return the elements types for this tuple.
ArrayRef<Type> getTypes() const;

/// Accumulate the types contained in this tuple and tuples nested within it.
/// Note that this only flattens nested tuples, not any other container type,
/// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
/// (i32, tensor<i32>, f32, i64)
void getFlattenedTypes(SmallVectorImpl<Type> &types);

/// Return the number of held types.
size_t size() const;

/// Iterate over the held elements.
using iterator = ArrayRef<Type>::iterator;
iterator begin() const { return getTypes().begin(); }
iterator end() const { return getTypes().end(); }

/// Return the element type at index 'index'.
Type getType(size_t index) const {
assert(index < size() && "invalid index for tuple type");
return getTypes()[index];
}
};
} // end namespace mlir

//===----------------------------------------------------------------------===//
Expand All @@ -647,6 +434,10 @@ inline bool BaseMemRefType::classof(Type type) {
return type.isa<MemRefType, UnrankedMemRefType>();
}

inline bool BaseMemRefType::isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() || type.isa<ComplexType, VectorType>();
}

inline bool FloatType::classof(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>();
}
Expand Down
Loading

0 comments on commit d79642b

Please sign in to comment.