Skip to content

Commit 1d02de2

Browse files
authored
[SPIRV] Implement translation for llvm.modf.* intrinsics (#147556)
Based on KhronosGroup/SPIRV-LLVM-Translator#3100, I'm adding translation for `llvm.modf.*` intrinsics.
1 parent 77f0a7d commit 1d02de2

File tree

4 files changed

+173
-0
lines changed

4 files changed

+173
-0
lines changed

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
296296
bool selectImageWriteIntrinsic(MachineInstr &I) const;
297297
bool selectResourceGetPointer(Register &ResVReg, const SPIRVType *ResType,
298298
MachineInstr &I) const;
299+
bool selectModf(Register ResVReg, const SPIRVType *ResType,
300+
MachineInstr &I) const;
299301

300302
// Utilities
301303
std::pair<Register, bool>
@@ -3235,6 +3237,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
32353237
case Intrinsic::spv_discard: {
32363238
return selectDiscard(ResVReg, ResType, I);
32373239
}
3240+
case Intrinsic::modf: {
3241+
return selectModf(ResVReg, ResType, I);
3242+
}
32383243
default: {
32393244
std::string DiagMsg;
32403245
raw_string_ostream OS(DiagMsg);
@@ -4018,6 +4023,83 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
40184023
.constrainAllUses(TII, TRI, RBI);
40194024
}
40204025

4026+
bool SPIRVInstructionSelector::selectModf(Register ResVReg,
4027+
const SPIRVType *ResType,
4028+
MachineInstr &I) const {
4029+
// llvm.modf has a single arg --the number to be decomposed-- and returns a
4030+
// struct { restype, restype }, while OpenCLLIB::modf has two args --the
4031+
// number to be decomposed and a pointer--, returns the fractional part and
4032+
// the integral part is stored in the pointer argument. Therefore, we can't
4033+
// use directly the OpenCLLIB::modf intrinsic. However, we can do some
4034+
// scaffolding to make it work. The idea is to create an alloca instruction
4035+
// to get a ptr, pass this ptr to OpenCL::modf, and then load the value
4036+
// from this ptr to place it in the struct. llvm.modf returns the fractional
4037+
// part as the first element of the result, and the integral part as the
4038+
// second element of the result.
4039+
4040+
// At this point, the return type is not a struct anymore, but rather two
4041+
// independent elements of SPIRVResType. We can get each independent element
4042+
// from I.getDefs() or I.getOperands().
4043+
if (STI.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
4044+
MachineIRBuilder MIRBuilder(I);
4045+
// Get pointer type for alloca variable.
4046+
const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
4047+
ResType, MIRBuilder, SPIRV::StorageClass::Function);
4048+
// Create new register for the pointer type of alloca variable.
4049+
Register PtrTyReg =
4050+
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
4051+
MIRBuilder.getMRI()->setType(
4052+
PtrTyReg,
4053+
LLT::pointer(storageClassToAddressSpace(SPIRV::StorageClass::Function),
4054+
GR.getPointerSize()));
4055+
// Assign SPIR-V type of the pointer type of the alloca variable to the
4056+
// new register.
4057+
GR.assignSPIRVTypeToVReg(PtrType, PtrTyReg, MIRBuilder.getMF());
4058+
MachineBasicBlock &EntryBB = I.getMF()->front();
4059+
MachineBasicBlock::iterator VarPos =
4060+
getFirstValidInstructionInsertPoint(EntryBB);
4061+
auto AllocaMIB =
4062+
BuildMI(EntryBB, VarPos, I.getDebugLoc(), TII.get(SPIRV::OpVariable))
4063+
.addDef(PtrTyReg)
4064+
.addUse(GR.getSPIRVTypeID(PtrType))
4065+
.addImm(static_cast<uint32_t>(SPIRV::StorageClass::Function));
4066+
Register Variable = AllocaMIB->getOperand(0).getReg();
4067+
// Modf must have 4 operands, the first two are the 2 parts of the result,
4068+
// the third is the operand, and the last one is the floating point value.
4069+
assert(I.getNumOperands() == 4 &&
4070+
"Expected 4 operands for modf instruction");
4071+
MachineBasicBlock &BB = *I.getParent();
4072+
// Create the OpenCLLIB::modf instruction.
4073+
auto MIB =
4074+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
4075+
.addDef(ResVReg)
4076+
.addUse(GR.getSPIRVTypeID(ResType))
4077+
.addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
4078+
.addImm(CL::modf)
4079+
.setMIFlags(I.getFlags())
4080+
.add(I.getOperand(3)) // Floating point value.
4081+
.addUse(Variable); // Pointer to integral part.
4082+
// Assign the integral part stored in the ptr to the second element of the
4083+
// result.
4084+
Register IntegralPartReg = I.getOperand(1).getReg();
4085+
if (IntegralPartReg.isValid()) {
4086+
// Load the value from the pointer to integral part.
4087+
auto LoadMIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLoad))
4088+
.addDef(IntegralPartReg)
4089+
.addUse(GR.getSPIRVTypeID(ResType))
4090+
.addUse(Variable);
4091+
return LoadMIB.constrainAllUses(TII, TRI, RBI);
4092+
}
4093+
4094+
return MIB.constrainAllUses(TII, TRI, RBI);
4095+
} else if (STI.canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450)) {
4096+
assert(false && "GLSL::Modf is deprecated.");
4097+
// FIXME: GL::Modf is deprecated, use Modfstruct instead.
4098+
return false;
4099+
}
4100+
return false;
4101+
}
4102+
40214103
// Generate the instructions to load 3-element vector builtin input
40224104
// IDs/Indices.
40234105
// Like: GlobalInvocationId, LocalInvocationId, etc....

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,4 +995,27 @@ unsigned getArrayComponentCount(const MachineRegisterInfo *MRI,
995995
return foldImm(ResType->getOperand(2), MRI);
996996
}
997997

