Skip to content

Conversation

YixingZhang007
Copy link
Contributor

@YixingZhang007 YixingZhang007 commented Sep 29, 2025

This PR extends SPIR-V code generation in LLVM to support arbitrary precision integer up to 1024 bits, enabled by the SPV_INTEL_arbitrary_precision_integers extension. More specifically, the following changes are made.

  1. Both getOrCreateConstInt and createConstInt functions in the SPIRVGlobalRegistry pass now accept an APInt Val parameter instead of uint64_t Val. All relevant call sites in SPIRVInstructionSelector.cpp are updated corresponding.
  2. When emitting SPIR-V instructions for integer constants, the value is broken down into 32-bit words, and a loop is used to serialize each chunk, ensuring correct handling for integers up to 1024 bits.

@YixingZhang007 YixingZhang007 marked this pull request as draft September 29, 2025 20:34
@llvmbot
Copy link
Member

llvmbot commented Sep 29, 2025

@llvm/pr-subscribers-backend-spir-v

Author: None (YixingZhang007)

Changes

spirv-backend


Full diff: https://github.com/llvm/llvm-project/pull/161270.diff

6 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp (+12-6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+9-8)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+2-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+19-20)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+6-5)
  • (modified) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll (+4)
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index 776208bd3e693..dff9f699ebd6f 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -50,18 +50,24 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
   unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16;
   const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
 
-  assert((NumVarOps == 1 || NumVarOps == 2) &&
+  // we support integer up to 1024 bits
+  assert((NumVarOps <= 1024) &&
          "Unsupported number of bits for literal variable");
 
   O << ' ';
 
-  uint64_t Imm = MI->getOperand(StartIndex).getImm();
-
-  // Handle 64 bit literals.
-  if (NumVarOps == 2) {
-    Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32);
+  // Handle arbitrary number of 32-bit words for the literal value.
+  if (MI->getOpcode() == SPIRV::OpConstantI){
+    APInt Val(NumVarOps * 32, 0);
+    for (unsigned i = 0; i < NumVarOps; ++i) {
+      Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) << (i * 32));
+    }
+    O << Val;
+    return;
   }
 
+  uint64_t Imm = MI->getOperand(StartIndex).getImm();
+
   // Format and print float values.
   if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) {
     APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 115766ce886c7..05b3371e97cdc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -149,7 +149,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
 }
 
 unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
-  if (Width > 64)
+  if (Width > 1024)
     report_fatal_error("Unsupported integer width!");
   const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
   if (ST.canUseExtension(
@@ -343,7 +343,7 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF,
   return Res;
 }
 
-Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
+Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I,
                                                   SPIRVType *SpvType,
                                                   const SPIRVInstrInfo &TII,
                                                   bool ZeroAsNull) {
@@ -353,10 +353,11 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
   if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
              MI->getOpcode() == SPIRV::OpConstantI))
     return MI->getOperand(0).getReg();
-  return createConstInt(CI, I, SpvType, TII, ZeroAsNull);
+  return createConstInt(CI, Val, I, SpvType, TII, ZeroAsNull);
 }
 
-Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
+Register SPIRVGlobalRegistry::createConstInt(const Constant *CI,
+                                             APInt Val,
                                              MachineInstr &I,
                                              SPIRVType *SpvType,
                                              const SPIRVInstrInfo &TII,
@@ -374,15 +375,15 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
         MachineInstrBuilder MIB;
         if (BitWidth == 1) {
           MIB = MIRBuilder
-                    .buildInstr(CI->isZero() ? SPIRV::OpConstantFalse
+                    .buildInstr(Val.isZero() ? SPIRV::OpConstantFalse
                                              : SPIRV::OpConstantTrue)
                     .addDef(Res)
                     .addUse(getSPIRVTypeID(SpvType));
-        } else if (!CI->isZero() || !ZeroAsNull) {
+        } else if (!Val.isZero() || !ZeroAsNull) {
           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
                     .addDef(Res)
                     .addUse(getSPIRVTypeID(SpvType));
-          addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
+          addNumImm(Val, MIB);
         } else {
           MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
                     .addDef(Res)
@@ -491,7 +492,7 @@ Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
   }
   assert(Type->getOpcode() == SPIRV::OpTypeInt);
   SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
-  return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I,
+  return getOrCreateConstInt(APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I,
                              SpvBaseType, TII, ZeroAsNull);
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index a648defa0a888..ee217f81fb416 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -515,10 +515,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
                             SPIRVType *SpvType, bool EmitIR,
                             bool ZeroAsNull = true);
-  Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
+  Register getOrCreateConstInt(APInt Val, MachineInstr &I,
                                SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                                bool ZeroAsNull = true);
-  Register createConstInt(const ConstantInt *CI, MachineInstr &I,
+  Register createConstInt(const Constant *CI, APInt Val, MachineInstr &I,
                           SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                           bool ZeroAsNull);
   Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 1aadd9df189a8..3e5566945ec0b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2252,8 +2252,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
             .addDef(AElt)
             .addUse(GR.getSPIRVTypeID(ResType))
             .addUse(X)
-            .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
             .constrainAllUses(TII, TRI, RBI);
 
     // B[i]
@@ -2263,8 +2263,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
             .addDef(BElt)
             .addUse(GR.getSPIRVTypeID(ResType))
             .addUse(Y)
-            .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
             .constrainAllUses(TII, TRI, RBI);
 
     // A[i] * B[i]
@@ -2283,8 +2283,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
             .addDef(MaskMul)
             .addUse(GR.getSPIRVTypeID(ResType))
             .addUse(Mul)
-            .addUse(GR.getOrCreateConstInt(0, I, EltType, TII, ZeroAsNull))
-            .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, ZeroAsNull))
+            .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
             .constrainAllUses(TII, TRI, RBI);
 
     // Acc = Acc + A[i] * B[i]
@@ -2381,7 +2381,7 @@ bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg,
   auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
                  .addDef(ResVReg)
                  .addUse(GR.getSPIRVTypeID(ResType))
-                 .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I,
+                 .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
                                                 IntTy, TII, !STI.isShader()));
 
   for (unsigned J = 2; J < I.getNumOperands(); J++) {
@@ -2405,7 +2405,7 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
                     TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
                 .addDef(ResVReg)
                 .addUse(GR.getSPIRVTypeID(ResType))
-                .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy,
+                .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy,
                                                TII, !STI.isShader()))
                 .addImm(SPIRV::GroupOperation::Reduce)
                 .addUse(BallotReg)
@@ -2436,7 +2436,7 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
   return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
-      .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
+      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
                                      !STI.isShader()))
       .addImm(SPIRV::GroupOperation::Reduce)
       .addUse(I.getOperand(2).getReg())
@@ -2463,7 +2463,7 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg,
   return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
-      .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
+      .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
                                      !STI.isShader()))
       .addImm(SPIRV::GroupOperation::Reduce)
       .addUse(I.getOperand(2).getReg());
@@ -2689,7 +2689,7 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
   bool ZeroAsNull = !STI.isShader();
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
     return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull);
-  return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
+  return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
 }
 
 Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
@@ -2720,7 +2720,7 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
       AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0);
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
     return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII);
-  return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII);
+  return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, ResType, TII);
 }
 
 bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
@@ -2939,8 +2939,7 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg,
     Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I,
                                 ResType, TII, !STI.isShader());
   } else {
-    Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I,
-                                 ResType, TII, !STI.isShader());
+    Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, ResType, TII, !STI.isShader());
   }
   return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I);
 }
@@ -3765,7 +3764,7 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
     bool ZeroAsNull = !STI.isShader();
     Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type));
     Register ConstIntLastIdx = GR.getOrCreateConstInt(
-        ComponentCount - 1, I, BaseType, TII, ZeroAsNull);
+        APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, BaseType, TII, ZeroAsNull);
 
     if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx},
                           SPIRV::OpVectorExtractDynamic))
@@ -3794,9 +3793,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
   SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
   bool ZeroAsNull = !STI.isShader();
   Register ConstIntZero =
-      GR.getOrCreateConstInt(0, I, BaseType, TII, ZeroAsNull);
+      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), I, BaseType, TII, ZeroAsNull);
   Register ConstIntOne =
-      GR.getOrCreateConstInt(1, I, BaseType, TII, ZeroAsNull);
+      GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), I, BaseType, TII, ZeroAsNull);
 
   // SPIRV doesn't support vectors with more than 4 components. Since the
   // algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only
@@ -3881,9 +3880,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
 
   if (IsScalarRes) {
     NegOneReg =
-        GR.getOrCreateConstInt((unsigned)-1, I, ResType, TII, ZeroAsNull);
-    Reg0 = GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
-    Reg32 = GR.getOrCreateConstInt(32, I, ResType, TII, ZeroAsNull);
+        GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, TII, ZeroAsNull);
+    Reg0 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
+    Reg32 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), I, ResType, TII, ZeroAsNull);
     SelectOp = SPIRV::OpSelectSISCond;
     AddOp = SPIRV::OpIAddS;
   } else {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 820e56b362edc..e409234a83568 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -100,11 +100,12 @@ void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
     if (Bitwidth == 16)
       MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
     return;
-  } else if (Bitwidth <= 64) {
-    uint64_t FullImm = Imm.getZExtValue();
-    uint32_t LowBits = FullImm & 0xffffffff;
-    uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
-    MIB.addImm(LowBits).addImm(HighBits);
+  } else if (Bitwidth <= 1024) {
+    unsigned NumWords = (Bitwidth + 31) / 32;
+    for (unsigned i = 0; i < NumWords; ++i) {
+      uint32_t Word = Imm.extractBits(32, i * 32).getZExtValue();
+      MIB.addImm(Word);
+    }
     return;
   }
   report_fatal_error("Unsupported constant bitwidth");
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
index 41d4b58ed1157..17ba9b044842c 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
@@ -8,6 +8,10 @@ define i13 @getConstantI13() {
   ret i13 42
 }
 
+define i96 @getConstantI96() {
+  ret i96 18446744073709551620
+} 
+
 ;; Capabilities:
 ; CHECK-DAG: OpExtension "SPV_INTEL_arbitrary_precision_integers"
 ; CHECK-DAG: OpCapability ArbitraryPrecisionIntegersINTEL

Copy link

github-actions bot commented Sep 29, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@YixingZhang007 YixingZhang007 changed the title Add support for arbitrary integer with bitwidth larger than 64 bits in [SPIRV] Add support for arbitrary-precision integers larger than 64 bits in SPIR-V backend Sep 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants