Skip to content

Commit

Permalink
[mlir] Refactoring the tablegen Tensor types
Browse files Browse the repository at this point in the history
Reduces repetition in tablegen files for defining various tensor types.  In particular the goal is to reduce the repetition when defining new tensor types (e.g., D126994).

Reviewed By: aartbik, rriddle

Differential Revision: https://reviews.llvm.org/D127039
  • Loading branch information
wrengr committed Jun 8, 2022
1 parent 49ed5bf commit 0371ddf
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 41 deletions.
12 changes: 2 additions & 10 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
Expand Up @@ -100,20 +100,12 @@ def IsSparseTensorPred
// `RankedTensorOf`, `AnyRankedTensor`.

class SparseTensorOf<list<Type> allowedTypes>
: ShapedContainerType<
allowedTypes,
And<[IsTensorTypePred, IsSparseTensorPred]>,
"sparse tensor",
"::mlir::TensorType">;
: TensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;

def AnySparseTensor : SparseTensorOf<[AnyType]>;

class RankedSparseTensorOf<list<Type> allowedTypes>
: ShapedContainerType<
allowedTypes,
And<[IsTensorTypePred, HasRankPred, IsSparseTensorPred]>,
"ranked sparse tensor",
"::mlir::TensorType">;
: RankedTensorOf<allowedTypes, [IsSparseTensorPred], "ranked sparse tensor">;

def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;

Expand Down
71 changes: 40 additions & 31 deletions mlir/include/mlir/IR/OpBase.td
Expand Up @@ -669,34 +669,29 @@ def AnyScalableVector : ScalableVectorOf<[AnyType]>;
def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
"::mlir::ShapedType">;

//===----------------------------------------------------------------------===//
// Tensor types.

// Any tensor type whose element type is from the given `allowedTypes` list
class TensorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor",
"::mlir::TensorType">;

class RankedTensorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, And<[IsTensorTypePred, HasRankPred]>,
"ranked tensor", "::mlir::TensorType">;
// Unranked tensor type whose element type is from the given
// `allowedTypes` list.
class UnrankedTensorOf<list<Type> allowedTypes>
: ShapedContainerType<allowedTypes, IsUnrankedTensorTypePred,
"unranked.tensor", "::mlir::UnrankedTensorType">;

def AnyTensor : TensorOf<[AnyType]>;

// Unranked Memref type
class UnrankedTensorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes,
IsUnrankedTensorTypePred,
"unranked.tensor", "::mlir::UnrankedTensorType">;

def AnyRankedTensor : RankedTensorOf<[AnyType]>;

// TODO: Have an easy way to add another constraint to a type.
class StaticShapeTensorOf<list<Type> allowedTypes>
: Type<And<[TensorOf<allowedTypes>.predicate, HasStaticShapePred]>,
"statically shaped " # TensorOf<allowedTypes>.summary,
"::mlir::TensorType">;

def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
// Any tensor type whose element type is from the given `allowedTypes`
// list, and which additionally satisfies an optional list of predicates.
//
// TODO: use `Constraint` instead of `Pred`, so we can generate a better
// default summary (a la `Confined`).
class TensorOf<
list<Type> allowedTypes,
list<Pred> preds = [],
string summary = "tensor">
: ShapedContainerType<allowedTypes,
And<!listconcat([IsTensorTypePred], preds)>,
summary, "::mlir::TensorType">;

def AnyTensor : TensorOf<[AnyType]>;

def I1Tensor : TensorOf<[I1]>;
def I8Tensor : TensorOf<[I8]>;
Expand All @@ -710,18 +705,34 @@ def F16Tensor : TensorOf<[F16]>;
def F32Tensor : TensorOf<[F32]>;
def F64Tensor : TensorOf<[F64]>;

class RankedTensorOf<
list<Type> allowedTypes,
list<Pred> preds = [],
string summary = "ranked tensor">
: TensorOf<allowedTypes, !listconcat([HasRankPred], preds), summary>;

def AnyRankedTensor : RankedTensorOf<[AnyType]>;

// Ranked tensor type with one of the specified types and ranks.
class TensorRankOf<list<Type> allowedTypes, list<int> ranks> :
Type<And<[TensorOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
TensorOf<allowedTypes>.summary, "::mlir::TensorType">;
class TensorRankOf<list<Type> allowedTypes, list<int> ranks>
: TensorOf<allowedTypes,
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;

class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>;
class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>;
class 2DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [2]>;
class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;

class StaticShapeTensorOf<list<Type> allowedTypes>
: TensorOf<allowedTypes, [HasStaticShapePred], "statically shaped tensor">;

def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;

//===----------------------------------------------------------------------===//
// Memref type.

// Unranked Memref type
class UnrankedMemRefOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes,
Expand All @@ -730,8 +741,6 @@ class UnrankedMemRefOf<list<Type> allowedTypes> :

def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>;

// Memref type.

// Memrefs are blocks of data with fixed type and rank.
class MemRefOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref",
Expand Down

0 comments on commit 0371ddf

Please sign in to comment.