998+
MachineBasicBlock::iterator
999+
getFirstValidInstructionInsertPoint(MachineBasicBlock &BB) {
1000+
// Find the position to insert the OpVariable instruction.
1001+
// We will insert it after the last OpFunctionParameter, if any, or
1002+
// after OpFunction otherwise.
1003+
MachineBasicBlock::iterator VarPos = BB.begin();
1004+
while (VarPos != BB.end() && VarPos->getOpcode() != SPIRV::OpFunction) {
1005+
++VarPos;
1006+
}
1007+
// Advance VarPos to the next instruction after OpFunction, it will either
1008+
// be an OpFunctionParameter, so that we can start the next loop, or the
1009+
// position to insert the OpVariable instruction.
1010+
++VarPos;
1011+
while (VarPos != BB.end() &&
1012+
VarPos->getOpcode() == SPIRV::OpFunctionParameter) {
1013+
++VarPos;
1014+
}
1015+
// VarPos is now pointing at after the last OpFunctionParameter, if any,
1016+
// or after OpFunction, if no parameters.
1017+
return VarPos != BB.end() && VarPos->getOpcode() == SPIRV::OpLabel ? ++VarPos
1018+
: VarPos;
1019+
}
1020+
9981021
} // namespace llvm

llvm/lib/Target/SPIRV/SPIRVUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,8 @@ MachineInstr *getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI);
506506
int64_t foldImm(const MachineOperand &MO, const MachineRegisterInfo *MRI);
507507
unsigned getArrayComponentCount(const MachineRegisterInfo *MRI,
508508
const MachineInstr *ResType);
509+
MachineBasicBlock::iterator
510+
getFirstValidInstructionInsertPoint(MachineBasicBlock &BB);
509511

510512
} // namespace llvm
511513
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H

llvm/test/CodeGen/SPIRV/llvm-intrinsics/fp-intrinsics.ll

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
23

