diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp index 776208bd3e693..b1c8a68df5d69 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp @@ -49,9 +49,12 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, raw_ostream &O) { unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16; const unsigned NumVarOps = MI->getNumOperands() - StartIndex; + const unsigned Opcode = MI->getOpcode(); - assert((NumVarOps == 1 || NumVarOps == 2) && - "Unsupported number of bits for literal variable"); + // We support up to 1024 bits for integers, and 64 bits for floats + assert(((NumVarOps <= 32 && Opcode == SPIRV::OpConstantI) || + (NumVarOps <= 2 && Opcode == SPIRV::OpConstantF)) && + "Unsupported number of operands for constant"); O << ' '; @@ -62,6 +65,18 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32); } + // Handle arbitrary number of 32-bit words for integer literals + if (Opcode == SPIRV::OpConstantI && NumVarOps > 2) { + const unsigned TotalBits = NumVarOps * 32; + APInt Val(TotalBits, 0); + for (unsigned i = 0; i < NumVarOps; ++i) { + uint64_t Word = MI->getOperand(StartIndex + i).getImm(); + Val |= APInt(TotalBits, Word) << (i * 32); + } + O << Val; + return; + } + // Format and print float values. if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) { APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat()) diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 115766ce886c7..f01d8a34c2ca4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -149,7 +149,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { } unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const { - if (Width > 64) + if (Width > 1024) report_fatal_error("Unsupported integer width!"); const SPIRVSubtarget &ST = cast(CurMF->getSubtarget()); if (ST.canUseExtension( @@ -343,6 +343,20 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF, return Res; } +Register SPIRVGlobalRegistry::getOrCreateConstInt(const APInt Val, + MachineInstr &I, + SPIRVType *SpvType, + const SPIRVInstrInfo &TII, + bool ZeroAsNull) { + const IntegerType *Ty = cast(getTypeForSPIRVType(SpvType)); + Constant *const CA = ConstantInt::get(const_cast(Ty), Val); + const MachineInstr *MI = findMI(CA, CurMF); + if (MI && (MI->getOpcode() == SPIRV::OpConstantNull || + MI->getOpcode() == SPIRV::OpConstantI)) + return MI->getOperand(0).getReg(); + return createConstInt(CA, I, SpvType, TII, ZeroAsNull); +} + Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, @@ -356,7 +370,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, return createConstInt(CI, I, SpvType, TII, ZeroAsNull); } -Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI, +Register SPIRVGlobalRegistry::createConstInt(const Constant *CA, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, @@ -374,15 +388,20 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI, MachineInstrBuilder MIB; if (BitWidth == 1) { MIB = MIRBuilder - .buildInstr(CI->isZero() ? SPIRV::OpConstantFalse - : SPIRV::OpConstantTrue) + .buildInstr(CA->isZeroValue() ? SPIRV::OpConstantFalse + : SPIRV::OpConstantTrue) .addDef(Res) .addUse(getSPIRVTypeID(SpvType)); - } else if (!CI->isZero() || !ZeroAsNull) { + } else if (!CA->isZeroValue() || !ZeroAsNull) { MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI) .addDef(Res) .addUse(getSPIRVTypeID(SpvType)); - addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB); + if (BitWidth <= 64) { + const ConstantInt *CI = dyn_cast(CA); + addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB); + } else { + addNumImm(CA->getUniqueInteger(), MIB); + } } else { MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) .addDef(Res) @@ -394,7 +413,7 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI, *ST.getRegBankInfo()); return MIB; }); - add(CI, NewType); + add(CA, NewType); return Res; } diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index a648defa0a888..60085cd0f337b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -515,10 +515,13 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR, bool ZeroAsNull = true); + Register getOrCreateConstInt(APInt Val, MachineInstr &I, SPIRVType *SpvType, + const SPIRVInstrInfo &TII, + bool ZeroAsNull = true); Register getOrCreateConstInt(uint64_t Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, bool ZeroAsNull = true); - Register createConstInt(const ConstantInt *CI, MachineInstr &I, + Register createConstInt(const Constant *CI, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, bool ZeroAsNull); Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType, diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 1aadd9df189a8..3c3374e5066b4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -2939,8 +2939,13 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg, Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I, ResType, TII, !STI.isShader()); } else { - Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I, - ResType, TII, !STI.isShader()); + if (GR.getScalarOrVectorBitWidth(ResType) <= 64) { + Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I, + ResType, TII, !STI.isShader()); + } else { + Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, + ResType, TII, !STI.isShader()); + } } return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I); } diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 820e56b362edc..e409234a83568 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -100,11 +100,12 @@ void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) { if (Bitwidth == 16) MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16); return; - } else if (Bitwidth <= 64) { - uint64_t FullImm = Imm.getZExtValue(); - uint32_t LowBits = FullImm & 0xffffffff; - uint32_t HighBits = (FullImm >> 32) & 0xffffffff; - MIB.addImm(LowBits).addImm(HighBits); + } else if (Bitwidth <= 1024) { + unsigned NumWords = (Bitwidth + 31) / 32; + for (unsigned i = 0; i < NumWords; ++i) { + uint32_t Word = Imm.extractBits(32, i * 32).getZExtValue(); + MIB.addImm(Word); + } return; } report_fatal_error("Unsupported constant bitwidth"); diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll index 41d4b58ed1157..23681d660dd20 100644 --- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll @@ -8,6 +8,14 @@ define i13 @getConstantI13() { ret i13 42 } +define i96 @getConstantI96() { + ret i96 18446744073709551620 +} + +define i160 @getConstantI160() { + ret i160 3363637389930338837376336738763689377839373638 +} + ;; Capabilities: ; CHECK-DAG: OpExtension "SPV_INTEL_arbitrary_precision_integers" ; CHECK-DAG: OpCapability ArbitraryPrecisionIntegersINTEL @@ -17,14 +25,20 @@ define i13 @getConstantI13() { ;; Names: ; CHECK-DAG: OpName %[[#GET_I6:]] "getConstantI6" ; CHECK-DAG: OpName %[[#GET_I13:]] "getConstantI13" +; CHECK-DAG: OpName %[[#GET_I96:]] "getConstantI96" +; CHECK-DAG: OpName %[[#GET_I160:]] "getConstantI160" ; CHECK-NOT: DAG-FENCE ;; Types and Constants: ; CHECK-DAG: %[[#I6:]] = OpTypeInt 6 0 ; CHECK-DAG: %[[#I13:]] = OpTypeInt 13 0 +; CHECK-DAG: %[[#I96:]] = OpTypeInt 96 0 +; CHECK-DAG: %[[#I160:]] = OpTypeInt 160 0 ; CHECK-DAG: %[[#CST_I6:]] = OpConstant %[[#I6]] 2 ; CHECK-DAG: %[[#CST_I13:]] = OpConstant %[[#I13]] 42 +; CHECK-DAG: %[[#CST_I96:]] = OpConstant %[[#I96]] 18446744073709551620 +; CHECK-DAG: %[[#CST_I160:]] = OpConstant %[[#I160]] 3363637389930338837376336738763689377839373638 ; CHECK: %[[#GET_I6]] = OpFunction %[[#I6]] ; CHECK: OpReturnValue %[[#CST_I6]] @@ -33,3 +47,11 @@ define i13 @getConstantI13() { ; CHECK: %[[#GET_I13]] = OpFunction %[[#I13]] ; CHECK: OpReturnValue %[[#CST_I13]] ; CHECK: OpFunctionEnd + +; CHECK: %[[#GET_I96]] = OpFunction %[[#I96]] +; CHECK: OpReturnValue %[[#CST_I96]] +; CHECK: OpFunctionEnd + +; CHECK: %[[#GET_I160]] = OpFunction %[[#I160]] +; CHECK: OpReturnValue %[[#CST_I160]] +; CHECK: OpFunctionEnd