-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[SPIRV] Add support for arbitrary-precision integers larger than 64 bits in SPIR-V backend #161270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
YixingZhang007
wants to merge
12
commits into
llvm:main
Choose a base branch
from
YixingZhang007:add_support_arbitary_precision
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
[SPIRV] Add support for arbitrary-precision integers larger than 64 bits in SPIR-V backend #161270
YixingZhang007
wants to merge
12
commits into
llvm:main
from
YixingZhang007:add_support_arbitary_precision
+82
−17
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-backend-spir-v Author: None (YixingZhang007) Changesspirv-backend Full diff: https://github.com/llvm/llvm-project/pull/161270.diff 6 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index 776208bd3e693..dff9f699ebd6f 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -50,18 +50,24 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16;
const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
- assert((NumVarOps == 1 || NumVarOps == 2) &&
+ // we support integer up to 1024 bits
+ assert((NumVarOps <= 1024) &&
"Unsupported number of bits for literal variable");
O << ' ';
- uint64_t Imm = MI->getOperand(StartIndex).getImm();
-
- // Handle 64 bit literals.
- if (NumVarOps == 2) {
- Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32);
+ // Handle arbitrary number of 32-bit words for the literal value.
+ if (MI->getOpcode() == SPIRV::OpConstantI){
+ APInt Val(NumVarOps * 32, 0);
+ for (unsigned i = 0; i < NumVarOps; ++i) {
+ Val |= (APInt(NumVarOps * 32, MI->getOperand(StartIndex + i).getImm()) << (i * 32));
+ }
+ O << Val;
+ return;
}
+ uint64_t Imm = MI->getOperand(StartIndex).getImm();
+
// Format and print float values.
if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) {
APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 115766ce886c7..05b3371e97cdc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -149,7 +149,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
}
unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
- if (Width > 64)
+ if (Width > 1024)
report_fatal_error("Unsupported integer width!");
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
if (ST.canUseExtension(
@@ -343,7 +343,7 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF,
return Res;
}
-Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
+Register SPIRVGlobalRegistry::getOrCreateConstInt(APInt Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
bool ZeroAsNull) {
@@ -353,10 +353,11 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
MI->getOpcode() == SPIRV::OpConstantI))
return MI->getOperand(0).getReg();
- return createConstInt(CI, I, SpvType, TII, ZeroAsNull);
+ return createConstInt(CI, Val, I, SpvType, TII, ZeroAsNull);
}
-Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
+Register SPIRVGlobalRegistry::createConstInt(const Constant *CI,
+ APInt Val,
MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
@@ -374,15 +375,15 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
MachineInstrBuilder MIB;
if (BitWidth == 1) {
MIB = MIRBuilder
- .buildInstr(CI->isZero() ? SPIRV::OpConstantFalse
+ .buildInstr(Val.isZero() ? SPIRV::OpConstantFalse
: SPIRV::OpConstantTrue)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
- } else if (!CI->isZero() || !ZeroAsNull) {
+ } else if (!Val.isZero() || !ZeroAsNull) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
- addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
+ addNumImm(Val, MIB);
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
@@ -491,7 +492,7 @@ Register SPIRVGlobalRegistry::getOrCreateBaseRegister(
}
assert(Type->getOpcode() == SPIRV::OpTypeInt);
SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
- return getOrCreateConstInt(Val->getUniqueInteger().getZExtValue(), I,
+ return getOrCreateConstInt(APInt(BitWidth, Val->getUniqueInteger().getZExtValue()), I,
SpvBaseType, TII, ZeroAsNull);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index a648defa0a888..ee217f81fb416 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -515,10 +515,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR,
bool ZeroAsNull = true);
- Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
+ Register getOrCreateConstInt(APInt Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
- Register createConstInt(const ConstantInt *CI, MachineInstr &I,
+ Register createConstInt(const Constant *CI, APInt Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull);
Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 1aadd9df189a8..3e5566945ec0b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2252,8 +2252,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
.addDef(AElt)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(X)
- .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
- .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
.constrainAllUses(TII, TRI, RBI);
// B[i]
@@ -2263,8 +2263,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
.addDef(BElt)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Y)
- .addUse(GR.getOrCreateConstInt(i * 8, I, EltType, TII, ZeroAsNull))
- .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, i * 8), I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
.constrainAllUses(TII, TRI, RBI);
// A[i] * B[i]
@@ -2283,8 +2283,8 @@ bool SPIRVInstructionSelector::selectDot4AddPackedExpansion(
.addDef(MaskMul)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Mul)
- .addUse(GR.getOrCreateConstInt(0, I, EltType, TII, ZeroAsNull))
- .addUse(GR.getOrCreateConstInt(8, I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 0), I, EltType, TII, ZeroAsNull))
+ .addUse(GR.getOrCreateConstInt(APInt(8, 8), I, EltType, TII, ZeroAsNull))
.constrainAllUses(TII, TRI, RBI);
// Acc = Acc + A[i] * B[i]
@@ -2381,7 +2381,7 @@ bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg,
auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I,
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I,
IntTy, TII, !STI.isShader()));
for (unsigned J = 2; J < I.getNumOperands(); J++) {
@@ -2405,7 +2405,7 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy,
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy,
TII, !STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(BallotReg)
@@ -2436,7 +2436,7 @@ bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
!STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(I.getOperand(2).getReg())
@@ -2463,7 +2463,7 @@ bool SPIRVInstructionSelector::selectWaveReduceSum(Register ResVReg,
return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
+ .addUse(GR.getOrCreateConstInt(APInt(32, SPIRV::Scope::Subgroup), I, IntTy, TII,
!STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(I.getOperand(2).getReg());
@@ -2689,7 +2689,7 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
bool ZeroAsNull = !STI.isShader();
if (ResType->getOpcode() == SPIRV::OpTypeVector)
return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull);
- return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
+ return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
}
Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
@@ -2720,7 +2720,7 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0);
if (ResType->getOpcode() == SPIRV::OpTypeVector)
return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII);
- return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII);
+ return GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), One.getZExtValue()), I, ResType, TII);
}
bool SPIRVInstructionSelector::selectSelect(Register ResVReg,
@@ -2939,8 +2939,7 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg,
Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I,
ResType, TII, !STI.isShader());
} else {
- Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I,
- ResType, TII, !STI.isShader());
+ Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I, ResType, TII, !STI.isShader());
}
return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I);
}
@@ -3765,7 +3764,7 @@ bool SPIRVInstructionSelector::selectFirstBitSet64Overflow(
bool ZeroAsNull = !STI.isShader();
Register FinalElemReg = MRI->createVirtualRegister(GR.getRegClass(I64Type));
Register ConstIntLastIdx = GR.getOrCreateConstInt(
- ComponentCount - 1, I, BaseType, TII, ZeroAsNull);
+ APInt(GR.getScalarOrVectorBitWidth(BaseType), ComponentCount - 1), I, BaseType, TII, ZeroAsNull);
if (!selectOpWithSrcs(FinalElemReg, I64Type, I, {SrcReg, ConstIntLastIdx},
SPIRV::OpVectorExtractDynamic))
@@ -3794,9 +3793,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
SPIRVType *BaseType = GR.retrieveScalarOrVectorIntType(ResType);
bool ZeroAsNull = !STI.isShader();
Register ConstIntZero =
- GR.getOrCreateConstInt(0, I, BaseType, TII, ZeroAsNull);
+ GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 0), I, BaseType, TII, ZeroAsNull);
Register ConstIntOne =
- GR.getOrCreateConstInt(1, I, BaseType, TII, ZeroAsNull);
+ GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(BaseType), 1), I, BaseType, TII, ZeroAsNull);
// SPIRV doesn't support vectors with more than 4 components. Since the
// algoritm below converts i64 -> i32x2 and i64x4 -> i32x8 it can only
@@ -3881,9 +3880,9 @@ bool SPIRVInstructionSelector::selectFirstBitSet64(
if (IsScalarRes) {
NegOneReg =
- GR.getOrCreateConstInt((unsigned)-1, I, ResType, TII, ZeroAsNull);
- Reg0 = GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
- Reg32 = GR.getOrCreateConstInt(32, I, ResType, TII, ZeroAsNull);
+ GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), (unsigned)-1), I, ResType, TII, ZeroAsNull);
+ Reg0 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 0), I, ResType, TII, ZeroAsNull);
+ Reg32 = GR.getOrCreateConstInt(APInt(GR.getScalarOrVectorBitWidth(ResType), 32), I, ResType, TII, ZeroAsNull);
SelectOp = SPIRV::OpSelectSISCond;
AddOp = SPIRV::OpIAddS;
} else {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 820e56b362edc..e409234a83568 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -100,11 +100,12 @@ void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
if (Bitwidth == 16)
MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
return;
- } else if (Bitwidth <= 64) {
- uint64_t FullImm = Imm.getZExtValue();
- uint32_t LowBits = FullImm & 0xffffffff;
- uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
- MIB.addImm(LowBits).addImm(HighBits);
+ } else if (Bitwidth <= 1024) {
+ unsigned NumWords = (Bitwidth + 31) / 32;
+ for (unsigned i = 0; i < NumWords; ++i) {
+ uint32_t Word = Imm.extractBits(32, i * 32).getZExtValue();
+ MIB.addImm(Word);
+ }
return;
}
report_fatal_error("Unsupported constant bitwidth");
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
index 41d4b58ed1157..17ba9b044842c 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_arbitrary_precision_integers.ll
@@ -8,6 +8,10 @@ define i13 @getConstantI13() {
ret i13 42
}
+define i96 @getConstantI96() {
+ ret i96 18446744073709551620
+}
+
;; Capabilities:
; CHECK-DAG: OpExtension "SPV_INTEL_arbitrary_precision_integers"
; CHECK-DAG: OpCapability ArbitraryPrecisionIntegersINTEL
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR extends SPIR-V code generation in LLVM to support arbitrary precision integer up to 1024 bits, enabled by the
SPV_INTEL_arbitrary_precision_integers
extension. More specifically, the following changes are made.getOrCreateConstInt
andcreateConstInt
functions in theSPIRVGlobalRegistry
pass now accept anAPInt Val
parameter instead ofuint64_t Val
. All relevant call sites inSPIRVInstructionSelector.cpp
are updated corresponding.