34
; CHECK: %[[#extinst_id:]] = OpExtInstImport "OpenCL.std"
45

@@ -337,3 +338,68 @@ entry:
337338
}
338339

339340
declare float @llvm.fma.f32(float, float, float)
341+
342+
; CHECK: OpFunction
343+
; CHECK: %[[#d:]] = OpFunctionParameter %[[#]]
344+
; CHECK: %[[#fracPtr:]] = OpFunctionParameter %[[#]]
345+
; CHECK: %[[#integralPtr:]] = OpFunctionParameter %[[#]]
346+
; CHECK: %[[#varPtr:]] = OpVariable %[[#]] Function
347+
; CHECK: %[[#frac:]] = OpExtInst %[[#var2]] %[[#extinst_id]] modf %[[#d]] %[[#varPtr]]
348+
; CHECK: %[[#integral:]] = OpLoad %[[#var2]] %[[#varPtr]]
349+
; CHECK: OpStore %[[#fracPtr]] %[[#frac]]
350+
; CHECK: OpStore %[[#integralPtr]] %[[#integral]]
351+
; CHECK: OpFunctionEnd
352+
define void @TestModf(double %d, ptr addrspace(1) %frac, ptr addrspace(1) %integral) {
353+
entry:
354+
%4 = tail call { double, double } @llvm.modf.f64(double %d)
355+
%5 = extractvalue { double, double } %4, 0
356+
%6 = extractvalue { double, double } %4, 1
357+
store double %5, ptr addrspace(1) %frac, align 8
358+
store double %6, ptr addrspace(1) %integral, align 8
359+
ret void
360+
}
361+
362+
; CHECK: OpFunction
363+
; CHECK: %[[#d:]] = OpFunctionParameter %[[#]]
364+
; CHECK: %[[#fracPtr:]] = OpFunctionParameter %[[#]]
365+
; CHECK: %[[#integralPtr:]] = OpFunctionParameter %[[#]]
366+
; CHECK: %[[#entryBlock:]] = OpLabel
367+
; CHECK: %[[#varPtr:]] = OpVariable %[[#]] Function
368+
; CHECK: OpBranchConditional %[[#]] %[[#lor_lhs_falseBlock:]] %[[#if_thenBlock:]]
369+
; CHECK: %[[#lor_lhs_falseBlock]] = OpLabel
370+
; CHECK: OpBranchConditional %[[#]] %[[#if_endBlock:]] %[[#if_thenBlock]]
371+
; CHECK: %[[#if_thenBlock]] = OpLabel
372+
; CHECK: OpBranch %[[#returnBlock:]]
373+
; CHECK: %[[#if_endBlock]] = OpLabel
374+
; CHECK: %[[#frac:]] = OpExtInst %[[#var2]] %[[#extinst_id]] modf %[[#d]] %[[#varPtr]]
375+
; CHECK: %[[#integral:]] = OpLoad %[[#var2]] %[[#varPtr]]
376+
; CHECK: OpStore %[[#fracPtr]] %[[#frac]]
377+
; CHECK: OpStore %[[#integralPtr]] %[[#integral]]
378+
; CHECK: OpFunctionEnd
379+
define dso_local void @TestModf2(double noundef %d, ptr noundef %frac, ptr noundef %integral) {
380+
entry:
381+
%0 = load ptr, ptr %frac, align 8
382+
%tobool = icmp ne ptr %0, null
383+
br i1 %tobool, label %lor.lhs.false, label %if.then
384+
385+
lor.lhs.false:
386+
%1 = load ptr, ptr %integral, align 8
387+
%tobool1 = icmp ne ptr %1, null
388+
br i1 %tobool1, label %if.end, label %if.then
389+
390+
if.then:
391+
br label %return
392+
393+
if.end:
394+
%6 = tail call { double, double } @llvm.modf.f64(double %d)
395+
%7 = extractvalue { double, double } %6, 0
396+
%8 = extractvalue { double, double } %6, 1
397+
store double %7, ptr %frac, align 4
398+
store double %8, ptr %integral, align 4
399+
br label %return
400+
401+
return:
402+
ret void
403+
}
404+
405+
declare { double, double } @llvm.modf.f64(double)

0 commit comments

Comments
 (0)