Skip to content
19 changes: 17 additions & 2 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
raw_ostream &O) {
unsigned IsBitwidth16 = MI->getFlags() & SPIRV::INST_PRINTER_WIDTH16;
const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
const unsigned Opcode = MI->getOpcode();

assert((NumVarOps == 1 || NumVarOps == 2) &&
"Unsupported number of bits for literal variable");
// We support up to 1024 bits for integers, and 64 bits for floats
assert(((NumVarOps <= 32 && Opcode == SPIRV::OpConstantI) ||
(NumVarOps <= 2 && Opcode == SPIRV::OpConstantF)) &&
"Unsupported number of operands for constant");

O << ' ';

Expand All @@ -62,6 +65,18 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
Imm |= (MI->getOperand(StartIndex + 1).getImm() << 32);
}

// Handle arbitrary number of 32-bit words for integer literals
if (Opcode == SPIRV::OpConstantI && NumVarOps > 2) {
const unsigned TotalBits = NumVarOps * 32;
APInt Val(TotalBits, 0);
for (unsigned i = 0; i < NumVarOps; ++i) {
uint64_t Word = MI->getOperand(StartIndex + i).getImm();
Val |= APInt(TotalBits, Word) << (i * 32);
}
O << Val;
return;
}

// Format and print float values.
if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) {
APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat())
Expand Down
33 changes: 26 additions & 7 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
}

unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
if (Width > 64)
if (Width > 1024)
report_fatal_error("Unsupported integer width!");
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
if (ST.canUseExtension(
Expand Down Expand Up @@ -343,6 +343,20 @@ Register SPIRVGlobalRegistry::createConstFP(const ConstantFP *CF,
return Res;
}

Register SPIRVGlobalRegistry::getOrCreateConstInt(const APInt Val,
MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
bool ZeroAsNull) {
const IntegerType *Ty = cast<IntegerType>(getTypeForSPIRVType(SpvType));
Constant *const CA = ConstantInt::get(const_cast<IntegerType *>(Ty), Val);
const MachineInstr *MI = findMI(CA, CurMF);
if (MI && (MI->getOpcode() == SPIRV::OpConstantNull ||
MI->getOpcode() == SPIRV::OpConstantI))
return MI->getOperand(0).getReg();
return createConstInt(CA, I, SpvType, TII, ZeroAsNull);
}

Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
Expand All @@ -356,7 +370,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
return createConstInt(CI, I, SpvType, TII, ZeroAsNull);
}

Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
Register SPIRVGlobalRegistry::createConstInt(const Constant *CA,
MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
Expand All @@ -374,15 +388,20 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
MachineInstrBuilder MIB;
if (BitWidth == 1) {
MIB = MIRBuilder
.buildInstr(CI->isZero() ? SPIRV::OpConstantFalse
: SPIRV::OpConstantTrue)
.buildInstr(CA->isZeroValue() ? SPIRV::OpConstantFalse
: SPIRV::OpConstantTrue)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
} else if (!CI->isZero() || !ZeroAsNull) {
} else if (!CA->isZeroValue() || !ZeroAsNull) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
if (BitWidth <= 64) {
const ConstantInt *CI = dyn_cast<ConstantInt>(CA);
addNumImm(APInt(BitWidth, CI->getZExtValue()), MIB);
} else {
addNumImm(CA->getUniqueInteger(), MIB);
}
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
Expand All @@ -394,7 +413,7 @@ Register SPIRVGlobalRegistry::createConstInt(const ConstantInt *CI,
*ST.getRegBankInfo());
return MIB;
});
add(CI, NewType);
add(CA, NewType);
return Res;
}

Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -515,10 +515,13 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR,
bool ZeroAsNull = true);
Register getOrCreateConstInt(APInt Val, MachineInstr &I, SPIRVType *SpvType,
const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
Register createConstInt(const ConstantInt *CI, MachineInstr &I,
Register createConstInt(const Constant *CI, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull);
Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType,
Expand Down
9 changes: 7 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2939,8 +2939,13 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg,
Reg = GR.getOrCreateConstFP(I.getOperand(1).getFPImm()->getValue(), I,
ResType, TII, !STI.isShader());
} else {
Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I,
ResType, TII, !STI.isShader());
if (GR.getScalarOrVectorBitWidth(ResType) <= 64) {
Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getZExtValue(), I,
ResType, TII, !STI.isShader());
} else {
Reg = GR.getOrCreateConstInt(I.getOperand(1).getCImm()->getValue(), I,
ResType, TII, !STI.isShader());
}
}
return Reg == ResVReg ? true : BuildCOPY(ResVReg, Reg, I);
}
Expand Down
11 changes: 6 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,12 @@ void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
if (Bitwidth == 16)
MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16);
return;
} else if (Bitwidth <= 64) {
uint64_t FullImm = Imm.getZExtValue();
uint32_t LowBits = FullImm & 0xffffffff;
uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
MIB.addImm(LowBits).addImm(HighBits);
} else if (Bitwidth <= 1024) {
unsigned NumWords = (Bitwidth + 31) / 32;
for (unsigned i = 0; i < NumWords; ++i) {
uint32_t Word = Imm.extractBits(32, i * 32).getZExtValue();
MIB.addImm(Word);
}
return;
}
report_fatal_error("Unsupported constant bitwidth");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ define i13 @getConstantI13() {
ret i13 42
}

