diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 4733bfca8be21..7e9a80e7d73a1 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -203,25 +203,13 @@ Type CompositeType::getElementType(unsigned index) const { } unsigned CompositeType::getNumElements() const { - if (auto arrayType = llvm::dyn_cast(*this)) - return arrayType.getNumElements(); - if (auto matrixType = llvm::dyn_cast(*this)) - return matrixType.getNumColumns(); - if (auto structType = llvm::dyn_cast(*this)) - return structType.getNumElements(); - if (auto vectorType = llvm::dyn_cast(*this)) - return vectorType.getNumElements(); - if (auto tensorArmType = dyn_cast(*this)) - return tensorArmType.getNumElements(); - if (llvm::isa(*this)) { - llvm_unreachable( - "invalid to query number of elements of spirv Cooperative Matrix type"); - } - if (llvm::isa(*this)) { - llvm_unreachable( - "invalid to query number of elements of spirv::RuntimeArray type"); - } - llvm_unreachable("invalid composite type"); + return TypeSwitch(*this) + .Case( + [](auto type) { return type.getNumElements(); }) + .Case([](MatrixType type) { return type.getNumColumns(); }) + .Default([](SPIRVType) -> unsigned { + llvm_unreachable("Invalid type for number of elements query"); + }); } bool CompositeType::hasCompileTimeKnownNumElements() const {