From 4c1d90c01082af32ee28a04edd3a0fc58a4ac1b9 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Sat, 27 Sep 2025 08:58:37 -0700 Subject: [PATCH 01/12] Add support for arbitrary integer with bitwidth larger than 64 bits in spirv-backend --- .../SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp | 18 ++++++--- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 17 ++++---- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 4 +- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 39 +++++++++---------- llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 11 +++--- .../SPV_INTEL_arbitrary_precision_integers.ll | 4 ++ 6 files changed, 52 insertions(+), 41 deletions(-) diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp index 776208bd3e693..dff9f699ebd6f 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp @@ -50,18 +50,24 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16; const unsigned NumVarOps = MI->getNumOperands() - StartIndex; - assert((NumVarOps == 1 || NumVarOps == 2) && + // we support integer up to 1024 bits + assert((NumVarOps <= 1024) && "Unsupported number of bits for literal variable"); O << ' '; - uint64_t Imm = MI->getOperand(StartIndex).getImm(); - - // Handle 64 bit literals. - if (NumVarOps == 2) { - Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32); + // Handle arbitrary number of 32-bit words for the literal value. + if (MI->getOpcode() == SPIRV::OpConstantI){ + APInt Val(NumVarOps * 32, 0); + for (unsigned i = 0; i < NumVarOps; ++i) { + Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) << (i * 32)); + } + O << Val; + return; } + uint64_t Imm = MI->getOperand(StartIndex).getImm(); + // Format and print float values. if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) { APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat()) diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 115766ce886c7..05b3371e97cdc 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -149,7 +149,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { } unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const { - if (Width > 64) + if (Width > 1024) report_fatal_error("Unsupported integer width!"); const SPIRVSubtarget &ST = cast(CurMF->getSubtarget()); if (ST.canUseExtension( @@ -343,7 +343,7 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF, return Res; } -Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, +Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, bool ZeroAsNull) { @@ -353,10 +353,11 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, if (MI && (MI->getOpcode() == SPIRV::OpConstantNull || MI->getOpcode() == SPIRV::OpConstantI)) return MI->getOperand(0).getReg(); - return createConstInt(CI, I, SpvType, TII, ZeroAsNull); + return createConstInt(CI, Val, I, SpvType, TII, ZeroAsNull); } -Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI, +Register SPIRVGlobalRegistry::createConstInt(const Constant *CI, + APInt Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, @@ -374,15 +375,15 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI, MachineInstrBuilder MIB; if (BitWidth == 1) { MIB = MIRBuilder - .buildInstr(CI->isZero() ? SPIRV::OpConstantFalse + .buildInstr(Val.isZero() ? SPIRV::OpConstantFalse : SPIRV::OpConstantTrue) .addDef(Res) .addUse(getSPIRVTypeID(SpvType)); - } else if (!CI->isZero() || !ZeroAsNull) { + } else if (!Val.isZero() || !ZeroAsNull) { MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI) .addDef(Res) .addUse(getSPIRVTypeID(SpvType)); - addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB); + addNumImm(Val, MIB); } else { MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) .addDef(Res) @@ -491,7 +492,7 @@ Register SPIRVGlobalRegistry::getOrCreateBaseRegister( } assert(Type->getOpcode() == SPIRV::OpTypeInt); SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); - return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I, + return getOrCreateConstInt(APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I, SpvBaseType, TII, ZeroAsNull); } diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index a648defa0a888..ee217f81fb416 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -515,10 +515,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR, bool ZeroAsNull = true); - Register getOrCreateConstInt(uint64_t Val, MachineInstr &I, + Register getOrCreateConstInt(APInt Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, bool ZeroAsNull = true); - Register createConstInt(const ConstantInt *CI, MachineInstr &I, + Register createConstInt(const Constant *CI, APInt Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, bool ZeroAsNull); Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType, diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 1aadd9df189a8..3e5566945ec0b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -2252,8 +2252,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion( .addDef(AElt) .addUse(GR.getSPIRVTypeID(ResType)) .addUse(X) - .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull)) - .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull)) + .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull)) + .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull)) .constrainAllUses(TII, TRI, RBI); // B[i] @@ -2263,8 +2263,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion( .addDef(BElt) .addUse(GR.getSPIRVTypeID(ResType)) .addUse(Y) - .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull)) - .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull)) + .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull)) + .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull)) .constrainAllUses(TII, TRI, RBI); // A[i] * B[i] @@ -2283,8 +2283,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion( .addDef(MaskMul) .addUse(GR.getSPIRVTypeID(ResType)) .addUse(Mul) - .addUse(GR.getOrCreateConstInt(0, I, EltType, TII, ZeroAsNull)) - .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull)) + .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, ZeroAsNull)) + .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull)) .constrainAllUses(TII, TRI, RBI); // Acc = Acc + A[i] * B[i] @@ -2381,7 +2381,7 @@ bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg, auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, + .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII, !STI.isShader())); for (unsigned J = 2; J < I.getNumOperands(); J++) { @@ -2405,7 +2405,7 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits( TII.get(SPIRV::OpGroupNonUniformBallotBitCount)) .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, + .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII, !STI.isShader())) .addImm(SPIRV::GroupOperation::Reduce) .addUse(BallotReg) @@ -2436,7 +2436,7 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg, return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII, + .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII, !STI.isShader())) .addImm(SPIRV::GroupOperation::Reduce) .addUse(I.getOperand(2).getReg()) @@ -2463,7 +2463,7 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg, return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII, + .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII, !STI.isShader())) .addImm(SPIRV::GroupOperation::Reduce) .addUse(I.getOperand(2).getReg()); @@ -2689,7 +2689,7 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType, bool ZeroAsNull = !STI.isShader(); if (ResType->getOpcode() == SPIRV::OpTypeVector) return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull); - return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull); + return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull); } Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType, @@ -2720,7 +2720,7 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes, AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0); if (ResType->getOpcode() == SPIRV::OpTypeVector) return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII); - return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII); + return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, ResType, TII); } bool SPIRVInstructionSelector::selectSelect(Register ResVReg, @@ -2939,8 +2939,7 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg, Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I, ResType, TII, !STI.isShader()); } else { - Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I, - ResType, TII, !STI.isShader()); + Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, ResType, TII, !STI.isShader()); } return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I); } @@ -3765,7 +3764,7 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow( bool ZeroAsNull = !STI.isShader(); Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type)); Register ConstIntLastIdx = GR.getOrCreateConstInt( - ComponentCount - 1, I, BaseType, TII, ZeroAsNull); + APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, BaseType, TII, ZeroAsNull); if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx}, SPIRV::OpVectorExtractDynamic)) @@ -3794,9 +3793,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64( SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType); bool ZeroAsNull = !STI.isShader(); Register ConstIntZero = - GR.getOrCreateConstInt(0, I, BaseType, TII, ZeroAsNull); + GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), I, BaseType, TII, ZeroAsNull); Register ConstIntOne = - GR.getOrCreateConstInt(1, I, BaseType, TII, ZeroAsNull); + GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), I, BaseType, TII, ZeroAsNull); // SPIRV doesn't support vectors with more than 4 components. Since the // algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only @@ -3881,9 +3880,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64( if (IsScalarRes) { NegOneReg = - GR.getOrCreateConstInt((unsigned)-1, I, ResType, TII, ZeroAsNull); - Reg0 = GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull); - Reg32 = GR.getOrCreateConstInt(32, I, ResType, TII, ZeroAsNull); + GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, TII, ZeroAsNull); + Reg0 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull); + Reg32 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), I, ResType, TII, ZeroAsNull); SelectOp = SPIRV::OpSelectSISCond; AddOp = SPIRV::OpIAddS; } else { diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 820e56b362edc..e409234a83568 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -100,11 +100,12 @@ void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) { if (Bitwidth == 16) MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16); return; - } else if (Bitwidth <= 64) { - uint64_t FullImm = Imm.getZExtValue(); - uint32_t LowBits = FullImm & 0xffffffff; - uint32_t HighBits = (FullImm >> 32) & 0xffffffff; - MIB.addImm(LowBits).addImm(HighBits); + } else if (Bitwidth <= 1024) { + unsigned NumWords = (Bitwidth + 31) / 32; + for (unsigned i = 0; i < NumWords; ++i) { + uint32_t Word = Imm.extractBits(32, i * 32).getZExtValue(); + MIB.addImm(Word); + } return; } report_fatal_error("Unsupported constant bitwidth"); diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll index 41d4b58ed1157..17ba9b044842c 100644 --- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll @@ -8,6 +8,10 @@ define i13 @getConstantI13() { ret i13 42 } +define i96 @getConstantI96() { + ret i96 18446744073709551620 +} + ;; Capabilities: ; CHECK-DAG: OpExtension "SPV_INTEL_arbitrary_precision_integers" ; CHECK-DAG: OpCapability ArbitraryPrecisionIntegersINTEL From 8578a56ff0bd41fc78ed0702f62f1c9d3233968a Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Mon, 6 Oct 2025 15:22:56 -0700 Subject: [PATCH 02/12] update the test --- .../SPV_INTEL_arbitrary_precision_integers.ll | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll index 17ba9b044842c..003a900c73770 100644 --- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll @@ -10,7 +10,11 @@ define i13 @getConstantI13() { define i96 @getConstantI96() { ret i96 18446744073709551620 -} +} + +define i160 @getConstantI160() { + ret i160 3363637389930338837376336738763689377839373638 +} ;; Capabilities: ; CHECK-DAG: OpExtension "SPV_INTEL_arbitrary_precision_integers" @@ -21,14 +25,20 @@ define i96 @getConstantI96() { ;; Names: ; CHECK-DAG: OpName %[[#GET_I6:]] "getConstantI6" ; CHECK-DAG: OpName %[[#GET_I13:]] "getConstantI13" +; CHECK-DAG: OpName %[[#GET_I96:]] "getConstantI96" +; CHECK-DAG: OpName %[[#GET_I160:]] "getConstantI160" ; CHECK-NOT: DAG-FENCE ;; Types and Constants: ; CHECK-DAG: %[[#I6:]] = OpTypeInt 6 0 ; CHECK-DAG: %[[#I13:]] = OpTypeInt 13 0 +; CHECK-DAG: %[[#I96:]] = OpTypeInt 96 0 +; CHECK-DAG: %[[#I160:]] = OpTypeInt 160 0 ; CHECK-DAG: %[[#CST_I6:]] = OpConstant %[[#I6]] 2 ; CHECK-DAG: %[[#CST_I13:]] = OpConstant %[[#I13]] 42 +; CHECK-DAG: %[[#CST_I96:]] = OpConstant %[[#I96]] 18446744073709551620 +; CHECK-DAG: %[[#CST_I160:]] = OpConstant %[[#I160]] 3363637389930338837376336738763689377839373638 ; CHECK: %[[#GET_I6]] = OpFunction %[[#I6]] ; CHECK: OpReturnValue %[[#CST_I6]] @@ -37,3 +47,11 @@ define i96 @getConstantI96() { ; CHECK: %[[#GET_I13]] = OpFunction %[[#I13]] ; CHECK: OpReturnValue %[[#CST_I13]] ; CHECK: OpFunctionEnd + +; CHECK: %[[#GET_I96]] = OpFunction %[[#I96]] +; CHECK: OpReturnValue %[[#CST_I96]] +; CHECK: OpFunctionEnd + +; CHECK: %[[#GET_I160]] = OpFunction %[[#I160]] +; CHECK: OpReturnValue %[[#CST_I160]] +; CHECK: OpFunctionEnd \ No newline at end of file From fd4a78f7119aca99afd7ac924231b10ddfffb28c Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Mon, 6 Oct 2025 15:56:33 -0700 Subject: [PATCH 03/12] code clean up --- .../SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp | 9 +- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 8 +- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 4 +- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 117 ++++++++++-------- .../SPV_INTEL_arbitrary_precision_integers.ll | 2 +- 5 files changed, 79 insertions(+), 61 deletions(-) diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp index dff9f699ebd6f..9529e18da21c7 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp @@ -50,17 +50,18 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16; const unsigned NumVarOps = MI->getNumOperands() - StartIndex; - // we support integer up to 1024 bits - assert((NumVarOps <= 1024) && + // We support integer up to 1024 bits + assert((NumVarOps <= 32) && "Unsupported number of bits for literal variable"); O << ' '; // Handle arbitrary number of 32-bit words for the literal value. - if (MI->getOpcode() == SPIRV::OpConstantI){ + if (MI->getOpcode() == SPIRV::OpConstantI) { APInt Val(NumVarOps * 32, 0); for (unsigned i = 0; i < NumVarOps; ++i) { - Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) << (i * 32)); + Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) + << (i * 32)); } O << Val; return; diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 05b3371e97cdc..c4f8565f39b84 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -356,8 +356,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I, return createConstInt(CI, Val, I, SpvType, TII, ZeroAsNull); } -Register SPIRVGlobalRegistry::createConstInt(const Constant *CI, - APInt Val, +Register SPIRVGlobalRegistry::createConstInt(const Constant *CI, APInt Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, @@ -492,8 +491,9 @@ Register SPIRVGlobalRegistry::getOrCreateBaseRegister( } assert(Type->getOpcode() == SPIRV::OpTypeInt); SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); - return getOrCreateConstInt(APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I, - SpvBaseType, TII, ZeroAsNull); + return getOrCreateConstInt( + APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I, SpvBaseType, + TII, ZeroAsNull); } Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull( diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index ee217f81fb416..9cb7d982c3fc2 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -515,8 +515,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR, bool ZeroAsNull = true); - Register getOrCreateConstInt(APInt Val, MachineInstr &I, - SPIRVType *SpvType, const SPIRVInstrInfo &TII, + Register getOrCreateConstInt(APInt Val, MachineInstr &I, SPIRVType *SpvType, + const SPIRVInstrInfo &TII, bool ZeroAsNull = true); Register createConstInt(const Constant *CI, APInt Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 3e5566945ec0b..f82ddbc8990b6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -2247,33 +2247,36 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion( for (unsigned i = 0; i < 4; i++) { // A[i] Register AElt = MRI->createVirtualRegister(&SPIRV::IDRegClass); - Result &= - BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) - .addDef(AElt) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(X) - .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull)) - .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull)) - .constrainAllUses(TII, TRI, RBI); + Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) + .addDef(AElt) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(X) + .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, + TII, ZeroAsNull)) + .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, + ZeroAsNull)) + .constrainAllUses(TII, TRI, RBI); // B[i] - Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass); - Result &= - BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) - .addDef(BElt) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(Y) - .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull)) - .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull)) - .constrainAllUses(TII, TRI, RBI); + Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) + .addDef(BElt) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Y) + .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, + TII, ZeroAsNull)) + .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, + ZeroAsNull)) + .constrainAllUses(TII, TRI, RBI); // A[i] * B[i] - Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass); - Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS)) - .addDef(Mul) + Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) + .addDef(MaskMul) .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(AElt) - .addUse(BElt) + .addUse(Mul) + .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, + ZeroAsNull)) + .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, + ZeroAsNull)) .constrainAllUses(TII, TRI, RBI); // Discard 24 highest-bits so that stored i32 register is i8 equivalent @@ -2378,11 +2381,12 @@ bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg, MachineBasicBlock &BB = *I.getParent(); SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); - auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) - .addDef(ResVReg) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, - IntTy, TII, !STI.isShader())); + auto BMI = + BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, + IntTy, TII, !STI.isShader())); for (unsigned J = 2; J < I.getNumOperands(); J++) { BMI.addUse(I.getOperand(J).getReg()); @@ -2401,15 +2405,16 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits( SPIRV::OpGroupNonUniformBallot); MachineBasicBlock &BB = *I.getParent(); - Result &= BuildMI(BB, I, I.getDebugLoc(), - TII.get(SPIRV::OpGroupNonUniformBallotBitCount)) - .addDef(ResVReg) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, - TII, !STI.isShader())) - .addImm(SPIRV::GroupOperation::Reduce) - .addUse(BallotReg) - .constrainAllUses(TII, TRI, RBI); + Result &= + BuildMI(BB, I, I.getDebugLoc(), + TII.get(SPIRV::OpGroupNonUniformBallotBitCount)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, + IntTy, TII, !STI.isShader())) + .addImm(SPIRV::GroupOperation::Reduce) + .addUse(BallotReg) + .constrainAllUses(TII, TRI, RBI); return Result; } @@ -2436,8 +2441,8 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg, return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII, - !STI.isShader())) + .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, + IntTy, TII, !STI.isShader())) .addImm(SPIRV::GroupOperation::Reduce) .addUse(I.getOperand(2).getReg()) .constrainAllUses(TII, TRI, RBI); @@ -2463,8 +2468,8 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg, return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII, - !STI.isShader())) + .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, + IntTy, TII, !STI.isShader())) .addImm(SPIRV::GroupOperation::Reduce) .addUse(I.getOperand(2).getReg()); } @@ -2689,7 +2694,8 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType, bool ZeroAsNull = !STI.isShader(); if (ResType->getOpcode() == SPIRV::OpTypeVector) return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull); - return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull); + return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), + I, ResType, TII, ZeroAsNull); } Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType, @@ -2720,7 +2726,9 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes, AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0); if (ResType->getOpcode() == SPIRV::OpTypeVector) return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII); - return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, ResType, TII); + return GR.getOrCreateConstInt( + APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, + ResType, TII); } bool SPIRVInstructionSelector::selectSelect(Register ResVReg, @@ -2939,7 +2947,8 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg, Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I, ResType, TII, !STI.isShader()); } else { - Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, ResType, TII, !STI.isShader()); + Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, + ResType, TII, !STI.isShader()); } return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I); } @@ -3764,7 +3773,8 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow( bool ZeroAsNull = !STI.isShader(); Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type)); Register ConstIntLastIdx = GR.getOrCreateConstInt( - APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, BaseType, TII, ZeroAsNull); + APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, + BaseType, TII, ZeroAsNull); if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx}, SPIRV::OpVectorExtractDynamic)) @@ -3793,9 +3803,11 @@ bool SPIRVInstructionSelector::selectFirstBitSet64( SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType); bool ZeroAsNull = !STI.isShader(); Register ConstIntZero = - GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), I, BaseType, TII, ZeroAsNull); + GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), + I, BaseType, TII, ZeroAsNull); Register ConstIntOne = - GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), I, BaseType, TII, ZeroAsNull); + GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), + I, BaseType, TII, ZeroAsNull); // SPIRV doesn't support vectors with more than 4 components. Since the // algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only @@ -3879,10 +3891,15 @@ bool SPIRVInstructionSelector::selectFirstBitSet64( unsigned AddOp; if (IsScalarRes) { - NegOneReg = - GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, TII, ZeroAsNull); - Reg0 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull); - Reg32 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), I, ResType, TII, ZeroAsNull); + NegOneReg = GR.getOrCreateConstInt( + APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, + TII, ZeroAsNull); + Reg0 = + GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), + I, ResType, TII, ZeroAsNull); + Reg32 = + GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), + I, ResType, TII, ZeroAsNull); SelectOp = SPIRV::OpSelectSISCond; AddOp = SPIRV::OpIAddS; } else { diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll index 003a900c73770..23681d660dd20 100644 --- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll @@ -54,4 +54,4 @@ define i160 @getConstantI160() { ; CHECK: %[[#GET_I160]] = OpFunction %[[#I160]] ; CHECK: OpReturnValue %[[#CST_I160]] -; CHECK: OpFunctionEnd \ No newline at end of file +; CHECK: OpFunctionEnd From 6547337a900df2921158a5b0c1ec05a0b5193089 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Mon, 6 Oct 2025 16:07:21 -0700 Subject: [PATCH 04/12] code clean up --- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index f82ddbc8990b6..7004d400bd450 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -2258,6 +2258,7 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion( .constrainAllUses(TII, TRI, RBI); // B[i] + Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass); Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) .addDef(BElt) .addUse(GR.getSPIRVTypeID(ResType)) @@ -2281,14 +2282,15 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion( // Discard 24 highest-bits so that stored i32 register is i8 equivalent Register MaskMul = MRI->createVirtualRegister(&SPIRV::IDRegClass); - Result &= - BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) - .addDef(MaskMul) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(Mul) - .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, ZeroAsNull)) - .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull)) - .constrainAllUses(TII, TRI, RBI); + Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) + .addDef(MaskMul) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Mul) + .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, + ZeroAsNull)) + .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, + ZeroAsNull)) + .constrainAllUses(TII, TRI, RBI); // Acc = Acc + A[i] * B[i] Register Sum = From 5dcff453f0c9a57edbe853eea4e7cf3861ef8664 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Mon, 6 Oct 2025 16:26:33 -0700 Subject: [PATCH 05/12] code clean up --- llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 7004d400bd450..911e43e57edee 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -2270,14 +2270,13 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion( .constrainAllUses(TII, TRI, RBI); // A[i] * B[i] - Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) - .addDef(MaskMul) + Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass); + Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS)) + .addDef(Mul) .addUse(GR.getSPIRVTypeID(ResType)) .addUse(Mul) - .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, - ZeroAsNull)) - .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, - ZeroAsNull)) + .addUse(AElt) + .addUse(BElt) .constrainAllUses(TII, TRI, RBI); // Discard 24 highest-bits so that stored i32 register is i8 equivalent From 3fada99efe130362bd04cd04df55740ece2b9a04 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Tue, 7 Oct 2025 14:37:48 -0700 Subject: [PATCH 06/12] fix CI failure --- .../SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp | 76 ++++++++++--------- 1 file changed, 42 insertions(+), 34 deletions(-) diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp index 9529e18da21c7..2e2b834c167f2 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp @@ -47,57 +47,65 @@ void SPIRVInstPrinter::printRemainingVariableOps(const MCInst *MI, void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, unsigned StartIndex, raw_ostream &O) { - unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16; + const bool IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16; const unsigned NumVarOps = MI->getNumOperands() - StartIndex; + const unsigned Opcode = MI->getOpcode(); - // We support integer up to 1024 bits - assert((NumVarOps <= 32) && - "Unsupported number of bits for literal variable"); + // We support up to 1024 bits for integers, and 64 bits for floats + assert(((NumVarOps <= 32 && Opcode == SPIRV::OpConstantI) || + (NumVarOps <= 2 && Opcode == SPIRV::OpConstantF)) && + "Unsupported number of operands for constant"); O << ' '; - // Handle arbitrary number of 32-bit words for the literal value. - if (MI->getOpcode() == SPIRV::OpConstantI) { - APInt Val(NumVarOps * 32, 0); + // Handle arbitrary number of 32-bit words for integer literals + if (Opcode == SPIRV::OpConstantI) { + const unsigned TotalBits = NumVarOps * 32; + APInt Val(TotalBits, 0); for (unsigned i = 0; i < NumVarOps; ++i) { - Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) - << (i * 32)); + uint64_t Word = MI->getOperand(StartIndex + i).getImm(); + Val |= APInt(TotalBits, Word) << (i * 32); } O << Val; return; } + // Handle float constants (OpConstantF) uint64_t Imm = MI->getOperand(StartIndex).getImm(); + if (NumVarOps == 2) { + Imm |= static_cast(MI->getOperand(StartIndex + 1).getImm()) << 32; + } - // Format and print float values. - if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) { - APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat()) - : APFloat(APInt(64, Imm).bitsToDouble()); - - // Print infinity and NaN as hex floats. - // TODO: Make sure subnormal numbers are handled correctly as they may also - // require hex float notation. - if (FP.isInfinity()) { - if (FP.isNegative()) - O << '-'; - O << "0x1p+128"; - return; - } - if (FP.isNaN()) { - O << "0x1.8p+128"; - return; - } - - // Format val as a decimal floating point or scientific notation (whichever - // is shorter), with enough digits of precision to produce the exact value. - O << format("%.*g", std::numeric_limits::max_digits10, - FP.convertToDouble()); + // For 16-bit floats, print as integer + if (IsBitwidth16) { + O << Imm; + return; + } + // Format and print float values + const APFloat FP = (NumVarOps == 1) + ? APFloat(APInt(32, Imm).bitsToFloat()) + : APFloat(APInt(64, Imm).bitsToDouble()); + + // Print infinity and NaN as hex floats. + // TODO: Make sure subnormal numbers are handled correctly as they may also + // require hex float notation. + if (FP.isInfinity()) { + if (FP.isNegative()) + O << '-'; + O << "0x1p+128"; + return; + } + + if (FP.isNaN()) { + O << "0x1.8p+128"; return; } - // Print integer values directly. - O << Imm; + // Format val as a decimal floating point or scientific notation (whichever + // is shorter), with enough digits of precision to produce the exact value. + O << format("%.*g", std::numeric_limits::max_digits10, + FP.convertToDouble()); } void SPIRVInstPrinter::recordOpExtInstImport(const MCInst *MI) { From acfca88085f822c425889d39f4e676bc4a2160b0 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Tue, 7 Oct 2025 15:49:10 -0700 Subject: [PATCH 07/12] fix CI failure --- llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 911e43e57edee..3e441d8c4a0cb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -2274,7 +2274,6 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion( Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulS)) .addDef(Mul) .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(Mul) .addUse(AElt) .addUse(BElt) .constrainAllUses(TII, TRI, RBI); @@ -2948,7 +2947,8 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg, Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I, ResType, TII, !STI.isShader()); } else { - Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, + Reg = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), + I.getOperand(1).getCImm()->getZExtValue()), I, ResType, TII, !STI.isShader()); } return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I); From 42ba900770891b5ed6491e49cd79156b05d418cd Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Wed, 8 Oct 2025 13:49:56 -0700 Subject: [PATCH 08/12] code refactoring --- .../SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp | 9 +- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 12 +- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 6 +- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 121 ++++++++---------- 4 files changed, 65 insertions(+), 83 deletions(-) diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp index 2e2b834c167f2..59206b443a312 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp @@ -47,7 +47,7 @@ void SPIRVInstPrinter::printRemainingVariableOps(const MCInst *MI, void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, unsigned StartIndex, raw_ostream &O) { - const bool IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16; + unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16; const unsigned NumVarOps = MI->getNumOperands() - StartIndex; const unsigned Opcode = MI->getOpcode(); @@ -73,7 +73,7 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, // Handle float constants (OpConstantF) uint64_t Imm = MI->getOperand(StartIndex).getImm(); if (NumVarOps == 2) { - Imm |= static_cast(MI->getOperand(StartIndex + 1).getImm()) << 32; + Imm |= (MI->getOperand(StartIndex + 1).getImm()) << 32; } // For 16-bit floats, print as integer @@ -83,9 +83,8 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, } // Format and print float values - const APFloat FP = (NumVarOps == 1) - ? APFloat(APInt(32, Imm).bitsToFloat()) - : APFloat(APInt(64, Imm).bitsToDouble()); + APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat()) + : APFloat(APInt(64, Imm).bitsToDouble()); // Print infinity and NaN as hex floats. // TODO: Make sure subnormal numbers are handled correctly as they may also diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index c4f8565f39b84..6401bfc7a979e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -343,7 +343,7 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF, return Res; } -Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I, +Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, bool ZeroAsNull) { @@ -353,10 +353,10 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I, if (MI && (MI->getOpcode() == SPIRV::OpConstantNull || MI->getOpcode() == SPIRV::OpConstantI)) return MI->getOperand(0).getReg(); - return createConstInt(CI, Val, I, SpvType, TII, ZeroAsNull); + return createConstInt(CI, I, SpvType, TII, ZeroAsNull); } -Register SPIRVGlobalRegistry::createConstInt(const Constant *CI, APInt Val, +Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, @@ -374,15 +374,15 @@ Register SPIRVGlobalRegistry::createConstInt(const Constant *CI, APInt Val, MachineInstrBuilder MIB; if (BitWidth == 1) { MIB = MIRBuilder - .buildInstr(Val.isZero() ? SPIRV::OpConstantFalse + .buildInstr(CI->isZero() ? SPIRV::OpConstantFalse : SPIRV::OpConstantTrue) .addDef(Res) .addUse(getSPIRVTypeID(SpvType)); - } else if (!Val.isZero() || !ZeroAsNull) { + } else if (!CI->isZero() || !ZeroAsNull) { MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI) .addDef(Res) .addUse(getSPIRVTypeID(SpvType)); - addNumImm(Val, MIB); + addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB); } else { MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) .addDef(Res) diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index 9cb7d982c3fc2..a648defa0a888 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -515,10 +515,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR, bool ZeroAsNull = true); - Register getOrCreateConstInt(APInt Val, MachineInstr &I, SPIRVType *SpvType, - const SPIRVInstrInfo &TII, + Register getOrCreateConstInt(uint64_t Val, MachineInstr &I, + SPIRVType *SpvType, const SPIRVInstrInfo &TII, bool ZeroAsNull = true); - Register createConstInt(const Constant *CI, APInt Val, MachineInstr &I, + Register createConstInt(const ConstantInt *CI, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, bool ZeroAsNull); Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType, diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 3e441d8c4a0cb..14bd922d6ac8b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -2247,27 +2247,25 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion( for (unsigned i = 0; i < 4; i++) { // A[i] Register AElt = MRI->createVirtualRegister(&SPIRV::IDRegClass); - Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) - .addDef(AElt) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(X) - .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, - TII, ZeroAsNull)) - .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, - ZeroAsNull)) - .constrainAllUses(TII, TRI, RBI); + Result &= + BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) + .addDef(AElt) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(X) + .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull)) + .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull)) + .constrainAllUses(TII, TRI, RBI); // B[i] Register BElt = MRI->createVirtualRegister(&SPIRV::IDRegClass); - Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) - .addDef(BElt) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(Y) - .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, - TII, ZeroAsNull)) - .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, - ZeroAsNull)) - .constrainAllUses(TII, TRI, RBI); + Result &= + BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) + .addDef(BElt) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Y) + .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull)) + .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull)) + .constrainAllUses(TII, TRI, RBI); // A[i] * B[i] Register Mul = MRI->createVirtualRegister(&SPIRV::IDRegClass); @@ -2280,15 +2278,14 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion( // Discard 24 highest-bits so that stored i32 register is i8 equivalent Register MaskMul = MRI->createVirtualRegister(&SPIRV::IDRegClass); - Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) - .addDef(MaskMul) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(Mul) - .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, - ZeroAsNull)) - .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, - ZeroAsNull)) - .constrainAllUses(TII, TRI, RBI); + Result &= + BuildMI(BB, I, I.getDebugLoc(), TII.get(ExtractOp)) + .addDef(MaskMul) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Mul) + .addUse(GR.getOrCreateConstInt(0, I, EltType, TII, ZeroAsNull)) + .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull)) + .constrainAllUses(TII, TRI, RBI); // Acc = Acc + A[i] * B[i] Register Sum = @@ -2381,12 +2378,11 @@ bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg, MachineBasicBlock &BB = *I.getParent(); SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); - auto BMI = - BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) - .addDef(ResVReg) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, - IntTy, TII, !STI.isShader())); + auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, + IntTy, TII, !STI.isShader())); for (unsigned J = 2; J < I.getNumOperands(); J++) { BMI.addUse(I.getOperand(J).getReg()); @@ -2405,16 +2401,15 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits( SPIRV::OpGroupNonUniformBallot); MachineBasicBlock &BB = *I.getParent(); - Result &= - BuildMI(BB, I, I.getDebugLoc(), - TII.get(SPIRV::OpGroupNonUniformBallotBitCount)) - .addDef(ResVReg) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, - IntTy, TII, !STI.isShader())) - .addImm(SPIRV::GroupOperation::Reduce) - .addUse(BallotReg) - .constrainAllUses(TII, TRI, RBI); + Result &= BuildMI(BB, I, I.getDebugLoc(), + TII.get(SPIRV::OpGroupNonUniformBallotBitCount)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, + TII, !STI.isShader())) + .addImm(SPIRV::GroupOperation::Reduce) + .addUse(BallotReg) + .constrainAllUses(TII, TRI, RBI); return Result; } @@ -2441,8 +2436,8 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg, return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, - IntTy, TII, !STI.isShader())) + .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII, + !STI.isShader())) .addImm(SPIRV::GroupOperation::Reduce) .addUse(I.getOperand(2).getReg()) .constrainAllUses(TII, TRI, RBI); @@ -2468,8 +2463,8 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg, return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, - IntTy, TII, !STI.isShader())) + .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII, + !STI.isShader())) .addImm(SPIRV::GroupOperation::Reduce) .addUse(I.getOperand(2).getReg()); } @@ -2694,8 +2689,7 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType, bool ZeroAsNull = !STI.isShader(); if (ResType->getOpcode() == SPIRV::OpTypeVector) return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull); - return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), - I, ResType, TII, ZeroAsNull); + return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull); } Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType, @@ -2726,9 +2720,7 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes, AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0); if (ResType->getOpcode() == SPIRV::OpTypeVector) return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII); - return GR.getOrCreateConstInt( - APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, - ResType, TII); + return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII); } bool SPIRVInstructionSelector::selectSelect(Register ResVReg, @@ -2947,8 +2939,7 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg, Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I, ResType, TII, !STI.isShader()); } else { - Reg = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), - I.getOperand(1).getCImm()->getZExtValue()), I, + Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I ResType, TII, !STI.isShader()); } return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I); @@ -3774,8 +3765,7 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow( bool ZeroAsNull = !STI.isShader(); Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type)); Register ConstIntLastIdx = GR.getOrCreateConstInt( - APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, - BaseType, TII, ZeroAsNull); + ComponentCount - 1, I, BaseType, TII, ZeroAsNull); if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx}, SPIRV::OpVectorExtractDynamic)) @@ -3804,11 +3794,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64( SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType); bool ZeroAsNull = !STI.isShader(); Register ConstIntZero = - GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), - I, BaseType, TII, ZeroAsNull); + GR.getOrCreateConstInt(0, I, BaseType, TII, ZeroAsNull); Register ConstIntOne = - GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), - I, BaseType, TII, ZeroAsNull); + GR.getOrCreateConstInt(1, I, BaseType, TII, ZeroAsNull); // SPIRV doesn't support vectors with more than 4 components. Since the // algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only @@ -3892,15 +3880,10 @@ bool SPIRVInstructionSelector::selectFirstBitSet64( unsigned AddOp; if (IsScalarRes) { - NegOneReg = GR.getOrCreateConstInt( - APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, - TII, ZeroAsNull); - Reg0 = - GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), - I, ResType, TII, ZeroAsNull); - Reg32 = - GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), - I, ResType, TII, ZeroAsNull); + NegOneReg = + GR.getOrCreateConstInt((unsigned)-1, I, ResType, TII, ZeroAsNull); + Reg0 = GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull); + Reg32 = GR.getOrCreateConstInt(32, I, ResType, TII, ZeroAsNull); SelectOp = SPIRV::OpSelectSISCond; AddOp = SPIRV::OpIAddS; } else { From 2c34546280ddc45a5d72104ae0f5790661bab1d5 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Fri, 10 Oct 2025 14:45:29 -0700 Subject: [PATCH 09/12] refactor the program --- .../SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp | 2 +- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 28 ++++++++++++++++--- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 5 +++- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 7 ++++- 4 files changed, 35 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp index 59206b443a312..41602b703f88e 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp @@ -73,7 +73,7 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, // Handle float constants (OpConstantF) uint64_t Imm = MI->getOperand(StartIndex).getImm(); if (NumVarOps == 2) { - Imm |= (MI->getOperand(StartIndex + 1).getImm()) << 32; + Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32); } // For 16-bit floats, print as integer diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 6401bfc7a979e..6d6ad3ef830f0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -343,6 +343,20 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF, return Res; } +Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I, + SPIRVType *SpvType, + const SPIRVInstrInfo &TII, + bool ZeroAsNull) { + const IntegerType *Ty = cast(getTypeForSPIRVType(SpvType)); + auto *const CI = ConstantInt::get(const_cast(Ty), Val); + const MachineInstr *MI = findMI(CI, CurMF); + if (MI && (MI->getOpcode() == SPIRV::OpConstantNull || + MI->getOpcode() == SPIRV::OpConstantI)) + return MI->getOperand(0).getReg(); + return createConstInt(CI, I, SpvType, TII, ZeroAsNull); + LLVMContext &Ctx = CurMF->getFunction().getContext(); +} + Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, @@ -354,9 +368,10 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, MI->getOpcode() == SPIRV::OpConstantI)) return MI->getOperand(0).getReg(); return createConstInt(CI, I, SpvType, TII, ZeroAsNull); + LLVMContext &Ctx = CurMF->getFunction().getContext(); } -Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI, +Register SPIRVGlobalRegistry::createConstInt(const Constant *CI, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, @@ -374,15 +389,20 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI, MachineInstrBuilder MIB; if (BitWidth == 1) { MIB = MIRBuilder - .buildInstr(CI->isZero() ? SPIRV::OpConstantFalse + .buildInstr(CI->isZeroValue() ? SPIRV::OpConstantFalse : SPIRV::OpConstantTrue) .addDef(Res) .addUse(getSPIRVTypeID(SpvType)); - } else if (!CI->isZero() || !ZeroAsNull) { + } else if (!CI->isZeroValue() || !ZeroAsNull) { MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI) .addDef(Res) .addUse(getSPIRVTypeID(SpvType)); - addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB); + if (BitWidth <= 64) { + const ConstantInt *CII = dyn_cast(CI); + addNumImm(APInt(BitWidth, CII->getZExtValue()), MIB); + } else { + addNumImm(CI->getUniqueInteger(), MIB); + } } else { MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) .addDef(Res) diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index a648defa0a888..eab9abc18b840 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -515,10 +515,13 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR, bool ZeroAsNull = true); + Register getOrCreateConstInt(APInt Val, MachineInstr &I, + SPIRVType *SpvType, const SPIRVInstrInfo &TII, + bool ZeroAsNull = true); Register getOrCreateConstInt(uint64_t Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, bool ZeroAsNull = true); - Register createConstInt(const ConstantInt *CI, MachineInstr &I, + Register createConstInt(const Constant *CI, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, bool ZeroAsNull); Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType, diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 14bd922d6ac8b..5761b82775c1f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -2939,8 +2939,13 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg, Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I, ResType, TII, !STI.isShader()); } else { - Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I + if(GR.getScalarOrVectorBitWidth(ResType) <= 64) { + Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I, ResType, TII, !STI.isShader()); + } else { + Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, + ResType, TII, !STI.isShader()); + } } return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I); } From f2b798c6858b0b32c959457063e41a932d3f6ad4 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Fri, 10 Oct 2025 15:41:47 -0700 Subject: [PATCH 10/12] code clean up --- .../SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp | 2 +- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 26 +++++++++---------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp index 41602b703f88e..59206b443a312 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp @@ -73,7 +73,7 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, // Handle float constants (OpConstantF) uint64_t Imm = MI->getOperand(StartIndex).getImm(); if (NumVarOps == 2) { - Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32); + Imm |= (MI->getOperand(StartIndex + 1).getImm()) << 32; } // For 16-bit floats, print as integer diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 6d6ad3ef830f0..fabd46dd192f7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -343,18 +343,17 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF, return Res; } -Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I, +Register SPIRVGlobalRegistry::getOrCreateConstInt(const APInt Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, bool ZeroAsNull) { const IntegerType *Ty = cast(getTypeForSPIRVType(SpvType)); - auto *const CI = ConstantInt::get(const_cast(Ty), Val); - const MachineInstr *MI = findMI(CI, CurMF); + Constant *const CA = ConstantInt::get(const_cast(Ty), Val); + const MachineInstr *MI = findMI(CA, CurMF); if (MI && (MI->getOpcode() == SPIRV::OpConstantNull || MI->getOpcode() == SPIRV::OpConstantI)) return MI->getOperand(0).getReg(); - return createConstInt(CI, I, SpvType, TII, ZeroAsNull); - LLVMContext &Ctx = CurMF->getFunction().getContext(); + return createConstInt(CA, I, SpvType, TII, ZeroAsNull); } Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, @@ -362,16 +361,15 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, const SPIRVInstrInfo &TII, bool ZeroAsNull) { const IntegerType *Ty = cast(getTypeForSPIRVType(SpvType)); - auto *const CI = ConstantInt::get(const_cast(Ty), Val); + ConstantInt *const CI = ConstantInt::get(const_cast(Ty), Val); const MachineInstr *MI = findMI(CI, CurMF); if (MI && (MI->getOpcode() == SPIRV::OpConstantNull || MI->getOpcode() == SPIRV::OpConstantI)) return MI->getOperand(0).getReg(); return createConstInt(CI, I, SpvType, TII, ZeroAsNull); - LLVMContext &Ctx = CurMF->getFunction().getContext(); } -Register SPIRVGlobalRegistry::createConstInt(const Constant *CI, +Register SPIRVGlobalRegistry::createConstInt(const Constant *CA, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, @@ -389,19 +387,19 @@ Register SPIRVGlobalRegistry::createConstInt(const Constant *CI, MachineInstrBuilder MIB; if (BitWidth == 1) { MIB = MIRBuilder - .buildInstr(CI->isZeroValue() ? SPIRV::OpConstantFalse + .buildInstr(CA->isZeroValue() ? SPIRV::OpConstantFalse : SPIRV::OpConstantTrue) .addDef(Res) .addUse(getSPIRVTypeID(SpvType)); - } else if (!CI->isZeroValue() || !ZeroAsNull) { + } else if (!CA->isZeroValue() || !ZeroAsNull) { MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI) .addDef(Res) .addUse(getSPIRVTypeID(SpvType)); if (BitWidth <= 64) { - const ConstantInt *CII = dyn_cast(CI); - addNumImm(APInt(BitWidth, CII->getZExtValue()), MIB); + const ConstantInt *CI = dyn_cast(CA); + addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB); } else { - addNumImm(CI->getUniqueInteger(), MIB); + addNumImm(CA->getUniqueInteger(), MIB); } } else { MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) @@ -414,7 +412,7 @@ Register SPIRVGlobalRegistry::createConstInt(const Constant *CI, *ST.getRegBankInfo()); return MIB; }); - add(CI, NewType); + add(CA, NewType); return Res; } From 9d2df70d4dc0eceb0a16e6dfdadcccf2d08b279a Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Sat, 11 Oct 2025 12:17:24 -0700 Subject: [PATCH 11/12] fix the CI failure --- .../SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp | 63 ++++++++++--------- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 5 +- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp index 59206b443a312..5d32a8b175cbe 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp @@ -58,8 +58,15 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, O << ' '; + uint64_t Imm = MI->getOperand(StartIndex).getImm(); + + // Handle 64 bit literals. + if (NumVarOps == 2) { + Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32); + } + // Handle arbitrary number of 32-bit words for integer literals - if (Opcode == SPIRV::OpConstantI) { + if (Opcode == SPIRV::OpConstantI && NumVarOps > 2) { const unsigned TotalBits = NumVarOps * 32; APInt Val(TotalBits, 0); for (unsigned i = 0; i < NumVarOps; ++i) { @@ -70,41 +77,35 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, return; } - // Handle float constants (OpConstantF) - uint64_t Imm = MI->getOperand(StartIndex).getImm(); - if (NumVarOps == 2) { - Imm |= (MI->getOperand(StartIndex + 1).getImm()) << 32; - } + // Format and print float values. + if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) { + APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat()) + : APFloat(APInt(64, Imm).bitsToDouble()); + + // Print infinity and NaN as hex floats. + // TODO: Make sure subnormal numbers are handled correctly as they may also + // require hex float notation. + if (FP.isInfinity()) { + if (FP.isNegative()) + O << '-'; + O << "0x1p+128"; + return; + } + if (FP.isNaN()) { + O << "0x1.8p+128"; + return; + } - // For 16-bit floats, print as integer - if (IsBitwidth16) { - O << Imm; - return; - } + // Format val as a decimal floating point or scientific notation (whichever + // is shorter), with enough digits of precision to produce the exact value. + O << format("%.*g", std::numeric_limits::max_digits10, + FP.convertToDouble()); - // Format and print float values - APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat()) - : APFloat(APInt(64, Imm).bitsToDouble()); - - // Print infinity and NaN as hex floats. - // TODO: Make sure subnormal numbers are handled correctly as they may also - // require hex float notation. - if (FP.isInfinity()) { - if (FP.isNegative()) - O << '-'; - O << "0x1p+128"; - return; - } - - if (FP.isNaN()) { - O << "0x1.8p+128"; return; } - // Format val as a decimal floating point or scientific notation (whichever - // is shorter), with enough digits of precision to produce the exact value. - O << format("%.*g", std::numeric_limits::max_digits10, - FP.convertToDouble()); + // Print integer values directly. + O << Imm; } void SPIRVInstPrinter::recordOpExtInstImport(const MCInst *MI) { diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index fabd46dd192f7..06608cfc07f83 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -509,9 +509,8 @@ Register SPIRVGlobalRegistry::getOrCreateBaseRegister( } assert(Type->getOpcode() == SPIRV::OpTypeInt); SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); - return getOrCreateConstInt( - APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I, SpvBaseType, - TII, ZeroAsNull); + return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I, + SpvBaseType, TII, ZeroAsNull); } Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull( From 9af7da630f4153e16dae5e766128996a5a476375 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Sun, 12 Oct 2025 07:58:22 -0700 Subject: [PATCH 12/12] code clean up --- llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp | 4 ++-- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 7 ++++--- llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 4 ++-- llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 6 +++--- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp index 5d32a8b175cbe..b1c8a68df5d69 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp @@ -52,8 +52,8 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI, const unsigned Opcode = MI->getOpcode(); // We support up to 1024 bits for integers, and 64 bits for floats - assert(((NumVarOps <= 32 && Opcode == SPIRV::OpConstantI) || - (NumVarOps <= 2 && Opcode == SPIRV::OpConstantF)) && + assert(((NumVarOps <= 32 && Opcode == SPIRV::OpConstantI) || + (NumVarOps <= 2 && Opcode == SPIRV::OpConstantF)) && "Unsupported number of operands for constant"); O << ' '; diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 06608cfc07f83..f01d8a34c2ca4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -343,7 +343,8 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF, return Res; } -Register SPIRVGlobalRegistry::getOrCreateConstInt(const APInt Val, MachineInstr &I, +Register SPIRVGlobalRegistry::getOrCreateConstInt(const APInt Val, + MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, bool ZeroAsNull) { @@ -361,7 +362,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, const SPIRVInstrInfo &TII, bool ZeroAsNull) { const IntegerType *Ty = cast(getTypeForSPIRVType(SpvType)); - ConstantInt *const CI = ConstantInt::get(const_cast(Ty), Val); + auto *const CI = ConstantInt::get(const_cast(Ty), Val); const MachineInstr *MI = findMI(CI, CurMF); if (MI && (MI->getOpcode() == SPIRV::OpConstantNull || MI->getOpcode() == SPIRV::OpConstantI)) @@ -388,7 +389,7 @@ Register SPIRVGlobalRegistry::createConstInt(const Constant *CA, if (BitWidth == 1) { MIB = MIRBuilder .buildInstr(CA->isZeroValue() ? SPIRV::OpConstantFalse - : SPIRV::OpConstantTrue) + : SPIRV::OpConstantTrue) .addDef(Res) .addUse(getSPIRVTypeID(SpvType)); } else if (!CA->isZeroValue() || !ZeroAsNull) { diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index eab9abc18b840..60085cd0f337b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -515,8 +515,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR, bool ZeroAsNull = true); - Register getOrCreateConstInt(APInt Val, MachineInstr &I, - SPIRVType *SpvType, const SPIRVInstrInfo &TII, + Register getOrCreateConstInt(APInt Val, MachineInstr &I, SPIRVType *SpvType, + const SPIRVInstrInfo &TII, bool ZeroAsNull = true); Register getOrCreateConstInt(uint64_t Val, MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII, diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 5761b82775c1f..3c3374e5066b4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -2939,12 +2939,12 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg, Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I, ResType, TII, !STI.isShader()); } else { - if(GR.getScalarOrVectorBitWidth(ResType) <= 64) { + if (GR.getScalarOrVectorBitWidth(ResType) <= 64) { Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I, - ResType, TII, !STI.isShader()); + ResType, TII, !STI.isShader()); } else { Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, - ResType, TII, !STI.isShader()); + ResType, TII, !STI.isShader()); } } return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I);