define i96 @getConstantI96() {
ret i96 18446744073709551620
}

define i160 @getConstantI160() {
ret i160 3363637389930338837376336738763689377839373638
}

;; Capabilities:
; CHECK-DAG: OpExtension "SPV_INTEL_arbitrary_precision_integers"
; CHECK-DAG: OpCapability ArbitraryPrecisionIntegersINTEL
Expand All @@ -17,14 +25,20 @@ define i13 @getConstantI13() {
;; Names:
; CHECK-DAG: OpName %[[#GET_I6:]] "getConstantI6"
; CHECK-DAG: OpName %[[#GET_I13:]] "getConstantI13"
; CHECK-DAG: OpName %[[#GET_I96:]] "getConstantI96"
; CHECK-DAG: OpName %[[#GET_I160:]] "getConstantI160"

; CHECK-NOT: DAG-FENCE

;; Types and Constants:
; CHECK-DAG: %[[#I6:]] = OpTypeInt 6 0
; CHECK-DAG: %[[#I13:]] = OpTypeInt 13 0
; CHECK-DAG: %[[#I96:]] = OpTypeInt 96 0
; CHECK-DAG: %[[#I160:]] = OpTypeInt 160 0
; CHECK-DAG: %[[#CST_I6:]] = OpConstant %[[#I6]] 2
; CHECK-DAG: %[[#CST_I13:]] = OpConstant %[[#I13]] 42
; CHECK-DAG: %[[#CST_I96:]] = OpConstant %[[#I96]] 18446744073709551620
; CHECK-DAG: %[[#CST_I160:]] = OpConstant %[[#I160]] 3363637389930338837376336738763689377839373638

; CHECK: %[[#GET_I6]] = OpFunction %[[#I6]]
; CHECK: OpReturnValue %[[#CST_I6]]
Expand All @@ -33,3 +47,11 @@ define i13 @getConstantI13() {
; CHECK: %[[#GET_I13]] = OpFunction %[[#I13]]
; CHECK: OpReturnValue %[[#CST_I13]]
; CHECK: OpFunctionEnd

; CHECK: %[[#GET_I96]] = OpFunction %[[#I96]]
; CHECK: OpReturnValue %[[#CST_I96]]
; CHECK: OpFunctionEnd

; CHECK: %[[#GET_I160]] = OpFunction %[[#I160]]
; CHECK: OpReturnValue %[[#CST_I160]]
; CHECK: OpFunctionEnd