Skip to content

Commit ee394e6

Browse files
committed
[MLIR] Add variadic isa<> for Type, Value, and Attribute
- Also adopt variadic llvm::isa<> in more places. - Fixes https://bugs.llvm.org/show_bug.cgi?id=46445 Differential Revision: https://reviews.llvm.org/D82769
1 parent 657ac8e commit ee394e6

File tree

28 files changed

+76
-68
lines changed

28 files changed

+76
-68
lines changed

mlir/docs/Tutorials/Toy/Ch-7.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
287287
return nullptr;
288288

289289
// Check that the type is either a TensorType or another StructType.
290-
if (!elementType.isa<mlir::TensorType>() &&
291-
!elementType.isa<StructType>()) {
290+
if (!elementType.isa<mlir::TensorType, StructType>()) {
292291
parser.emitError(typeLoc, "element type for a struct must either "
293292
"be a TensorType or a StructType, got: ")
294293
<< elementType;

mlir/examples/toy/Ch7/mlir/Dialect.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
510510
return nullptr;
511511

512512
// Check that the type is either a TensorType or another StructType.
513-
if (!elementType.isa<mlir::TensorType>() &&
514-
!elementType.isa<StructType>()) {
513+
if (!elementType.isa<mlir::TensorType, StructType>()) {
515514
parser.emitError(typeLoc, "element type for a struct must either "
516515
"be a TensorType or a StructType, got: ")
517516
<< elementType;

mlir/include/mlir/EDSC/Builders.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,12 @@ struct StructuredIndexed {
139139

140140
StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
141141
: value(v), exprs(indexings.begin(), indexings.end()) {
142-
assert((v.getType().isa<MemRefType>() ||
143-
v.getType().isa<RankedTensorType>() ||
144-
v.getType().isa<VectorType>()) &&
142+
assert((v.getType().isa<MemRefType, RankedTensorType, VectorType>()) &&
145143
"MemRef, RankedTensor or Vector expected");
146144
}
147145
StructuredIndexed(Type t, ArrayRef<AffineExpr> indexings)
148146
: type(t), exprs(indexings.begin(), indexings.end()) {
149-
assert((t.isa<MemRefType>() || t.isa<RankedTensorType>() ||
150-
t.isa<VectorType>()) &&
147+
assert((t.isa<MemRefType, RankedTensorType, VectorType>()) &&
151148
"MemRef, RankedTensor or Vector expected");
152149
}
153150

mlir/include/mlir/IR/Attributes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ class Attribute {
8585
bool operator!() const { return impl == nullptr; }
8686

8787
template <typename U> bool isa() const;
88+
template <typename First, typename Second, typename... Rest>
89+
bool isa() const;
8890
template <typename U> U dyn_cast() const;
8991
template <typename U> U dyn_cast_or_null() const;
9092
template <typename U> U cast() const;
@@ -1630,6 +1632,12 @@ template <typename U> bool Attribute::isa() const {
16301632
assert(impl && "isa<> used on a null attribute.");
16311633
return U::classof(*this);
16321634
}
1635+
1636+
template <typename First, typename Second, typename... Rest>
1637+
bool Attribute::isa() const {
1638+
return isa<First>() || isa<Second, Rest...>();
1639+
}
1640+
16331641
template <typename U> U Attribute::dyn_cast() const {
16341642
return isa<U>() ? U(impl) : U(nullptr);
16351643
}

mlir/include/mlir/IR/Matchers.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ struct constant_int_op_binder {
9797
return false;
9898
auto type = op->getResult(0).getType();
9999

100-
if (type.isa<IntegerType>() || type.isa<IndexType>())
100+
if (type.isa<IntegerType, IndexType>())
101101
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
102-
if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
102+
if (type.isa<VectorType, RankedTensorType>()) {
103103
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
104104
return attr_value_binder<IntegerAttr>(bind_value)
105105
.match(splatAttr.getSplatValue());

mlir/include/mlir/IR/StandardTypes.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ class VectorType
357357
/// Returns true of the given type can be used as an element of a vector type.
358358
/// In particular, vectors can consist of integer or float primitives.
359359
static bool isValidElementType(Type t) {
360-
return t.isa<IntegerType>() || t.isa<FloatType>();
360+
return t.isa<IntegerType, FloatType>();
361361
}
362362

363363
ArrayRef<int64_t> getShape() const;
@@ -381,9 +381,8 @@ class TensorType : public ShapedType {
381381
// Note: Non standard/builtin types are allowed to exist within tensor
382382
// types. Dialects are expected to verify that tensor types have a valid
383383
// element type within that dialect.
384-
return type.isa<ComplexType>() || type.isa<FloatType>() ||
385-
type.isa<IntegerType>() || type.isa<OpaqueType>() ||
386-
type.isa<VectorType>() || type.isa<IndexType>() ||
384+
return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
385+
IndexType>() ||
387386
(type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
388387
}
389388

mlir/include/mlir/IR/Types.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ class Type {
121121
bool operator!() const { return impl == nullptr; }
122122

123123
template <typename U> bool isa() const;
124+
template <typename First, typename Second, typename... Rest>
125+
bool isa() const;
124126
template <typename U> U dyn_cast() const;
125127
template <typename U> U dyn_cast_or_null() const;
126128
template <typename U> U cast() const;
@@ -271,6 +273,12 @@ template <typename U> bool Type::isa() const {
271273
assert(impl && "isa<> used on a null type.");
272274
return U::classof(*this);
273275
}
276+
277+
template <typename First, typename Second, typename... Rest>
278+
bool Type::isa() const {
279+
return isa<First>() || isa<Second, Rest...>();
280+
}
281+
274282
template <typename U> U Type::dyn_cast() const {
275283
return isa<U>() ? U(impl) : U(nullptr);
276284
}

mlir/include/mlir/IR/Value.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ class Value {
8181
assert(*this && "isa<> used on a null type.");
8282
return U::classof(*this);
8383
}
84+
85+
template <typename First, typename Second, typename... Rest>
86+
bool isa() const {
87+
return isa<First>() || isa<Second, Rest...>();
88+
}
89+
8490
template <typename U> U dyn_cast() const {
8591
return isa<U>() ? U(ownerAndKind) : U(nullptr);
8692
}

mlir/lib/Analysis/Utils.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -956,8 +956,7 @@ static Optional<int64_t> getMemoryFootprintBytes(Block &block,
956956

957957
// Walk this 'affine.for' operation to gather all memory regions.
958958
auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
959-
if (!isa<AffineReadOpInterface>(opInst) &&
960-
!isa<AffineWriteOpInterface>(opInst)) {
959+
if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
961960
// Neither load nor a store op.
962961
return WalkResult::advance();
963962
}
@@ -1017,11 +1016,9 @@ bool mlir::isLoopParallel(AffineForOp forOp) {
10171016
// Collect all load and store ops in loop nest rooted at 'forOp'.
10181017
SmallVector<Operation *, 8> loadAndStoreOpInsts;
10191018
auto walkResult = forOp.walk([&](Operation *opInst) -> WalkResult {
1020-
if (isa<AffineReadOpInterface>(opInst) ||
1021-
isa<AffineWriteOpInterface>(opInst))
1019+
if (isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst))
10221020
loadAndStoreOpInsts.push_back(opInst);
1023-
else if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
1024-
!isa<AffineIfOp>(opInst) &&
1021+
else if (!isa<AffineForOp, AffineTerminatorOp, AffineIfOp>(opInst) &&
10251022
!MemoryEffectOpInterface::hasNoEffect(opInst))
10261023
return WalkResult::interrupt();
10271024

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
302302
auto converted = convertType(t).dyn_cast_or_null<LLVM::LLVMType>();
303303
if (!converted)
304304
return {};
305-
if (t.isa<MemRefType>() || t.isa<UnrankedMemRefType>())
305+
if (t.isa<MemRefType, UnrankedMemRefType>())
306306
converted = converted.getPointerTo();
307307
inputs.push_back(converted);
308308
}
@@ -1044,7 +1044,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
10441044
FunctionType type, SmallVectorImpl<UnsignedTypePair> &argsInfo) const {
10451045
argsInfo.reserve(type.getNumInputs());
10461046
for (auto en : llvm::enumerate(type.getInputs())) {
1047-
if (en.value().isa<MemRefType>() || en.value().isa<UnrankedMemRefType>())
1047+
if (en.value().isa<MemRefType, UnrankedMemRefType>())
10481048
argsInfo.push_back({en.index(), en.value()});
10491049
}
10501050
}

0 commit comments

Comments
 (0)