diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td index c0c3d4920a1fe..76f408d76c955 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -100,20 +100,12 @@ def IsSparseTensorPred // `RankedTensorOf`, `AnyRankedTensor`. class SparseTensorOf allowedTypes> - : ShapedContainerType< - allowedTypes, - And<[IsTensorTypePred, IsSparseTensorPred]>, - "sparse tensor", - "::mlir::TensorType">; + : TensorOf; def AnySparseTensor : SparseTensorOf<[AnyType]>; class RankedSparseTensorOf allowedTypes> - : ShapedContainerType< - allowedTypes, - And<[IsTensorTypePred, HasRankPred, IsSparseTensorPred]>, - "ranked sparse tensor", - "::mlir::TensorType">; + : RankedTensorOf; def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 207ae2a1ef2b1..9f2ae6fd8b804 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -669,34 +669,29 @@ def AnyScalableVector : ScalableVectorOf<[AnyType]>; def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped", "::mlir::ShapedType">; +//===----------------------------------------------------------------------===// // Tensor types. -// Any tensor type whose element type is from the given `allowedTypes` list -class TensorOf allowedTypes> : - ShapedContainerType; - -class RankedTensorOf allowedTypes> : - ShapedContainerType, - "ranked tensor", "::mlir::TensorType">; +// Unranked tensor type whose element type is from the given +// `allowedTypes` list. +class UnrankedTensorOf allowedTypes> + : ShapedContainerType; -def AnyTensor : TensorOf<[AnyType]>; - -// Unranked Memref type -class UnrankedTensorOf allowedTypes> : - ShapedContainerType; - -def AnyRankedTensor : RankedTensorOf<[AnyType]>; - -// TODO: Have an easy way to add another constraint to a type. -class StaticShapeTensorOf allowedTypes> - : Type.predicate, HasStaticShapePred]>, - "statically shaped " # TensorOf.summary, - "::mlir::TensorType">; - -def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>; +// Any tensor type whose element type is from the given `allowedTypes` +// list, and which additionally satisfies an optional list of predicates. +// +// TODO: use `Constraint` instead of `Pred`, so we can generate a better +// default summary (a la `Confined`). +class TensorOf< + list allowedTypes, + list preds = [], + string summary = "tensor"> + : ShapedContainerType, + summary, "::mlir::TensorType">; + +def AnyTensor : TensorOf<[AnyType]>; def I1Tensor : TensorOf<[I1]>; def I8Tensor : TensorOf<[I8]>; @@ -710,11 +705,19 @@ def F16Tensor : TensorOf<[F16]>; def F32Tensor : TensorOf<[F32]>; def F64Tensor : TensorOf<[F64]>; +class RankedTensorOf< + list allowedTypes, + list preds = [], + string summary = "ranked tensor"> + : TensorOf; + +def AnyRankedTensor : RankedTensorOf<[AnyType]>; + // Ranked tensor type with one of the specified types and ranks. -class TensorRankOf allowedTypes, list ranks> : - Type.predicate, HasAnyRankOfPred]>, - !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # - TensorOf.summary, "::mlir::TensorType">; +class TensorRankOf allowedTypes, list ranks> + : TensorOf], + !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">; class 0DTensorOf allowedTypes> : TensorRankOf; class 1DTensorOf allowedTypes> : TensorRankOf; @@ -722,6 +725,14 @@ class 2DTensorOf allowedTypes> : TensorRankOf; class 3DTensorOf allowedTypes> : TensorRankOf; class 4DTensorOf allowedTypes> : TensorRankOf; +class StaticShapeTensorOf allowedTypes> + : TensorOf; + +def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>; + +//===----------------------------------------------------------------------===// +// Memref type. + // Unranked Memref type class UnrankedMemRefOf allowedTypes> : ShapedContainerType allowedTypes> : def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>; -// Memref type. - // Memrefs are blocks of data with fixed type and rank. class MemRefOf allowedTypes> : ShapedContainerType