Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIRV] Implement log10 for logical SPIR-V #66921

Merged
merged 5 commits into from
Oct 6, 2023
Merged

Conversation

sudonatalie
Copy link
Member

There is no log10 instruction in the GLSL Extended Instruction Set so to implement the HLSL log10 intrinsic when targeting Vulkan this change adds the logic to derive the result using the following formula:

log10(x) = log2(x) * (1 / log2(10))
         = log2(x) * 0.30103

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 20, 2023

@llvm/pr-subscribers-backend-spir-v

Changes

There is no log10 instruction in the GLSL Extended Instruction Set so to implement the HLSL log10 intrinsic when targeting Vulkan this change adds the logic to derive the result using the following formula:

log10(x) = log2(x) * (1 / log2(10))
         = log2(x) * 0.30103

Full diff: https://github.com/llvm/llvm-project/pull/66921.diff

5 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+12-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+55-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+1-2)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/log10.ll (+42)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index d68454f26a80282..68547e62abf823e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -242,7 +242,7 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
 
 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
                                               MachineIRBuilder &MIRBuilder,
-                                              SPIRVType *SpvType) {
+                                              SPIRVType *SpvType, bool EmitIR) {
   auto &MF = MIRBuilder.getMF();
   const Type *LLVMFPTy;
   if (SpvType) {
@@ -260,8 +260,18 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
     MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
     assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
     DT.add(ConstFP, &MF, Res);
-    MIRBuilder.buildFConstant(Res, *ConstFP);
+    if (EmitIR) {
+      MIRBuilder.buildFConstant(Res, *ConstFP);
+    } else {
+      MachineInstrBuilder MIB;
+      assert(SpvType);
+      MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
+                .addDef(Res)
+                .addUse(getSPIRVTypeID(SpvType));
+      addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
+    }
   }
+
   return Res;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 88769f84b3e504b..304af2440f95ae5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -234,7 +234,7 @@ class SPIRVGlobalRegistry {
   Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
                                SPIRVType *SpvType, const SPIRVInstrInfo &TII);
   Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder,
-                           SPIRVType *SpvType = nullptr);
+                           SPIRVType *SpvType = nullptr, bool EmitIR = true);
   Register getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
                                     SPIRVType *SpvType,
                                     const SPIRVInstrInfo &TII);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index afa34ff3ce1fa40..edb5ce69b0dc6ce 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -173,6 +173,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectExtInst(Register ResVReg, const SPIRVType *ResType,
                      MachineInstr &I, const ExtInstList &ExtInsts) const;
 
+  bool selectLog10(Register ResVReg, const SPIRVType *ResType,
+                   MachineInstr &I) const;
+
   Register buildI32Constant(uint32_t Val, MachineInstr &I,
                             const SPIRVType *ResType = nullptr) const;
 
@@ -361,7 +364,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_FLOG2:
     return selectExtInst(ResVReg, ResType, I, CL::log2, GL::Log2);
   case TargetOpcode::G_FLOG10:
-    return selectExtInst(ResVReg, ResType, I, CL::log10);
+    return selectLog10(ResVReg, ResType, I);
 
   case TargetOpcode::G_FABS:
     return selectExtInst(ResVReg, ResType, I, CL::fabs, GL::FAbs);
@@ -1540,6 +1543,57 @@ bool SPIRVInstructionSelector::selectGlobalValue(
   return Reg.isValid();
 }
 
+bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
+                                           const SPIRVType *ResType,
+                                           MachineInstr &I) const {
+  if (STI.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
+    return selectExtInst(ResVReg, ResType, I, CL::log10);
+  }
+
+  // There is no log10 instruction in the GLSL Extended Instruction set, so it
+  // is implemented as:
+  // log10(x) = log2(x) * (1 / log2(10))
+  //          = log2(x) * 0.30103
+
+  MachineIRBuilder MIRBuilder(I);
+  MachineBasicBlock &BB = *I.getParent();
+
+  // Build log2(x).
+  Register VarReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+  bool Result =
+      BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+          .addDef(VarReg)
+          .addUse(GR.getSPIRVTypeID(ResType))
+          .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+          .addImm(GL::Log)
+          .add(I.getOperand(1))
+          .constrainAllUses(TII, TRI, RBI);
+
+  // Build 0.30103.
+  assert(ResType->getOpcode() == SPIRV::OpTypeVector ||
+         ResType->getOpcode() == SPIRV::OpTypeFloat);
+  // TODO: Add matrix implementeation once supported by the HLSL frontend.
+  const SPIRVType *SpirvScalarType =
+      ResType->getOpcode() == SPIRV::OpTypeVector
+          ? GR.getSPIRVTypeForVReg(ResType->getOperand(1).getReg())
+          : ResType;
+  Register ScaleReg =
+      GR.buildConstantFP(APFloat(0.30103f), MIRBuilder, SpirvScalarType, false);
+
+  // Multiply log2(x) by 0.30103 to get log10(x) result.
+  auto Opcode = ResType->getOpcode() == SPIRV::OpTypeVector
+                    ? SPIRV::OpVectorTimesScalar
+                    : SPIRV::OpFMulS;
+  Result &= BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
+                .addDef(ResVReg)
+                .addUse(GR.getSPIRVTypeID(ResType))
+                .addUse(VarReg)
+                .addUse(ScaleReg)
+                .constrainAllUses(TII, TRI, RBI);
+
+  return Result;
+}
+
 namespace llvm {
 InstructionSelector *
 createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index b0028f8c80a406e..cffa0acd25226c3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -234,6 +234,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                                G_FEXP2,
                                G_FLOG,
                                G_FLOG2,
+                               G_FLOG10,
                                G_FABS,
                                G_FMINNUM,
                                G_FMAXNUM,
@@ -259,8 +260,6 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       allFloatScalarsAndVectors, allIntScalarsAndVectors);
 
   if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
