diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index b7cfa4f6f2ac1..269e59d9ed82b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -25,9 +25,64 @@ namespace llvm { class SPIRVSubtarget; +/// @deprecated Use SPIRVTypeInst instead +/// SPIRVType is supposed to represent a MachineInstr that defines a SPIRV Type +/// (e.g. an OpTypeInt intruction). It is misused in several places and we're +/// getting rid of it. using SPIRVType = const MachineInstr; + using StructOffsetDecorator = std::function; +class SPIRVTypeInst { + const MachineInstr *MI; + +public: + static bool definesATypeRegister(const MachineInstr &MI) { + const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); + return MRI.getRegClass(MI.getOperand(0).getReg()) == &SPIRV::TYPERegClass; + } + + SPIRVTypeInst(const MachineInstr &MI) : SPIRVTypeInst(&MI) {} + SPIRVTypeInst(const MachineInstr *MI) : MI(MI) { + // A SPIRV Type whose result is not a type is invalid. + assert(!MI || definesATypeRegister(*MI)); + } + + // No need to verify the register since it's already verified by the copied + // object. + SPIRVTypeInst(const SPIRVTypeInst &Other) : MI(Other.MI) {} + + SPIRVTypeInst &operator=(const SPIRVTypeInst &Other) { + MI = Other.MI; + return *this; + } + + const MachineInstr &operator*() const { return *MI; } + const MachineInstr *operator->() const { return MI; } + operator const MachineInstr *() const { return MI; } + + bool operator==(const SPIRVTypeInst &Other) const { return MI == Other.MI; } + bool operator!=(const SPIRVTypeInst &Other) const { return MI != Other.MI; } + + bool operator==(const MachineInstr *Other) const { return MI == Other; } + bool operator!=(const MachineInstr *Other) const { return MI != Other; } + + operator bool() const { return MI; } + + unsigned getHashValue() const { + return DenseMapInfo::getHashValue(MI); + } +}; + +template <> struct DenseMapInfo { + static SPIRVTypeInst getEmptyKey() { return SPIRVTypeInst(nullptr); } + static SPIRVTypeInst getTombstoneKey() { return SPIRVTypeInst(nullptr); } + static unsigned getHashValue(SPIRVTypeInst Ty) { return Ty.getHashValue(); } + static bool isEqual(SPIRVTypeInst Ty1, SPIRVTypeInst Ty2) { + return Ty1 == Ty2; + } +}; + class SPIRVGlobalRegistry : public SPIRVIRMapping { // Registers holding values which have types associated with them. // Initialized upon VReg definition in IRTranslator. diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp index 7d3711a583e50..b16dae012b10d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -39,7 +39,8 @@ SPIRVTargetLowering::SPIRVTargetLowering(const TargetMachine &TM, // Returns true of the types logically match, as defined in // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical. -static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2, +static bool typesLogicallyMatch(const SPIRVTypeInst Ty1, + const SPIRVTypeInst Ty2, SPIRVGlobalRegistry &GR) { if (Ty1->getOpcode() != Ty2->getOpcode()) return false; @@ -52,17 +53,19 @@ static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2, if (Ty1->getOperand(2).getReg() != Ty2->getOperand(2).getReg()) return false; - SPIRVType *ElemType1 = GR.getSPIRVTypeForVReg(Ty1->getOperand(1).getReg()); - SPIRVType *ElemType2 = GR.getSPIRVTypeForVReg(Ty2->getOperand(1).getReg()); + SPIRVTypeInst ElemType1 = + GR.getSPIRVTypeForVReg(Ty1->getOperand(1).getReg()); + SPIRVTypeInst ElemType2 = + GR.getSPIRVTypeForVReg(Ty2->getOperand(1).getReg()); return ElemType1 == ElemType2 || typesLogicallyMatch(ElemType1, ElemType2, GR); } if (Ty1->getOpcode() == SPIRV::OpTypeStruct) { for (unsigned I = 1; I < Ty1->getNumOperands(); I++) { - SPIRVType *ElemType1 = + SPIRVTypeInst ElemType1 = GR.getSPIRVTypeForVReg(Ty1->getOperand(I).getReg()); - SPIRVType *ElemType2 = + SPIRVTypeInst ElemType2 = GR.getSPIRVTypeForVReg(Ty2->getOperand(I).getReg()); if (ElemType1 != ElemType2 && !typesLogicallyMatch(ElemType1, ElemType2, GR))