diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index e46b576810316..aac5ef17370b2 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -407,7 +407,7 @@ class CooperativeMatrixType /// Returns the use parameter of the cooperative matrix. CooperativeMatrixUseKHR getUse() const; - operator ShapedType() const { return llvm::cast(*this); } + operator ShapedType() const { return cast(*this); } ArrayRef getShape() const; @@ -491,7 +491,7 @@ class TensorArmType Type getElementType() const; ArrayRef getShape() const; bool hasRank() const { return !getShape().empty(); } - operator ShapedType() const { return llvm::cast(*this); } + operator ShapedType() const { return cast(*this); } }; } // namespace spirv diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 56e8fee191432..c101a95685a25 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -422,10 +422,10 @@ struct VectorReductionPattern final : OpConversionPattern { #define INT_AND_FLOAT_CASE(kind, iop, fop) \ case vector::CombiningKind::kind: \ - if (llvm::isa(resultType)) { \ + if (isa(resultType)) { \ result = spirv::iop::create(rewriter, loc, resultType, result, next); \ } else { \ - assert(llvm::isa(resultType)); \ + assert(isa(resultType)); \ result = spirv::fop::create(rewriter, loc, resultType, result, next); \ } \ break diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp index 948d48980f2e8..7029268177128 100644 --- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp @@ -35,9 +35,9 @@ StringRef stringifyTypeName() { // Verifies an atomic update op. template static LogicalResult verifyAtomicUpdateOp(Operation *op) { - auto ptrType = llvm::cast(op->getOperand(0).getType()); + auto ptrType = cast(op->getOperand(0).getType()); auto elementType = ptrType.getPointeeType(); - if (!llvm::isa(elementType)) + if (!isa(elementType)) return op->emitOpError() << "pointer operand must point to an " << stringifyTypeName() << " value, found " << elementType; diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp index fcf4eb6fbcf60..a5330dc56d48f 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp @@ -87,13 +87,13 @@ LogicalResult BitcastOp::verify() { if (operandType == resultType) { return emitError("result type must be different from operand type"); } - if (llvm::isa(operandType) && - !llvm::isa(resultType)) { + if (isa(operandType) && + !isa(resultType)) { return emitError( "unhandled bit cast conversion from pointer type to non-pointer type"); } - if (!llvm::isa(operandType) && - llvm::isa(resultType)) { + if (!isa(operandType) && + isa(resultType)) { return emitError( "unhandled bit cast conversion from non-pointer type to pointer type"); } @@ -112,8 +112,8 @@ LogicalResult BitcastOp::verify() { //===----------------------------------------------------------------------===// LogicalResult ConvertPtrToUOp::verify() { - auto operandType = llvm::cast(getPointer().getType()); - auto resultType = llvm::cast(getResult().getType()); + auto operandType = cast(getPointer().getType()); + auto resultType = cast(getResult().getType()); if (!resultType || !resultType.isSignlessInteger()) return emitError("result must be a scalar type of unsigned integer"); auto spirvModule = (*this)->getParentOfType(); @@ -133,8 +133,8 @@ LogicalResult ConvertPtrToUOp::verify() { //===----------------------------------------------------------------------===// LogicalResult ConvertUToPtrOp::verify() { - auto operandType = llvm::cast(getOperand().getType()); - auto resultType = llvm::cast(getResult().getType()); + auto operandType = cast(getOperand().getType()); + auto resultType = cast(getResult().getType()); if (!operandType || !operandType.isSignlessInteger()) return emitError("result must be a scalar type of unsigned integer"); auto spirvModule = (*this)->getParentOfType(); @@ -154,8 +154,8 @@ LogicalResult ConvertUToPtrOp::verify() { //===----------------------------------------------------------------------===// LogicalResult PtrCastToGenericOp::verify() { - auto operandType = llvm::cast(getPointer().getType()); - auto resultType = llvm::cast(getResult().getType()); + auto operandType = cast(getPointer().getType()); + auto resultType = cast(getResult().getType()); spirv::StorageClass operandStorage = operandType.getStorageClass(); if (operandStorage != spirv::StorageClass::Workgroup && @@ -182,8 +182,8 @@ LogicalResult PtrCastToGenericOp::verify() { //===----------------------------------------------------------------------===// LogicalResult GenericCastToPtrOp::verify() { - auto operandType = llvm::cast(getPointer().getType()); - auto resultType = llvm::cast(getResult().getType()); + auto operandType = cast(getPointer().getType()); + auto resultType = cast(getResult().getType()); spirv::StorageClass operandStorage = operandType.getStorageClass(); if (operandStorage != spirv::StorageClass::Generic) @@ -210,8 +210,8 @@ LogicalResult GenericCastToPtrOp::verify() { //===----------------------------------------------------------------------===// LogicalResult GenericCastToPtrExplicitOp::verify() { - auto operandType = llvm::cast(getPointer().getType()); - auto resultType = llvm::cast(getResult().getType()); + auto operandType = cast(getPointer().getType()); + auto resultType = cast(getResult().getType()); spirv::StorageClass operandStorage = operandType.getStorageClass(); if (operandStorage != spirv::StorageClass::Generic) diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp index a846d7e60024c..4d0aedca27d42 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -138,7 +138,7 @@ LogicalResult BranchConditionalOp::verify() { return emitOpError("must have exactly two branch weights"); } if (llvm::all_of(*weights, [](Attribute attr) { - return llvm::cast(attr).getValue().isZero(); + return cast(attr).getValue().isZero(); })) return emitOpError("branch weights cannot both be zero"); } @@ -504,8 +504,8 @@ LogicalResult ReturnValueOp::verify() { //===----------------------------------------------------------------------===// LogicalResult SelectOp::verify() { - if (auto conditionTy = llvm::dyn_cast(getCondition().getType())) { - auto resultVectorTy = llvm::dyn_cast(getResult().getType()); + if (auto conditionTy = dyn_cast(getCondition().getType())) { + auto resultVectorTy = dyn_cast(getResult().getType()); if (!resultVectorTy) { return emitOpError("result expected to be of vector type when " "condition is of vector type"); diff --git a/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp index 01ef1bdc42515..dada8925b88e1 100644 --- a/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp @@ -74,10 +74,9 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) { Type factorTy = op->getOperand(0).getType(); StringAttr packedVectorFormatAttrName = IntegerDotProductOpTy::getFormatAttrName(op->getName()); - if (auto intTy = llvm::dyn_cast(factorTy)) { - auto packedVectorFormat = - llvm::dyn_cast_or_null( - op->getAttr(packedVectorFormatAttrName)); + if (auto intTy = dyn_cast(factorTy)) { + auto packedVectorFormat = dyn_cast_or_null( + op->getAttr(packedVectorFormatAttrName)); if (!packedVectorFormat) return op->emitOpError("requires Packed Vector Format attribute for " "integer vector operands"); @@ -135,8 +134,8 @@ getIntegerDotProductCapabilities(Operation *op) { Type factorTy = op->getOperand(0).getType(); StringAttr packedVectorFormatAttrName = IntegerDotProductOpTy::getFormatAttrName(op->getName()); - if (auto intTy = llvm::dyn_cast(factorTy)) { - auto formatAttr = llvm::cast( + if (auto intTy = dyn_cast(factorTy)) { + auto formatAttr = cast( op->getAttr(packedVectorFormatAttrName)); if (formatAttr.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit) @@ -145,7 +144,7 @@ getIntegerDotProductCapabilities(Operation *op) { return capabilities; } - auto vecTy = llvm::cast(factorTy); + auto vecTy = cast(factorTy); if (vecTy.getElementTypeBitWidth() == 8) { capabilities.push_back(dotProductInput4x8BitCap); return capabilities; diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp index 461d037134dae..a1bb7f89e9183 100644 --- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp @@ -65,7 +65,7 @@ LogicalResult GroupBroadcastOp::verify() { if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); - if (auto localIdTy = llvm::dyn_cast(getLocalid().getType())) + if (auto localIdTy = dyn_cast(getLocalid().getType())) if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3) return emitOpError("localid is a vector and can be with only " " 2 or 3 components, actual number is ") diff --git a/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp index 661f3d5d9b81d..6ea07330b70cb 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp @@ -89,8 +89,8 @@ static LogicalResult verifyImageOperands(Operation *imageOp, "floating-point type scalar"); auto samplingOp = cast(imageOp); - auto sampledImageType = llvm::cast( - samplingOp.getSampledImage().getType()); + auto sampledImageType = + cast(samplingOp.getSampledImage().getType()); imageType = cast(sampledImageType.getImageType()); } else { if (!isa(operands[index].getType())) @@ -243,8 +243,7 @@ LogicalResult spirv::ImageWriteOp::verify() { //===----------------------------------------------------------------------===// LogicalResult spirv::ImageQuerySizeOp::verify() { - spirv::ImageType imageType = - llvm::cast(getImage().getType()); + spirv::ImageType imageType = cast(getImage().getType()); Type resultType = getResult().getType(); spirv::Dim dim = imageType.getDim(); @@ -292,7 +291,7 @@ LogicalResult spirv::ImageQuerySizeOp::verify() { componentNumber += 1; unsigned resultComponentNumber = 1; - if (auto resultVectorType = llvm::dyn_cast(resultType)) + if (auto resultVectorType = dyn_cast(resultType)) resultComponentNumber = resultVectorType.getNumElements(); if (componentNumber != resultComponentNumber) diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp index 5ae27e5d82bd7..e3187d3dc1901 100644 --- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp @@ -166,7 +166,7 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, // TODO: Check that the value type satisfies restrictions of // SPIR-V OpLoad/OpStore operations if (val.getType() != - llvm::cast(ptr.getType()).getPointeeType()) { + cast(ptr.getType()).getPointeeType()) { return op.emitOpError("mismatch in result type and pointer type"); } return success(); @@ -190,7 +190,7 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) { return success(); } - auto memAccess = llvm::cast(memAccessAttr); + auto memAccess = cast(memAccessAttr); if (!memAccess) { return memoryOp.emitOpError("invalid memory access specifier: ") @@ -234,7 +234,7 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) { return success(); } - auto memAccess = llvm::cast(memAccessAttr); + auto memAccess = cast(memAccessAttr); if (!memAccess) { return memoryOp.emitOpError("invalid memory access specifier: ") @@ -261,7 +261,7 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) { //===----------------------------------------------------------------------===// static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { - auto ptrType = llvm::dyn_cast(type); + auto ptrType = dyn_cast(type); if (!ptrType) { emitError(baseLoc, "'spirv.AccessChain' op expected a pointer " "to composite type, but provided ") @@ -274,7 +274,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { int32_t index = 0; for (auto indexSSA : indices) { - auto cType = llvm::dyn_cast(resultType); + auto cType = dyn_cast(resultType); if (!cType) { emitError( baseLoc, @@ -283,7 +283,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { return nullptr; } index = 0; - if (llvm::isa(resultType)) { + if (isa(resultType)) { Operation *op = indexSSA.getDefiningOp(); if (!op) { emitError(baseLoc, "'spirv.AccessChain' op index must be an " @@ -334,7 +334,7 @@ static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) { return failure(); auto providedResultType = - llvm::dyn_cast(accessChainOp.getType()); + dyn_cast(accessChainOp.getType()); if (!providedResultType) return accessChainOp.emitOpError( "result type must be a pointer, but provided") @@ -357,7 +357,7 @@ LogicalResult AccessChainOp::verify() { void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr, MemoryAccessAttr memoryAccess, IntegerAttr alignment) { - auto ptrType = llvm::cast(basePtr.getType()); + auto ptrType = cast(basePtr.getType()); build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess, alignment); } @@ -386,7 +386,7 @@ ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { void LoadOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( - llvm::cast(getPtr().getType()).getStorageClass()); + cast(getPtr().getType()).getStorageClass()); printer << " \"" << sc << "\" " << getPtr(); printMemoryAccessAttribute(*this, printer, elidedAttrs); @@ -433,7 +433,7 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { void StoreOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( - llvm::cast(getPtr().getType()).getStorageClass()); + cast(getPtr().getType()).getStorageClass()); printer << " \"" << sc << "\" " << getPtr() << ", " << getValue(); printMemoryAccessAttribute(*this, printer, elidedAttrs); @@ -458,11 +458,11 @@ void CopyMemoryOp::print(OpAsmPrinter &printer) { printer << ' '; StringRef targetStorageClass = stringifyStorageClass( - llvm::cast(getTarget().getType()).getStorageClass()); + cast(getTarget().getType()).getStorageClass()); printer << " \"" << targetStorageClass << "\" " << getTarget() << ", "; StringRef sourceStorageClass = stringifyStorageClass( - llvm::cast(getSource().getType()).getStorageClass()); + cast(getSource().getType()).getStorageClass()); printer << " \"" << sourceStorageClass << "\" " << getSource(); SmallVector elidedAttrs; @@ -474,7 +474,7 @@ void CopyMemoryOp::print(OpAsmPrinter &printer) { printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); Type pointeeType = - llvm::cast(getTarget().getType()).getPointeeType(); + cast(getTarget().getType()).getPointeeType(); printer << " : " << pointeeType; } @@ -521,10 +521,10 @@ ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) { LogicalResult CopyMemoryOp::verify() { Type targetType = - llvm::cast(getTarget().getType()).getPointeeType(); + cast(getTarget().getType()).getPointeeType(); Type sourceType = - llvm::cast(getSource().getType()).getPointeeType(); + cast(getSource().getType()).getPointeeType(); if (targetType != sourceType) return emitOpError("both operands must be pointers to the same type"); @@ -600,7 +600,7 @@ ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseType(type)) return failure(); - auto ptrType = llvm::dyn_cast(type); + auto ptrType = dyn_cast(type); if (!ptrType) return parser.emitError(loc, "expected spirv.ptr type"); result.addTypes(ptrType); @@ -640,7 +640,7 @@ LogicalResult VariableOp::verify() { "spirv.GlobalVariable for module-level variables."); } - auto pointerType = llvm::cast(getPointer().getType()); + auto pointerType = cast(getPointer().getType()); if (getStorageClass() != pointerType.getStorageClass()) return emitOpError( "storage class must match result pointer's storage class"); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp index 2ba6106896c1f..f1940091ca238 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp @@ -146,20 +146,18 @@ StringRef spirv::InterfaceVarABIAttr::getKindName() { } uint32_t spirv::InterfaceVarABIAttr::getBinding() { - return llvm::cast(getImpl()->binding).getInt(); + return cast(getImpl()->binding).getInt(); } uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() { - return llvm::cast(getImpl()->descriptorSet).getInt(); + return cast(getImpl()->descriptorSet).getInt(); } std::optional spirv::InterfaceVarABIAttr::getStorageClass() { if (getImpl()->storageClass) return static_cast( - llvm::cast(getImpl()->storageClass) - .getValue() - .getZExtValue()); + cast(getImpl()->storageClass).getValue().getZExtValue()); return std::nullopt; } @@ -173,7 +171,7 @@ LogicalResult spirv::InterfaceVarABIAttr::verifyInvariants( return emitError() << "expected 32-bit integer for binding"; if (storageClass) { - if (auto storageClassAttr = llvm::cast(storageClass)) { + if (auto storageClassAttr = cast(storageClass)) { auto storageClassValue = spirv::symbolizeStorageClass(storageClassAttr.getInt()); if (!storageClassValue) @@ -222,14 +220,14 @@ StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; } spirv::Version spirv::VerCapExtAttr::getVersion() { return static_cast( - llvm::cast(getImpl()->version).getValue().getZExtValue()); + cast(getImpl()->version).getValue().getZExtValue()); } spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it) : llvm::mapped_iterator( it, [](Attribute attr) { - return *symbolizeExtension(llvm::cast(attr).getValue()); + return *symbolizeExtension(cast(attr).getValue()); }) {} spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() { @@ -238,7 +236,7 @@ spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() { } ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() { - return llvm::cast(getImpl()->extensions); + return cast(getImpl()->extensions); } spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it) @@ -246,7 +244,7 @@ spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it) spirv::Capability (*)(Attribute)>( it, [](Attribute attr) { return *symbolizeCapability( - llvm::cast(attr).getValue().getZExtValue()); + cast(attr).getValue().getZExtValue()); }) {} spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() { @@ -255,7 +253,7 @@ spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() { } ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() { - return llvm::cast(getImpl()->capabilities); + return cast(getImpl()->capabilities); } LogicalResult spirv::VerCapExtAttr::verifyInvariants( @@ -265,7 +263,7 @@ LogicalResult spirv::VerCapExtAttr::verifyInvariants( return emitError() << "expected 32-bit integer for version"; if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) { - if (auto intAttr = llvm::dyn_cast(attr)) + if (auto intAttr = dyn_cast(attr)) if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue())) return true; return false; @@ -273,7 +271,7 @@ LogicalResult spirv::VerCapExtAttr::verifyInvariants( return emitError() << "unknown capability in capability list"; if (!llvm::all_of(extensions.getValue(), [](Attribute attr) { - if (auto strAttr = llvm::dyn_cast(attr)) + if (auto strAttr = dyn_cast(attr)) if (spirv::symbolizeExtension(strAttr.getValue())) return true; return false; @@ -299,7 +297,7 @@ spirv::TargetEnvAttr spirv::TargetEnvAttr::get( StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; } spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const { - return llvm::cast(getImpl()->triple); + return cast(getImpl()->triple); } spirv::Version spirv::TargetEnvAttr::getVersion() const { @@ -339,7 +337,7 @@ uint32_t spirv::TargetEnvAttr::getDeviceID() const { } spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const { - return llvm::cast(getImpl()->limits); + return cast(getImpl()->limits); } //===----------------------------------------------------------------------===// @@ -668,11 +666,11 @@ void SPIRVDialect::printAttribute(Attribute attr, if (succeeded(generatedAttributePrinter(attr, printer))) return; - if (auto targetEnv = llvm::dyn_cast(attr)) + if (auto targetEnv = dyn_cast(attr)) print(targetEnv, printer); - else if (auto vceAttr = llvm::dyn_cast(attr)) + else if (auto vceAttr = dyn_cast(attr)) print(vceAttr, printer); - else if (auto interfaceVarABIAttr = llvm::dyn_cast(attr)) + else if (auto interfaceVarABIAttr = dyn_cast(attr)) print(interfaceVarABIAttr, printer); else llvm_unreachable("unhandled SPIR-V attribute kind"); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index ccc85368c78a4..9ab3bdc6ab102 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -35,9 +35,9 @@ static std::optional getScalarOrSplatBoolAttr(Attribute attr) { if (!attr) return std::nullopt; - if (auto boolAttr = llvm::dyn_cast(attr)) + if (auto boolAttr = dyn_cast(attr)) return boolAttr.getValue(); - if (auto splatAttr = llvm::dyn_cast(attr)) + if (auto splatAttr = dyn_cast(attr)) if (splatAttr.getElementType().isInteger(1)) return splatAttr.getSplatValue(); return std::nullopt; @@ -54,12 +54,12 @@ static Attribute extractCompositeElement(Attribute composite, if (indices.empty()) return composite; - if (auto vector = llvm::dyn_cast(composite)) { + if (auto vector = dyn_cast(composite)) { assert(indices.size() == 1 && "must have exactly one index for a vector"); return vector.getValues()[indices[0]]; } - if (auto array = llvm::dyn_cast(composite)) { + if (auto array = dyn_cast(composite)) { assert(!indices.empty() && "must have at least one index for an array"); return extractCompositeElement(array.getValue()[indices[0]], indices.drop_front()); @@ -370,7 +370,7 @@ struct UModSimplification final : OpRewritePattern { void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.insert(context); + patterns.add(context); } //===----------------------------------------------------------------------===// @@ -412,10 +412,10 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) { if (auto constructOp = compositeOp.getDefiningOp()) { - auto type = llvm::cast(constructOp.getType()); + auto type = cast(constructOp.getType()); if (getIndices().size() == 1 && constructOp.getConstituents().size() == type.getNumElements()) { - auto i = llvm::cast(*getIndices().begin()); + auto i = cast(*getIndices().begin()); if (i.getValue().getSExtValue() < static_cast(constructOp.getConstituents().size())) return constructOp.getConstituents()[i.getValue().getSExtValue()]; @@ -423,7 +423,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) { } auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) { - return static_cast(llvm::cast(attr).getInt()); + return static_cast(cast(attr).getInt()); }); return extractCompositeElement(adaptor.getComposite(), indexVector); } @@ -1379,7 +1379,7 @@ LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection( // Starting with version 1.4, Result Type can additionally be a composite type // other than a vector." bool isScalarOrVector = - llvm::cast(trueBrStoreOp.getValue().getType()) + cast(trueBrStoreOp.getValue().getType()) .isScalarOrVector(); // Check that each `spirv.Store` uses the same pointer, memory access diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index 24c33f9ae1b90..22b57d6c0821a 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -172,16 +172,16 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, return type; // Check other allowed types - if (auto t = llvm::dyn_cast(type)) { + if (auto t = dyn_cast(type)) { // TODO: All float types are allowed for now, but this should be fixed. - } else if (auto t = llvm::dyn_cast(type)) { + } else if (auto t = dyn_cast(type)) { if (!ScalarType::isValid(t)) { parser.emitError(typeLoc, "only 1/8/16/32/64-bit integer type allowed but found ") << type; return Type(); } - } else if (auto t = llvm::dyn_cast(type)) { + } else if (auto t = dyn_cast(type)) { if (t.getRank() != 1) { parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; return Type(); @@ -215,7 +215,7 @@ static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, if (parser.parseType(type)) return Type(); - if (auto t = llvm::dyn_cast(type)) { + if (auto t = dyn_cast(type)) { if (t.getRank() != 1) { parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; return Type(); @@ -228,7 +228,7 @@ static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, return Type(); } - if (!llvm::isa(t.getElementType())) { + if (!isa(t.getElementType())) { parser.emitError(typeLoc, "matrix columns' elements must be of " "Float type, got ") << t.getElementType(); @@ -1016,12 +1016,12 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op, Attribute attr = attribute.getValue(); if (symbol == spirv::getEntryPointABIAttrName()) { - if (!llvm::isa(attr)) { + if (!isa(attr)) { return op->emitError("'") << symbol << "' attribute must be an entry point ABI attribute"; } } else if (symbol == spirv::getTargetEnvAttrName()) { - if (!llvm::isa(attr)) + if (!isa(attr)) return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr"; } else { return op->emitError("found unsupported '") @@ -1039,7 +1039,7 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType, Attribute attr = attribute.getValue(); if (symbol == spirv::getInterfaceVarABIAttrName()) { - auto varABIAttr = llvm::dyn_cast(attr); + auto varABIAttr = dyn_cast(attr); if (!varABIAttr) return emitError(loc, "'") << symbol << "' must be a spirv::InterfaceVarABIAttr"; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp index 8575487ff52cc..ba69fa75cf2b8 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp @@ -53,7 +53,7 @@ static bool isDirectInModuleLikeOp(Operation *op) { static Type getUnaryOpResultType(Type operandType) { Builder builder(operandType.getContext()); Type resultType = builder.getIntegerType(1); - if (auto vecType = llvm::dyn_cast(operandType)) + if (auto vecType = dyn_cast(operandType)) return VectorType::get(vecType.getNumElements(), resultType); return resultType; } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 938952ed273cd..1962538d804a8 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -52,7 +52,7 @@ LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) { return failure(); } auto valueAttr = constOp.getValue(); - auto integerValueAttr = llvm::dyn_cast(valueAttr); + auto integerValueAttr = dyn_cast(valueAttr); if (!integerValueAttr) { return failure(); } @@ -129,7 +129,7 @@ static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(type)) return failure(); - auto fnType = llvm::dyn_cast(type); + auto fnType = dyn_cast(type); if (!fnType) { parser.emitError(loc, "expected function type"); return failure(); @@ -169,11 +169,10 @@ template static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val) { auto valType = val.getType(); - if (auto valVecTy = llvm::dyn_cast(valType)) + if (auto valVecTy = dyn_cast(valType)) valType = valVecTy.getElementType(); - if (valType != - llvm::cast(ptr.getType()).getPointeeType()) { + if (valType != cast(ptr.getType()).getPointeeType()) { return op.emitOpError("mismatch in result type and pointer type"); } return success(); @@ -191,7 +190,7 @@ getElementType(Type type, ArrayRef indices, } for (auto index : indices) { - if (auto cType = llvm::dyn_cast(type)) { + if (auto cType = dyn_cast(type)) { if (cType.hasCompileTimeKnownNumElements() && (index < 0 || static_cast(index) >= cType.getNumElements())) { @@ -211,7 +210,7 @@ getElementType(Type type, ArrayRef indices, static Type getElementType(Type type, Attribute indices, function_ref emitErrorFn) { - auto indicesArrayAttr = llvm::dyn_cast(indices); + auto indicesArrayAttr = dyn_cast(indices); if (!indicesArrayAttr) { emitErrorFn("expected a 32-bit integer array attribute for 'indices'"); return nullptr; @@ -223,7 +222,7 @@ getElementType(Type type, Attribute indices, SmallVector indexVals; for (auto indexAttr : indicesArrayAttr) { - auto indexIntAttr = llvm::dyn_cast(indexAttr); + auto indexIntAttr = dyn_cast(indexAttr); if (!indexIntAttr) { emitErrorFn("expected an 32-bit integer for index, but found '") << indexAttr << "'"; @@ -251,7 +250,7 @@ static Type getElementType(Type type, Attribute indices, OpAsmParser &parser, template static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) { - auto resultType = llvm::cast(op.getType()); + auto resultType = cast(op.getType()); if (resultType.getNumElements() != 2) return op.emitOpError("expected result struct type containing two members"); @@ -276,7 +275,7 @@ static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, if (parser.parseType(resultType)) return failure(); - auto structType = llvm::dyn_cast(resultType); + auto structType = dyn_cast(resultType); if (!structType || structType.getNumElements() != 2) return parser.emitError(loc, "expected spirv.struct type with two members"); @@ -361,7 +360,7 @@ LogicalResult spirv::CompositeConstructOp::verify() { } // Case 2./3./4. -- number of constituents matches the number of elements. - auto cType = llvm::cast(getType()); + auto cType = cast(getType()); if (constituents.size() == cType.getNumElements()) { for (auto index : llvm::seq(0, constituents.size())) { if (constituents[index].getType() != cType.getElementType(index)) { @@ -374,7 +373,7 @@ LogicalResult spirv::CompositeConstructOp::verify() { } // Case 4. -- check that all constituents add up tp the expected vector type. - auto resultType = llvm::dyn_cast(cType); + auto resultType = dyn_cast(cType); if (!resultType) return emitOpError( "expected to return a vector or cooperative matrix when the number of " @@ -382,14 +381,14 @@ LogicalResult spirv::CompositeConstructOp::verify() { SmallVector sizes; for (Value component : constituents) { - if (!llvm::isa(component.getType()) && + if (!isa(component.getType()) && !component.getType().isIntOrFloat()) return emitOpError("operand type mismatch: expected operand to have " "a scalar or vector type, but provided ") << component.getType(); Type elementType = component.getType(); - if (auto vectorType = llvm::dyn_cast(component.getType())) { + if (auto vectorType = dyn_cast(component.getType())) { sizes.push_back(vectorType.getNumElements()); elementType = vectorType.getElementType(); } else { @@ -455,7 +454,7 @@ void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) { } LogicalResult spirv::CompositeExtractOp::verify() { - auto indicesArrayAttr = llvm::dyn_cast(getIndices()); + auto indicesArrayAttr = dyn_cast(getIndices()); auto resultType = getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); if (!resultType) @@ -500,7 +499,7 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser, } LogicalResult spirv::CompositeInsertOp::verify() { - auto indicesArrayAttr = llvm::dyn_cast(getIndices()); + auto indicesArrayAttr = dyn_cast(getIndices()); auto objectType = getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); if (!objectType) @@ -538,14 +537,14 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser, return failure(); Type type = NoneType::get(parser.getContext()); - if (auto typedAttr = llvm::dyn_cast(value)) + if (auto typedAttr = dyn_cast(value)) type = typedAttr.getType(); - if (llvm::isa(type)) { + if (isa(type)) { if (parser.parseColonType(type)) return failure(); } - if (llvm::isa(type)) { + if (isa(type)) { if (parser.parseOptionalColon().succeeded()) if (parser.parseType(type)) return failure(); @@ -556,7 +555,7 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser, void spirv::ConstantOp::print(OpAsmPrinter &printer) { printer << ' ' << getValue(); - if (llvm::isa(getType())) + if (isa(getType())) printer << " : " << getType(); } @@ -569,19 +568,19 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, "matrix constant, but found ") << denseAttr; } - if (llvm::isa(value)) { - auto valueType = llvm::cast(value).getType(); + if (isa(value)) { + auto valueType = cast(value).getType(); if (valueType != opType) return op.emitOpError("result type (") << opType << ") does not match value type (" << valueType << ")"; return success(); } - if (llvm::isa(value)) { - auto valueType = llvm::cast(value).getType(); + if (isa(value)) { + auto valueType = cast(value).getType(); if (valueType == opType) return success(); - auto arrayType = llvm::dyn_cast(opType); - auto shapedType = llvm::dyn_cast(valueType); + auto arrayType = dyn_cast(opType); + auto shapedType = dyn_cast(valueType); if (!arrayType) return op.emitOpError("result or element type (") << opType << ") does not match value type (" << valueType @@ -589,7 +588,7 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, int numElements = arrayType.getNumElements(); auto opElemType = arrayType.getElementType(); - while (auto t = llvm::dyn_cast(opElemType)) { + while (auto t = dyn_cast(opElemType)) { numElements *= t.getNumElements(); opElemType = t.getElementType(); } @@ -610,8 +609,8 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, } return success(); } - if (auto arrayAttr = llvm::dyn_cast(value)) { - auto arrayType = llvm::dyn_cast(opType); + if (auto arrayAttr = dyn_cast(value)) { + auto arrayType = dyn_cast(opType); if (!arrayType) return op.emitOpError( "must have spirv.array result type for array value"); @@ -635,12 +634,12 @@ LogicalResult spirv::ConstantOp::verify() { bool spirv::ConstantOp::isBuildableWith(Type type) { // Must be valid SPIR-V type first. - if (!llvm::isa(type)) + if (!isa(type)) return false; if (isa(type.getDialect())) { // TODO: support constant struct - return llvm::isa(type); + return isa(type); } return true; @@ -648,7 +647,7 @@ bool spirv::ConstantOp::isBuildableWith(Type type) { spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, OpBuilder &builder) { - if (auto intType = llvm::dyn_cast(type)) { + if (auto intType = dyn_cast(type)) { unsigned width = intType.getWidth(); if (width == 1) return spirv::ConstantOp::create(builder, loc, type, @@ -656,19 +655,19 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, return spirv::ConstantOp::create( builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0))); } - if (auto floatType = llvm::dyn_cast(type)) { + if (auto floatType = dyn_cast(type)) { return spirv::ConstantOp::create(builder, loc, type, builder.getFloatAttr(floatType, 0.0)); } - if (auto vectorType = llvm::dyn_cast(type)) { + if (auto vectorType = dyn_cast(type)) { Type elemType = vectorType.getElementType(); - if (llvm::isa(elemType)) { + if (isa(elemType)) { return spirv::ConstantOp::create( builder, loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 0).getValue())); } - if (llvm::isa(elemType)) { + if (isa(elemType)) { return spirv::ConstantOp::create( builder, loc, type, DenseFPElementsAttr::get(vectorType, @@ -681,7 +680,7 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, OpBuilder &builder) { - if (auto intType = llvm::dyn_cast(type)) { + if (auto intType = dyn_cast(type)) { unsigned width = intType.getWidth(); if (width == 1) return spirv::ConstantOp::create(builder, loc, type, @@ -689,19 +688,19 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, return spirv::ConstantOp::create( builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1))); } - if (auto floatType = llvm::dyn_cast(type)) { + if (auto floatType = dyn_cast(type)) { return spirv::ConstantOp::create(builder, loc, type, builder.getFloatAttr(floatType, 1.0)); } - if (auto vectorType = llvm::dyn_cast(type)) { + if (auto vectorType = dyn_cast(type)) { Type elemType = vectorType.getElementType(); - if (llvm::isa(elemType)) { + if (isa(elemType)) { return spirv::ConstantOp::create( builder, loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 1).getValue())); } - if (llvm::isa(elemType)) { + if (isa(elemType)) { return spirv::ConstantOp::create( builder, loc, type, DenseFPElementsAttr::get(vectorType, @@ -720,9 +719,9 @@ void mlir::spirv::ConstantOp::getAsmResultNames( llvm::raw_svector_ostream specialName(specialNameBuffer); specialName << "cst"; - IntegerType intTy = llvm::dyn_cast(type); + IntegerType intTy = dyn_cast(type); - if (IntegerAttr intCst = llvm::dyn_cast(getValue())) { + if (IntegerAttr intCst = dyn_cast(getValue())) { assert(intTy); if (intTy.getWidth() == 1) { @@ -738,18 +737,17 @@ void mlir::spirv::ConstantOp::getAsmResultNames( } } - if (intTy || llvm::isa(type)) { + if (intTy || isa(type)) { specialName << '_' << type; } - if (auto vecType = llvm::dyn_cast(type)) { + if (auto vecType = dyn_cast(type)) { specialName << "_vec_"; specialName << vecType.getDimSize(0); Type elementType = vecType.getElementType(); - if (llvm::isa(elementType) || - llvm::isa(elementType)) { + if (isa(elementType) || isa(elementType)) { specialName << "x" << elementType; } } @@ -903,7 +901,7 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser, if (parser.parseAttribute(value, i32Type, "value", attr)) { return failure(); } - values.push_back(llvm::cast(value).getInt()); + values.push_back(cast(value).getInt()); } StringRef valuesAttrName = spirv::ExecutionModeOp::getValuesAttrName(result.name); @@ -1005,7 +1003,7 @@ LogicalResult spirv::FuncOp::verifyType() { auto hasDecorationAttr = [&](spirv::Decoration decoration, unsigned argIndex) { - auto func = llvm::cast(getOperation()); + auto func = cast(getOperation()); for (auto argAttr : cast(func).getArgAttrs(argIndex)) { if (argAttr.getName() != spirv::DecorationAttr::name) continue; @@ -1224,7 +1222,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser, if (parser.parseColonType(type)) { return failure(); } - if (!llvm::isa(type)) { + if (!isa(type)) { return parser.emitError(loc, "expected spirv.ptr type"); } result.addAttribute(typeAttrName, TypeAttr::get(type)); @@ -1257,7 +1255,7 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) { } LogicalResult spirv::GlobalVariableOp::verify() { - if (!llvm::isa(getType())) + if (!isa(getType())) return emitOpError("result must be of a !spv.ptr type"); // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the @@ -1325,7 +1323,7 @@ ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser, } auto ptrType = spirv::PointerType::get(elementType, storageClass); - if (auto valVecTy = llvm::dyn_cast(elementType)) + if (auto valVecTy = dyn_cast(elementType)) ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, @@ -1539,7 +1537,7 @@ LogicalResult spirv::ModuleOp::verifyRegions() { } if (auto interface = entryPointOp.getInterface()) { for (Attribute varRef : interface) { - auto varSymRef = llvm::dyn_cast(varRef); + auto varSymRef = dyn_cast(varRef); if (!varSymRef) { return entryPointOp.emitError( "expected symbol reference for interface " @@ -1660,9 +1658,9 @@ LogicalResult spirv::SpecConstantOp::verify() { return emitOpError("SpecId cannot be negative"); auto value = getDefaultValue(); - if (llvm::isa(value)) { + if (isa(value)) { // Make sure bitwidth is allowed. - if (!llvm::isa(value.getType())) + if (!isa(value.getType())) return emitOpError("default value bitwidth disallowed"); return success(); } @@ -1675,7 +1673,7 @@ LogicalResult spirv::SpecConstantOp::verify() { //===----------------------------------------------------------------------===// LogicalResult spirv::VectorShuffleOp::verify() { - VectorType resultType = llvm::cast(getType()); + VectorType resultType = cast(getType()); size_t numResultElements = resultType.getNumElements(); if (numResultElements != getComponents().size()) @@ -1685,8 +1683,8 @@ LogicalResult spirv::VectorShuffleOp::verify() { << getComponents().size() << ")"; size_t totalSrcElements = - llvm::cast(getVector1().getType()).getNumElements() + - llvm::cast(getVector2().getType()).getNumElements(); + cast(getVector1().getType()).getNumElements() + + cast(getVector2().getType()).getNumElements(); for (const auto &selector : getComponents().getAsValueRange()) { uint32_t index = selector.getZExtValue(); @@ -1725,8 +1723,8 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() { //===----------------------------------------------------------------------===// LogicalResult spirv::TransposeOp::verify() { - auto inputMatrix = llvm::cast(getMatrix().getType()); - auto resultMatrix = llvm::cast(getResult().getType()); + auto inputMatrix = cast(getMatrix().getType()); + auto resultMatrix = cast(getResult().getType()); // Verify that the input and output matrices have correct shapes. if (inputMatrix.getNumRows() != resultMatrix.getNumColumns()) @@ -1750,9 +1748,9 @@ LogicalResult spirv::TransposeOp::verify() { //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesVectorOp::verify() { - auto matrixType = llvm::cast(getMatrix().getType()); - auto vectorType = llvm::cast(getVector().getType()); - auto resultType = llvm::cast(getType()); + auto matrixType = cast(getMatrix().getType()); + auto vectorType = cast(getVector().getType()); + auto resultType = cast(getType()); if (matrixType.getNumColumns() != vectorType.getNumElements()) return emitOpError("matrix columns (") @@ -1775,9 +1773,9 @@ LogicalResult spirv::MatrixTimesVectorOp::verify() { //===----------------------------------------------------------------------===// LogicalResult spirv::VectorTimesMatrixOp::verify() { - auto vectorType = llvm::cast(getVector().getType()); - auto matrixType = llvm::cast(getMatrix().getType()); - auto resultType = llvm::cast(getType()); + auto vectorType = cast(getVector().getType()); + auto matrixType = cast(getMatrix().getType()); + auto resultType = cast(getType()); if (matrixType.getNumRows() != vectorType.getNumElements()) return emitOpError("number of components in vector must equal the number " @@ -1799,9 +1797,9 @@ LogicalResult spirv::VectorTimesMatrixOp::verify() { //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesMatrixOp::verify() { - auto leftMatrix = llvm::cast(getLeftmatrix().getType()); - auto rightMatrix = llvm::cast(getRightmatrix().getType()); - auto resultMatrix = llvm::cast(getResult().getType()); + auto leftMatrix = cast(getLeftmatrix().getType()); + auto rightMatrix = cast(getRightmatrix().getType()); + auto resultMatrix = cast(getResult().getType()); // left matrix columns' count and right matrix rows' count must be equal if (leftMatrix.getNumColumns() != rightMatrix.getNumRows()) @@ -1886,14 +1884,14 @@ void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) { } LogicalResult spirv::SpecConstantCompositeOp::verify() { - auto cType = llvm::dyn_cast(getType()); + auto cType = dyn_cast(getType()); auto constituents = this->getConstituents().getValue(); if (!cType) return emitError("result type must be a composite type, but provided ") << getType(); - if (llvm::isa(cType)) + if (isa(cType)) return emitError("unsupported composite type ") << cType; if (constituents.size() != cType.getNumElements()) return emitError("has incorrect number of operands: expected ") @@ -1901,7 +1899,7 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() { << constituents.size(); for (auto index : llvm::seq(0, constituents.size())) { - auto constituent = llvm::cast(constituents[index]); + auto constituent = cast(constituents[index]); auto constituentSpecConstOp = dyn_cast(SymbolTable::lookupNearestSymbolFrom( @@ -2042,19 +2040,19 @@ LogicalResult spirv::SpecConstantOperationOp::verifyRegions() { LogicalResult spirv::GLFrexpStructOp::verify() { spirv::StructType structTy = - llvm::dyn_cast(getResult().getType()); + dyn_cast(getResult().getType()); if (structTy.getNumElements() != 2) return emitError("result type must be a struct type with two memebers"); Type significandTy = structTy.getElementType(0); Type exponentTy = structTy.getElementType(1); - VectorType exponentVecTy = llvm::dyn_cast(exponentTy); - IntegerType exponentIntTy = llvm::dyn_cast(exponentTy); + VectorType exponentVecTy = dyn_cast(exponentTy); + IntegerType exponentIntTy = dyn_cast(exponentTy); Type operandTy = getOperand().getType(); - VectorType operandVecTy = llvm::dyn_cast(operandTy); - FloatType operandFTy = llvm::dyn_cast(operandTy); + VectorType operandVecTy = dyn_cast(operandTy); + FloatType operandFTy = dyn_cast(operandTy); if (significandTy != operandTy) return emitError("member zero of the resulting struct type must be the " @@ -2062,7 +2060,7 @@ LogicalResult spirv::GLFrexpStructOp::verify() { if (exponentVecTy) { IntegerType componentIntTy = - llvm::dyn_cast(exponentVecTy.getElementType()); + dyn_cast(exponentVecTy.getElementType()); if (!componentIntTy || componentIntTy.getWidth() != 32) return emitError("member one of the resulting struct type must" "be a scalar or vector of 32 bit integer type"); @@ -2091,12 +2089,11 @@ LogicalResult spirv::GLLdexpOp::verify() { Type significandType = getX().getType(); Type exponentType = getExp().getType(); - if (llvm::isa(significandType) != - llvm::isa(exponentType)) + if (isa(significandType) != isa(exponentType)) return emitOpError("operands must both be scalars or vectors"); auto getNumElements = [](Type type) -> unsigned { - if (auto vectorType = llvm::dyn_cast(type)) + if (auto vectorType = dyn_cast(type)) return vectorType.getNumElements(); return 1; }; @@ -2138,7 +2135,7 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() { LogicalResult spirv::VectorTimesScalarOp::verify() { if (getVector().getType() != getType()) return emitOpError("vector operand and result type mismatch"); - auto scalarType = llvm::cast(getType()).getElementType(); + auto scalarType = cast(getType()).getElementType(); if (getScalar().getType() != scalarType) return emitOpError("scalar operand and result element type match"); return success(); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h index f28d386f8874d..772219c8db654 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h @@ -75,11 +75,11 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), attrName, attr)) return failure(); - if (!llvm::isa(attrVal)) + if (!isa(attrVal)) return parser.emitError(loc, "expected ") << attrName << " attribute specified as string"; - auto attrOptional = spirv::symbolizeEnum( - llvm::cast(attrVal).getValue()); + auto attrOptional = + spirv::symbolizeEnum(cast(attrVal).getValue()); if (!attrOptional) return parser.emitError(loc, "invalid ") << attrName << " attribute specification: " << attrVal; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index d1e275d590f78..53a48abe5ad02 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -178,17 +178,17 @@ unsigned ArrayType::getArrayStride() const { return getImpl()->stride; } //===----------------------------------------------------------------------===// bool CompositeType::classof(Type type) { - if (auto vectorType = llvm::dyn_cast(type)) + if (auto vectorType = dyn_cast(type)) return isValid(vectorType); - return llvm::isa(type); + return isa( + type); } bool CompositeType::isValid(VectorType type) { return type.getRank() == 1 && llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) && - llvm::isa(type.getElementType()); + isa(type.getElementType()); } Type CompositeType::getElementType(unsigned index) const { @@ -210,7 +210,7 @@ unsigned CompositeType::getNumElements() const { } bool CompositeType::hasCompileTimeKnownNumElements() const { - return !llvm::isa(*this); + return !isa(*this); } void TypeCapabilityVisitor::addConcrete(VectorType type) { @@ -529,10 +529,10 @@ void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) { //===----------------------------------------------------------------------===// bool ScalarType::classof(Type type) { - if (auto floatType = llvm::dyn_cast(type)) { + if (auto floatType = dyn_cast(type)) { return isValid(floatType); } - if (auto intType = llvm::dyn_cast(type)) { + if (auto intType = dyn_cast(type)) { return isValid(intType); } return false; @@ -676,19 +676,19 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) { bool SPIRVType::classof(Type type) { // Allow SPIR-V dialect types - if (llvm::isa(type.getDialect())) + if (isa(type.getDialect())) return true; - if (llvm::isa(type)) + if (isa(type)) return true; - if (auto vectorType = llvm::dyn_cast(type)) + if (auto vectorType = dyn_cast(type)) return CompositeType::isValid(vectorType); - if (auto tensorArmType = llvm::dyn_cast(type)) - return llvm::isa(tensorArmType.getElementType()); + if (auto tensorArmType = dyn_cast(type)) + return isa(tensorArmType.getElementType()); return false; } bool SPIRVType::isScalarOrVector() { - return isIntOrFloat() || llvm::isa(*this); + return isIntOrFloat() || isa(*this); } void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, @@ -1190,7 +1190,7 @@ MatrixType::verifyInvariants(function_ref emitError, return emitError() << "matrix columns must be vectors of floats"; /// The underlying vectors (columns) must be of size 2, 3, or 4 - ArrayRef columnShape = llvm::cast(columnType).getShape(); + ArrayRef columnShape = cast(columnType).getShape(); if (columnShape.size() != 1) return emitError() << "matrix columns must be 1D vectors"; @@ -1202,8 +1202,8 @@ MatrixType::verifyInvariants(function_ref emitError, /// Returns true if the matrix elements are vectors of float elements bool MatrixType::isValidColumnType(Type columnType) { - if (auto vectorType = llvm::dyn_cast(columnType)) { - if (llvm::isa(vectorType.getElementType())) + if (auto vectorType = dyn_cast(columnType)) { + if (isa(vectorType.getElementType())) return true; } return false; @@ -1212,13 +1212,13 @@ bool MatrixType::isValidColumnType(Type columnType) { Type MatrixType::getColumnType() const { return getImpl()->columnType; } Type MatrixType::getElementType() const { - return llvm::cast(getImpl()->columnType).getElementType(); + return cast(getImpl()->columnType).getElementType(); } unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; } unsigned MatrixType::getNumRows() const { - return llvm::cast(getImpl()->columnType).getShape()[0]; + return cast(getImpl()->columnType).getShape()[0]; } unsigned MatrixType::getNumElements() const { diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 50883d9ed5e75..ce7e7dc4116c8 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -243,7 +243,7 @@ static LogicalResult deserializeCacheControlDecoration( auto value = opBuilder.getAttr(cacheLevel, cacheControlAttr); SmallVector attrs; if (auto attrList = - llvm::dyn_cast_or_null(decorations[words[0]].get(symbol))) + dyn_cast_or_null(decorations[words[0]].get(symbol))) llvm::append_range(attrs, attrList); attrs.push_back(value); decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs)); @@ -326,7 +326,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { static_cast<::mlir::spirv::LinkageType>(words[wordIndex++])); auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>( StringAttr::get(context, linkageName), linkageTypeAttr); - decorations[words[0]].set(symbol, llvm::dyn_cast(linkageAttr)); + decorations[words[0]].set(symbol, dyn_cast(linkageAttr)); break; } case spirv::Decoration::Aliased: @@ -1511,10 +1511,10 @@ spirv::Deserializer::processTensorARMType(ArrayRef operands) { return emitError(unknownLoc, "OpTypeTensorARM shape must come from a " "constant instruction of type OpTypeArray"); - ArrayAttr shapeArrayAttr = llvm::dyn_cast(shapeInfo->first); + ArrayAttr shapeArrayAttr = dyn_cast(shapeInfo->first); SmallVector shape; for (auto dimAttr : shapeArrayAttr.getValue()) { - auto dimIntAttr = llvm::dyn_cast(dimAttr); + auto dimIntAttr = dyn_cast(dimAttr); if (!dimIntAttr) return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid " "dimension size"); diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp index 6397d2c005c16..b78fac532d8c5 100644 --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -876,7 +876,7 @@ Serializer::processOp(spirv::ExecutionModeOp op) { if (values) { for (auto &intVal : values.getValue()) { operands.push_back(static_cast( - llvm::cast(intVal).getValue().getZExtValue())); + cast(intVal).getValue().getZExtValue())); } } encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index c879a2b3e0207..c29d20f755332 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -312,7 +312,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, case spirv::Decoration::LinkageAttributes: { // Get the value of the Linkage Attributes // e.g., LinkageAttributes=["linkageName", linkageType]. - auto linkageAttr = llvm::dyn_cast(attr); + auto linkageAttr = dyn_cast(attr); auto linkageName = linkageAttr.getLinkageName(); auto linkageType = linkageAttr.getLinkageType().getValue(); // Encode the Linkage Name (string literal to uint32_t). @@ -822,7 +822,7 @@ LogicalResult Serializer::prepareBasicType( return success(); } - if (auto tensorArmType = llvm::dyn_cast(type)) { + if (auto tensorArmType = dyn_cast(type)) { uint32_t elementTypeID = 0; uint32_t rank = 0; uint32_t shapeID = 0;