-    getActionDefinitionsBuilder(G_FLOG10).legalFor(allFloatScalarsAndVectors);
-
     getActionDefinitionsBuilder(
         {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
         .legalForCartesianProduct(allIntScalarsAndVectors,
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/log10.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/log10.ll
new file mode 100644
index 000000000000000..7a3a611f5e7f1d8
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/log10.ll
@@ -0,0 +1,42 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-linux %s -o - | FileCheck %s
+
+; CHECK: OpExtInstImport "GLSL.std.450"
+
+; CHECK: %[[#float:]] = OpTypeFloat 32
+; CHECK: %[[#v4float:]] = OpTypeVector %[[#float]] 4
+; CHECK: %[[#float_0_30103001:]] = OpConstant %[[#float]] 0.30103000998497009
+; CHECK: %[[#_ptr_Function_v4float:]] = OpTypePointer Function %[[#v4float]]
+; CHECK: %[[#_ptr_Function_float:]] = OpTypePointer Function %[[#float]]
+
+define void @main() {
+entry:
+; CHECK: %[[#f:]] = OpVariable %[[#_ptr_Function_float]] Function
+; CHECK: %[[#logf:]] = OpVariable %[[#_ptr_Function_float]] Function
+; CHECK: %[[#f4:]] = OpVariable %[[#_ptr_Function_v4float]] Function
+; CHECK: %[[#logf4:]] = OpVariable %[[#_ptr_Function_v4float]] Function
+  %f = alloca float, align 4
+  %logf = alloca float, align 4
+  %f4 = alloca <4 x float>, align 16
+  %logf4 = alloca <4 x float>, align 16
+
+; CHECK: %[[#load:]] = OpLoad %[[#float]] %[[#f]] Aligned 4
+; CHECK: %[[#log2:]] = OpExtInst %[[#float]] %15 Log %[[#load]]
+; CHECK: %[[#res:]] = OpFMul %[[#float]] %[[#log2]] %[[#float_0_30103001]]
+; CHECK: OpStore %[[#logf]] %[[#res]] Aligned 4
+  %0 = load float, ptr %f, align 4
+  %elt.log10 = call float @llvm.log10.f32(float %0)
+  store float %elt.log10, ptr %logf, align 4
+
+; CHECK: %[[#load:]] = OpLoad %[[#v4float]] %[[#f4]] Aligned 16
+; CHECK: %[[#log2:]] = OpExtInst %[[#v4float]] %15 Log %[[#load]]
+; CHECK: %[[#res:]] = OpVectorTimesScalar %[[#v4float]] %[[#log2]] %[[#float_0_30103001]]
+; CHECK: OpStore %[[#logf4]] %[[#res]] Aligned 16
+  %1 = load <4 x float>, ptr %f4, align 16
+  %elt.log101 = call <4 x float> @llvm.log10.v4f32(<4 x float> %1)
+  store <4 x float> %elt.log101, ptr %logf4, align 16
+
+  ret void
+}
+
+declare float @llvm.log10.f32(float)
+declare <4 x float> @llvm.log10.v4f32(<4 x float>)

Copy link
Contributor

@Keenuts Keenuts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just some questions.

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp Outdated Show resolved Hide resolved
There is no log10 instruction in the GLSL Extended Instruction Set so to
implement the HLSL log10 intrinsic when targeting Vulkan this change
adds the logic to derive the result using the following formula:

```
log10(x) = log2(x) * (1 / log2(10))
         = log2(x) * 0.30103
```
Copy link
Contributor

@iliya-diyachkov iliya-diyachkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patch looks good to me.

@sudonatalie sudonatalie merged commit 0a2aaab into llvm:main Oct 6, 2023
3 checks passed
@sudonatalie sudonatalie deleted the log10 branch October 6, 2023 13:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants