Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DXIL] exp, any, lerp, & rcp Intrinsic Lowering #84526

Merged
merged 3 commits into from
Mar 15, 2024

Conversation

farzonl
Copy link
Member

@farzonl farzonl commented Mar 8, 2024

This change implements lowering for #70076, #70100, #70072, & #70102
CGBuiltin.cpp - - simplify lerp intrinsic
IntrinsicsDirectX.td - simplify lerp intrinsic
SemaChecking.cpp - remove unnecessary check
DXILIntrinsicExpansion.* - add intrinsic to instruction expansion cases
DXILOpLowering.cpp - make sure DXILIntrinsicExpansion happens first
DirectX.h - changes to support new pass
DirectXTargetMachine.cpp - changes to support new pass

Why any, and lerp as instruction expansion just for DXIL?

  • SPIR-V there is an OpAny
  • SPIR-V has a GLSL lerp extension via Fmix

Why exp instruction expansion?

  • We have an exp2 opcode and exp reuses that opcode. So instruction expansion is a convenient way to do preprocessing.
  • Further SPIR-V has a GLSL exp extension via Exp and Exp2

Why rcp as instruction expansion?
This one is a bit of the odd man out and might have to move to cgbuiltins when we better understand SPIRV requirements. However I included it because it seems like fast math mode has an AllowRecip flag which lets you compute the reciprocal without performing the division. We don't have that in DXIL so thought to include it.

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen backend:DirectX HLSL HLSL Language Support llvm:ir labels Mar 8, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 8, 2024

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-hlsl

Author: Farzon Lotfi (farzonl)

Changes

This change implements lowering for #70076, #70100, #70072, & #70102
CGBuiltin.cpp - - simplify lerp intrinsic
IntrinsicsDirectX.td - simplify lerp intrinsic
SemaChecking.cpp - remove unnecessary check
DXILIntrinsicExpansion.* - add intrinsic to instruction expansion cases
DXILOpLowering.cpp - make sure DXILIntrinsicExpansion happens first
DirectX.h - changes to support new pass
DirectXTargetMachine.cpp - changes to support new pass


Patch is 28.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84526.diff

15 Files Affected:

  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+4-31)
  • (modified) clang/lib/Sema/SemaChecking.cpp (-2)
  • (modified) clang/test/CodeGenHLSL/builtins/lerp.hlsl (+4-9)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1-4)
  • (modified) llvm/lib/Target/DirectX/CMakeLists.txt (+1)
  • (added) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+187)
  • (added) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h (+33)
  • (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+5-1)
  • (modified) llvm/lib/Target/DirectX/DirectX.h (+6)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetMachine.cpp (+2)
  • (added) llvm/test/CodeGen/DirectX/any.ll (+133)
  • (added) llvm/test/CodeGen/DirectX/exp-vec.ll (+23)
  • (added) llvm/test/CodeGen/DirectX/exp.ll (+38)
  • (added) llvm/test/CodeGen/DirectX/lerp.ll (+64)
  • (added) llvm/test/CodeGen/DirectX/rcp.ll (+63)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 20c35757939152..1d12237fb25e05 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18009,38 +18009,11 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     Value *X = EmitScalarExpr(E->getArg(0));
     Value *Y = EmitScalarExpr(E->getArg(1));
     Value *S = EmitScalarExpr(E->getArg(2));
-    llvm::Type *Xty = X->getType();
-    llvm::Type *Yty = Y->getType();
-    llvm::Type *Sty = S->getType();
-    if (!Xty->isVectorTy() && !Yty->isVectorTy() && !Sty->isVectorTy()) {
-      if (Xty->isFloatingPointTy()) {
-        auto V = Builder.CreateFSub(Y, X);
-        V = Builder.CreateFMul(S, V);
-        return Builder.CreateFAdd(X, V, "dx.lerp");
-      }
-      llvm_unreachable("Scalar Lerp is only supported on floats.");
-    }
-    // A VectorSplat should have happened
-    assert(Xty->isVectorTy() && Yty->isVectorTy() && Sty->isVectorTy() &&
-           "Lerp of vector and scalar is not supported.");
-
-    [[maybe_unused]] auto *XVecTy =
-        E->getArg(0)->getType()->getAs<VectorType>();
-    [[maybe_unused]] auto *YVecTy =
-        E->getArg(1)->getType()->getAs<VectorType>();
-    [[maybe_unused]] auto *SVecTy =
-        E->getArg(2)->getType()->getAs<VectorType>();
-    // A HLSLVectorTruncation should have happend
-    assert(XVecTy->getNumElements() == YVecTy->getNumElements() &&
-           XVecTy->getNumElements() == SVecTy->getNumElements() &&
-           "Lerp requires vectors to be of the same size.");
-    assert(XVecTy->getElementType()->isRealFloatingType() &&
-           XVecTy->getElementType() == YVecTy->getElementType() &&
-           XVecTy->getElementType() == SVecTy->getElementType() &&
-           "Lerp requires float vectors to be of the same type.");
+    if (!E->getArg(0)->getType()->hasFloatingRepresentation())
+      llvm_unreachable("lerp operand must have a float representation");
     return Builder.CreateIntrinsic(
-        /*ReturnType=*/Xty, Intrinsic::dx_lerp, ArrayRef<Value *>{X, Y, S},
-        nullptr, "dx.lerp");
+        /*ReturnType=*/X->getType(), Intrinsic::dx_lerp,
+        ArrayRef<Value *>{X, Y, S}, nullptr, "dx.lerp");
   }
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index a5f42b630c3fa2..8a2b7384a0b0d5 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5300,8 +5300,6 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     if (SemaBuiltinElementwiseTernaryMath(TheCall))
       return true;
-    if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
-      return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_mad: {
diff --git a/clang/test/CodeGenHLSL/builtins/lerp.hlsl b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
index a6b3d9643d674c..38a71ed3bcec94 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
@@ -6,13 +6,10 @@
 // RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
 // RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
 
-// NATIVE_HALF: %3 = fsub half %1, %0
-// NATIVE_HALF: %4 = fmul half %2, %3
-// NATIVE_HALF: %dx.lerp = fadd half %0, %4
+
+// NATIVE_HALF: %dx.lerp = call half @llvm.dx.lerp.f16(half %0, half %1, half %2)
 // NATIVE_HALF: ret half %dx.lerp
-// NO_HALF: %3 = fsub float %1, %0
-// NO_HALF: %4 = fmul float %2, %3
-// NO_HALF: %dx.lerp = fadd float %0, %4
+// NO_HALF: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
 // NO_HALF: ret float %dx.lerp
 half test_lerp_half(half p0) { return lerp(p0, p0, p0); }
 
@@ -34,9 +31,7 @@ half3 test_lerp_half3(half3 p0, half3 p1) { return lerp(p0, p0, p0); }
 // NO_HALF: ret <4 x float> %dx.lerp
 half4 test_lerp_half4(half4 p0, half4 p1) { return lerp(p0, p0, p0); }
 
-// CHECK: %3 = fsub float %1, %0
-// CHECK: %4 = fmul float %2, %3
-// CHECK: %dx.lerp = fadd float %0, %4
+// CHECK: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
 // CHECK: ret float %dx.lerp
 float test_lerp_float(float p0, float p1) { return lerp(p0, p0, p0); }
 
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 7229292e377a83..c3597432cfb974 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -29,10 +29,7 @@ def int_dx_dot :
 
 def int_dx_frac  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 
-def int_dx_lerp :
-    Intrinsic<[LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>,LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [IntrNoMem, IntrWillReturn] >;
+def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>], [IntrNoMem, IntrWillReturn] >;
 
 def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index bf93280779bf8b..4c70b3f9230edb 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -19,6 +19,7 @@ add_llvm_target(DirectXCodeGen
   DirectXSubtarget.cpp
   DirectXTargetMachine.cpp
   DXContainerGlobals.cpp
+  DXILIntrinsicExpansion.cpp
   DXILMetadata.cpp
   DXILOpBuilder.cpp
   DXILOpLowering.cpp
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
new file mode 100644
index 00000000000000..54a01d20e20548
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -0,0 +1,187 @@
+//===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains DXIL intrinsic expansions for those that don't have
+//  opcodes in DirectX Intermediate Language (DXIL).
+//===----------------------------------------------------------------------===//
+
+#include "DXILIntrinsicExpansion.h"
+#include "DirectX.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "dxil-intrinsic-expansion"
+#define M_LOG2E_F 1.44269504088896340735992468100189214f
+
+using namespace llvm;
+
+static bool isIntrinsicExpansion(Function &F) {
+  switch (F.getIntrinsicID()) {
+  case Intrinsic::exp:
+  case Intrinsic::dx_any:
+  case Intrinsic::dx_lerp:
+  case Intrinsic::dx_rcp:
+    return true;
+  }
+  return false;
+}
+
+static bool expandExpIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  Type *Ty = X->getType();
+  Type *EltTy = Ty->getScalarType();
+  Constant *Log2eConst =
+      Ty->isVectorTy()
+          ? ConstantVector::getSplat(
+                ElementCount::getFixed(
+                    dyn_cast<FixedVectorType>(Ty)->getNumElements()),
+                ConstantFP::get(EltTy, M_LOG2E_F))
+          : ConstantFP::get(EltTy, M_LOG2E_F);
+  Value *NewX = Builder.CreateFMul(Log2eConst, X);
+  auto *Exp2Call =
+      Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
+  Exp2Call->setTailCall(Orig->isTailCall());
+  Exp2Call->setAttributes(Orig->getAttributes());
+  Orig->replaceAllUsesWith(Exp2Call);
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandAnyIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  Type *Ty = X->getType();
+  Type *EltTy = Ty->getScalarType();
+
+  if (!Ty->isVectorTy()) {
+    Value *Cond = EltTy->isFloatingPointTy()
+                      ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
+                      : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
+    Orig->replaceAllUsesWith(Cond);
+  } else {
+    auto *XVec = dyn_cast<FixedVectorType>(Ty);
+    Value *Cond =
+        EltTy->isFloatingPointTy()
+            ? Builder.CreateFCmpUNE(
+                  X, ConstantVector::getSplat(
+                         ElementCount::getFixed(XVec->getNumElements()),
+                         ConstantFP::get(EltTy, 0)))
+            : Builder.CreateICmpNE(
+                  X, ConstantVector::getSplat(
+                         ElementCount::getFixed(XVec->getNumElements()),
+                         ConstantInt::get(EltTy, 0)));
+    Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
+    for (unsigned I = 1; I < XVec->getNumElements(); I++) {
+      Value *Elt = Builder.CreateExtractElement(Cond, I);
+      Result = Builder.CreateOr(Result, Elt);
+    }
+    Orig->replaceAllUsesWith(Result);
+  }
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandLerpIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  Value *Y = Orig->getOperand(1);
+  Value *S = Orig->getOperand(2);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  auto *V = Builder.CreateFSub(Y, X);
+  V = Builder.CreateFMul(S, V);
+  auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandReciprocalIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  Type *Ty = X->getType();
+  Type *EltTy = Ty->getScalarType();
+  Constant *One =
+      Ty->isVectorTy()
+          ? ConstantVector::getSplat(
+                ElementCount::getFixed(
+                    dyn_cast<FixedVectorType>(Ty)->getNumElements()),
+                ConstantFP::get(EltTy, 1.0))
+          : ConstantFP::get(EltTy, 1.0);
+  auto *Result = Builder.CreateFDiv(One, X, "dx.rcp");
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandIntrinsic(Function &F, CallInst *Orig) {
+  switch (F.getIntrinsicID()) {
+  case Intrinsic::exp:
+    return expandExpIntrinsic(Orig);
+  case Intrinsic::dx_any:
+    return expandAnyIntrinsic(Orig);
+  case Intrinsic::dx_lerp:
+    return expandLerpIntrinsic(Orig);
+  case Intrinsic::dx_rcp:
+    return expandReciprocalIntrinsic(Orig);
+  }
+  return false;
+}
+
+static bool intrinsicExpansion(Module &M) {
+  for (auto &F : make_early_inc_range(M.functions())) {
+    if (!isIntrinsicExpansion(F))
+      continue;
+    bool IntrinsicExpanded = false;
+    for (User *U : make_early_inc_range(F.users())) {
+      auto *IntrinsicCall = dyn_cast<CallInst>(U);
+      if (!IntrinsicCall)
+        continue;
+      IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
+    }
+    if (F.user_empty() && IntrinsicExpanded)
+      F.eraseFromParent();
+  }
+  return true;
+}
+
+PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
+                                              ModuleAnalysisManager &) {
+  if (intrinsicExpansion(M))
+    return PreservedAnalyses::none();
+  return PreservedAnalyses::all();
+}
+
+bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
+  return intrinsicExpansion(M);
+}
+
+char DXILIntrinsicExpansionLegacy::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
+                      "DXIL Intrinsic Expansion", false, false)
+INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
+                    "DXIL Intrinsic Expansion", false, false)
+
+ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
+  return new DXILIntrinsicExpansionLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h
new file mode 100644
index 00000000000000..c86681af7a3712
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h
@@ -0,0 +1,33 @@
+//===- DXILIntrinsicExpansion.h - Prepare LLVM Module for DXIL encoding----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
+#define LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
+
+#include "DXILResource.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+
+namespace llvm {
+
+/// A pass that transforms DXIL Intrinsics that don't have DXIL opCodes
+class DXILIntrinsicExpansion : public PassInfoMixin<DXILIntrinsicExpansion> {
+public:
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
+};
+
+class DXILIntrinsicExpansionLegacy : public ModulePass {
+
+public:
+  bool runOnModule(Module &M) override;
+  DXILIntrinsicExpansionLegacy() : ModulePass(ID) {}
+
+  static char ID; // Pass identification.
+};
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 6b649b76beecdf..e5c2042e7d16ae 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "DXILConstants.h"
+#include "DXILIntrinsicExpansion.h"
 #include "DXILOpBuilder.h"
 #include "DirectX.h"
 #include "llvm/ADT/SmallVector.h"
@@ -94,9 +95,12 @@ class DXILOpLoweringLegacy : public ModulePass {
   DXILOpLoweringLegacy() : ModulePass(ID) {}
 
   static char ID; // Pass identification.
+  void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
+    // Specify the passes that your pass depends on
+    AU.addRequired<DXILIntrinsicExpansionLegacy>();
+  }
 };
 char DXILOpLoweringLegacy::ID = 0;
-
 } // end anonymous namespace
 
 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index eaecc3ac280c4c..11b5412c21d783 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -28,6 +28,12 @@ void initializeDXILPrepareModulePass(PassRegistry &);
 /// Pass to convert modules into DXIL-compatable modules
 ModulePass *createDXILPrepareModulePass();
 
+/// Initializer for DXIL Intrinsic Expansion
+void initializeDXILIntrinsicExpansionLegacyPass(PassRegistry &);
+
+/// Pass to expand intrinsic operations that lack DXIL opCodes
+ModulePass *createDXILIntrinsicExpansionLegacyPass();
+
 /// Initializer for DXILOpLowering
 void initializeDXILOpLoweringLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 06938f8c74f155..03c825b3977db3 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -39,6 +39,7 @@ using namespace llvm;
 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   RegisterTargetMachine<DirectXTargetMachine> X(getTheDirectXTarget());
   auto *PR = PassRegistry::getPassRegistry();
+  initializeDXILIntrinsicExpansionLegacyPass(*PR);
   initializeDXILPrepareModulePass(*PR);
   initializeEmbedDXILPassPass(*PR);
   initializeWriteDXILPassPass(*PR);
@@ -76,6 +77,7 @@ class DirectXPassConfig : public TargetPassConfig {
 
   FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; }
   void addCodeGenPrepare() override {
+    addPass(createDXILIntrinsicExpansionLegacyPass());
     addPass(createDXILOpLoweringLegacyPass());
     addPass(createDXILPrepareModulePass());
     addPass(createDXILTranslateMetadataPass());
diff --git a/llvm/test/CodeGen/DirectX/any.ll b/llvm/test/CodeGen/DirectX/any.ll
new file mode 100644
index 00000000000000..516afa101e948d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/any.ll
@@ -0,0 +1,133 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for any are generated for float and half.
+
+
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.7-library"
+
+
+; CHECK:icmp ne i1 %{{.*}}, false
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_bool(i1 noundef %p0) #0 {
+entry:
+  %p0.addr = alloca i8, align 1
+  %frombool = zext i1 %p0 to i8
+  store i8 %frombool, ptr %p0.addr, align 1
+  %0 = load i8, ptr %p0.addr, align 1
+  %tobool = trunc i8 %0 to i1
+  %dx.any = call i1 @llvm.dx.any.i1(i1 %tobool)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.i1(i1) #1
+
+; CHECK:icmp ne i64 %{{.*}}, 0
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_int64_t(i64 noundef %p0) #0 {
+entry:
+  %p0.addr = alloca i64, align 8
+  store i64 %p0, ptr %p0.addr, align 8
+  %0 = load i64, ptr %p0.addr, align 8
+  %dx.any = call i1 @llvm.dx.any.i64(i64 %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.i64(i64) #1
+
+; CHECK:icmp ne i32 %{{.*}}, 0
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_int(i32 noundef %p0) #0 {
+entry:
+  %p0.addr = alloca i32, align 4
+  store i32 %p0, ptr %p0.addr, align 4
+  %0 = load i32, ptr %p0.addr, align 4
+  %dx.any = call i1 @llvm.dx.any.i32(i32 %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.i32(i32) #1
+
+; CHECK:icmp ne i16 %{{.*}}, 0
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_int16_t(i16 noundef %p0) #0 {
+entry:
+  %p0.addr = alloca i16, align 2
+  store i16 %p0, ptr %p0.addr, align 2
+  %0 = load i16, ptr %p0.addr, align 2
+  %dx.any = call i1 @llvm.dx.any.i16(i16 %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.i16(i16) #1
+
+; CHECK:fcmp une double %{{.*}}, 0.000000e+00
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_double(double noundef %p0) #0 {
+entry:
+  %p0.addr = alloca double, align 8
+  store double %p0, ptr %p0.addr, align 8
+  %0 = load double, ptr %p0.addr, align 8
+  %dx.any = call i1 @llvm.dx.any.f64(double %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.f64(double) #1
+
+; CHECK:fcmp une float %{{.*}}, 0.000000e+00
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_float(float noundef %p0) #0 {
+entry:
+  %p0.addr = alloca float, align 4
+  store float %p0, ptr %p0.addr, align 4
+  %0 = load float, ptr %p0.addr, align 4
+  %dx.any = call i1 @llvm.dx.any.f32(float %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.f32(float) #1
+
+; CHECK:fcmp une half %{{.*}}, 0xH0000
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_half(half noundef %p0) #0 {
+entry:
+  %p0.addr = alloca half, align 2
+  store half %p0, ptr %p0.addr, align 2
+  %0 = load half, ptr %p0.addr, align 2
+  %dx.any = call i1 @llvm.dx.any.f16(half %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.f16(half) #1
+
+; CHECK:icmp ne <4 x i1> %extractvec, zeroinitialize
+; CHECK:extractelement <4 x i1> %{{.*}}, i64 0
+; CHECK:extractelement <4 x i1> %{{.*}}, i64 1
+; CHECK:or i1  %{{.*}}, %{{.*}}
+; CHECK:extractelement <4 x i1> %{{.*}}, i64 2
+...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 8, 2024

@llvm/pr-subscribers-backend-directx

Author: Farzon Lotfi (farzonl)

Changes

This change implements lowering for #70076, #70100, #70072, & #70102
CGBuiltin.cpp - - simplify lerp intrinsic
IntrinsicsDirectX.td - simplify lerp intrinsic
SemaChecking.cpp - remove unnecessary check
DXILIntrinsicExpansion.* - add intrinsic to instruction expansion cases
DXILOpLowering.cpp - make sure DXILIntrinsicExpansion happens first
DirectX.h - changes to support new pass
DirectXTargetMachine.cpp - changes to support new pass


Patch is 28.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84526.diff

15 Files Affected:

  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+4-31)
  • (modified) clang/lib/Sema/SemaChecking.cpp (-2)
  • (modified) clang/test/CodeGenHLSL/builtins/lerp.hlsl (+4-9)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1-4)
  • (modified) llvm/lib/Target/DirectX/CMakeLists.txt (+1)
  • (added) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+187)
  • (added) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h (+33)
  • (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+5-1)
  • (modified) llvm/lib/Target/DirectX/DirectX.h (+6)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetMachine.cpp (+2)
  • (added) llvm/test/CodeGen/DirectX/any.ll (+133)
  • (added) llvm/test/CodeGen/DirectX/exp-vec.ll (+23)
  • (added) llvm/test/CodeGen/DirectX/exp.ll (+38)
  • (added) llvm/test/CodeGen/DirectX/lerp.ll (+64)
  • (added) llvm/test/CodeGen/DirectX/rcp.ll (+63)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 20c35757939152..1d12237fb25e05 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18009,38 +18009,11 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     Value *X = EmitScalarExpr(E->getArg(0));
     Value *Y = EmitScalarExpr(E->getArg(1));
     Value *S = EmitScalarExpr(E->getArg(2));
-    llvm::Type *Xty = X->getType();
-    llvm::Type *Yty = Y->getType();
-    llvm::Type *Sty = S->getType();
-    if (!Xty->isVectorTy() && !Yty->isVectorTy() && !Sty->isVectorTy()) {
-      if (Xty->isFloatingPointTy()) {
-        auto V = Builder.CreateFSub(Y, X);
-        V = Builder.CreateFMul(S, V);
-        return Builder.CreateFAdd(X, V, "dx.lerp");
-      }
-      llvm_unreachable("Scalar Lerp is only supported on floats.");
-    }
-    // A VectorSplat should have happened
-    assert(Xty->isVectorTy() && Yty->isVectorTy() && Sty->isVectorTy() &&
-           "Lerp of vector and scalar is not supported.");
-
-    [[maybe_unused]] auto *XVecTy =
-        E->getArg(0)->getType()->getAs<VectorType>();
-    [[maybe_unused]] auto *YVecTy =
-        E->getArg(1)->getType()->getAs<VectorType>();
-    [[maybe_unused]] auto *SVecTy =
-        E->getArg(2)->getType()->getAs<VectorType>();
-    // A HLSLVectorTruncation should have happend
-    assert(XVecTy->getNumElements() == YVecTy->getNumElements() &&
-           XVecTy->getNumElements() == SVecTy->getNumElements() &&
-           "Lerp requires vectors to be of the same size.");
-    assert(XVecTy->getElementType()->isRealFloatingType() &&
-           XVecTy->getElementType() == YVecTy->getElementType() &&
-           XVecTy->getElementType() == SVecTy->getElementType() &&
-           "Lerp requires float vectors to be of the same type.");
+    if (!E->getArg(0)->getType()->hasFloatingRepresentation())
+      llvm_unreachable("lerp operand must have a float representation");
     return Builder.CreateIntrinsic(
-        /*ReturnType=*/Xty, Intrinsic::dx_lerp, ArrayRef<Value *>{X, Y, S},
-        nullptr, "dx.lerp");
+        /*ReturnType=*/X->getType(), Intrinsic::dx_lerp,
+        ArrayRef<Value *>{X, Y, S}, nullptr, "dx.lerp");
   }
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index a5f42b630c3fa2..8a2b7384a0b0d5 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5300,8 +5300,6 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     if (SemaBuiltinElementwiseTernaryMath(TheCall))
       return true;
-    if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
-      return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_mad: {
diff --git a/clang/test/CodeGenHLSL/builtins/lerp.hlsl b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
index a6b3d9643d674c..38a71ed3bcec94 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
@@ -6,13 +6,10 @@
 // RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
 // RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
 
-// NATIVE_HALF: %3 = fsub half %1, %0
-// NATIVE_HALF: %4 = fmul half %2, %3
-// NATIVE_HALF: %dx.lerp = fadd half %0, %4
+
+// NATIVE_HALF: %dx.lerp = call half @llvm.dx.lerp.f16(half %0, half %1, half %2)
 // NATIVE_HALF: ret half %dx.lerp
-// NO_HALF: %3 = fsub float %1, %0
-// NO_HALF: %4 = fmul float %2, %3
-// NO_HALF: %dx.lerp = fadd float %0, %4
+// NO_HALF: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
 // NO_HALF: ret float %dx.lerp
 half test_lerp_half(half p0) { return lerp(p0, p0, p0); }
 
@@ -34,9 +31,7 @@ half3 test_lerp_half3(half3 p0, half3 p1) { return lerp(p0, p0, p0); }
 // NO_HALF: ret <4 x float> %dx.lerp
 half4 test_lerp_half4(half4 p0, half4 p1) { return lerp(p0, p0, p0); }
 
-// CHECK: %3 = fsub float %1, %0
-// CHECK: %4 = fmul float %2, %3
-// CHECK: %dx.lerp = fadd float %0, %4
+// CHECK: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
 // CHECK: ret float %dx.lerp
 float test_lerp_float(float p0, float p1) { return lerp(p0, p0, p0); }
 
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 7229292e377a83..c3597432cfb974 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -29,10 +29,7 @@ def int_dx_dot :
 
 def int_dx_frac  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 
-def int_dx_lerp :
-    Intrinsic<[LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>,LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [IntrNoMem, IntrWillReturn] >;
+def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>], [IntrNoMem, IntrWillReturn] >;
 
 def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index bf93280779bf8b..4c70b3f9230edb 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -19,6 +19,7 @@ add_llvm_target(DirectXCodeGen
   DirectXSubtarget.cpp
   DirectXTargetMachine.cpp
   DXContainerGlobals.cpp
+  DXILIntrinsicExpansion.cpp
   DXILMetadata.cpp
   DXILOpBuilder.cpp
   DXILOpLowering.cpp
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
new file mode 100644
index 00000000000000..54a01d20e20548
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -0,0 +1,187 @@
+//===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains DXIL intrinsic expansions for those that don't have
+//  opcodes in DirectX Intermediate Language (DXIL).
+//===----------------------------------------------------------------------===//
+
+#include "DXILIntrinsicExpansion.h"
+#include "DirectX.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "dxil-intrinsic-expansion"
+#define M_LOG2E_F 1.44269504088896340735992468100189214f
+
+using namespace llvm;
+
+static bool isIntrinsicExpansion(Function &F) {
+  switch (F.getIntrinsicID()) {
+  case Intrinsic::exp:
+  case Intrinsic::dx_any:
+  case Intrinsic::dx_lerp:
+  case Intrinsic::dx_rcp:
+    return true;
+  }
+  return false;
+}
+
+static bool expandExpIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  Type *Ty = X->getType();
+  Type *EltTy = Ty->getScalarType();
+  Constant *Log2eConst =
+      Ty->isVectorTy()
+          ? ConstantVector::getSplat(
+                ElementCount::getFixed(
+                    dyn_cast<FixedVectorType>(Ty)->getNumElements()),
+                ConstantFP::get(EltTy, M_LOG2E_F))
+          : ConstantFP::get(EltTy, M_LOG2E_F);
+  Value *NewX = Builder.CreateFMul(Log2eConst, X);
+  auto *Exp2Call =
+      Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
+  Exp2Call->setTailCall(Orig->isTailCall());
+  Exp2Call->setAttributes(Orig->getAttributes());
+  Orig->replaceAllUsesWith(Exp2Call);
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandAnyIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  Type *Ty = X->getType();
+  Type *EltTy = Ty->getScalarType();
+
+  if (!Ty->isVectorTy()) {
+    Value *Cond = EltTy->isFloatingPointTy()
+                      ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
+                      : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
+    Orig->replaceAllUsesWith(Cond);
+  } else {
+    auto *XVec = dyn_cast<FixedVectorType>(Ty);
+    Value *Cond =
+        EltTy->isFloatingPointTy()
+            ? Builder.CreateFCmpUNE(
+                  X, ConstantVector::getSplat(
+                         ElementCount::getFixed(XVec->getNumElements()),
+                         ConstantFP::get(EltTy, 0)))
+            : Builder.CreateICmpNE(
+                  X, ConstantVector::getSplat(
+                         ElementCount::getFixed(XVec->getNumElements()),
+                         ConstantInt::get(EltTy, 0)));
+    Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
+    for (unsigned I = 1; I < XVec->getNumElements(); I++) {
+      Value *Elt = Builder.CreateExtractElement(Cond, I);
+      Result = Builder.CreateOr(Result, Elt);
+    }
+    Orig->replaceAllUsesWith(Result);
+  }
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandLerpIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  Value *Y = Orig->getOperand(1);
+  Value *S = Orig->getOperand(2);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  auto *V = Builder.CreateFSub(Y, X);
+  V = Builder.CreateFMul(S, V);
+  auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandReciprocalIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  Type *Ty = X->getType();
+  Type *EltTy = Ty->getScalarType();
+  Constant *One =
+      Ty->isVectorTy()
+          ? ConstantVector::getSplat(
+                ElementCount::getFixed(
+                    dyn_cast<FixedVectorType>(Ty)->getNumElements()),
+                ConstantFP::get(EltTy, 1.0))
+          : ConstantFP::get(EltTy, 1.0);
+  auto *Result = Builder.CreateFDiv(One, X, "dx.rcp");
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandIntrinsic(Function &F, CallInst *Orig) {
+  switch (F.getIntrinsicID()) {
+  case Intrinsic::exp:
+    return expandExpIntrinsic(Orig);
+  case Intrinsic::dx_any:
+    return expandAnyIntrinsic(Orig);
+  case Intrinsic::dx_lerp:
+    return expandLerpIntrinsic(Orig);
+  case Intrinsic::dx_rcp:
+    return expandReciprocalIntrinsic(Orig);
+  }
+  return false;
+}
+
+static bool intrinsicExpansion(Module &M) {
+  for (auto &F : make_early_inc_range(M.functions())) {
+    if (!isIntrinsicExpansion(F))
+      continue;
+    bool IntrinsicExpanded = false;
+    for (User *U : make_early_inc_range(F.users())) {
+      auto *IntrinsicCall = dyn_cast<CallInst>(U);
+      if (!IntrinsicCall)
+        continue;
+      IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
+    }
+    if (F.user_empty() && IntrinsicExpanded)
+      F.eraseFromParent();
+  }
+  return true;
+}
+
+PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
+                                              ModuleAnalysisManager &) {
+  if (intrinsicExpansion(M))
+    return PreservedAnalyses::none();
+  return PreservedAnalyses::all();
+}
+
+bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
+  return intrinsicExpansion(M);
+}
+
+char DXILIntrinsicExpansionLegacy::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
+                      "DXIL Intrinsic Expansion", false, false)
+INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
+                    "DXIL Intrinsic Expansion", false, false)
+
+ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
+  return new DXILIntrinsicExpansionLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h
new file mode 100644
index 00000000000000..c86681af7a3712
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h
@@ -0,0 +1,33 @@
+//===- DXILIntrinsicExpansion.h - Prepare LLVM Module for DXIL encoding----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
+#define LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
+
+#include "DXILResource.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+
+namespace llvm {
+
+/// A pass that transforms DXIL Intrinsics that don't have DXIL opCodes
+class DXILIntrinsicExpansion : public PassInfoMixin<DXILIntrinsicExpansion> {
+public:
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
+};
+
+class DXILIntrinsicExpansionLegacy : public ModulePass {
+
+public:
+  bool runOnModule(Module &M) override;
+  DXILIntrinsicExpansionLegacy() : ModulePass(ID) {}
+
+  static char ID; // Pass identification.
+};
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 6b649b76beecdf..e5c2042e7d16ae 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "DXILConstants.h"
+#include "DXILIntrinsicExpansion.h"
 #include "DXILOpBuilder.h"
 #include "DirectX.h"
 #include "llvm/ADT/SmallVector.h"
@@ -94,9 +95,12 @@ class DXILOpLoweringLegacy : public ModulePass {
   DXILOpLoweringLegacy() : ModulePass(ID) {}
 
   static char ID; // Pass identification.
+  void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
+    // Specify the passes that your pass depends on
+    AU.addRequired<DXILIntrinsicExpansionLegacy>();
+  }
 };
 char DXILOpLoweringLegacy::ID = 0;
-
 } // end anonymous namespace
 
 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index eaecc3ac280c4c..11b5412c21d783 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -28,6 +28,12 @@ void initializeDXILPrepareModulePass(PassRegistry &);
 /// Pass to convert modules into DXIL-compatable modules
 ModulePass *createDXILPrepareModulePass();
 
+/// Initializer for DXIL Intrinsic Expansion
+void initializeDXILIntrinsicExpansionLegacyPass(PassRegistry &);
+
+/// Pass to expand intrinsic operations that lack DXIL opCodes
+ModulePass *createDXILIntrinsicExpansionLegacyPass();
+
 /// Initializer for DXILOpLowering
 void initializeDXILOpLoweringLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 06938f8c74f155..03c825b3977db3 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -39,6 +39,7 @@ using namespace llvm;
 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   RegisterTargetMachine<DirectXTargetMachine> X(getTheDirectXTarget());
   auto *PR = PassRegistry::getPassRegistry();
+  initializeDXILIntrinsicExpansionLegacyPass(*PR);
   initializeDXILPrepareModulePass(*PR);
   initializeEmbedDXILPassPass(*PR);
   initializeWriteDXILPassPass(*PR);
@@ -76,6 +77,7 @@ class DirectXPassConfig : public TargetPassConfig {
 
   FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; }
   void addCodeGenPrepare() override {
+    addPass(createDXILIntrinsicExpansionLegacyPass());
     addPass(createDXILOpLoweringLegacyPass());
     addPass(createDXILPrepareModulePass());
     addPass(createDXILTranslateMetadataPass());
diff --git a/llvm/test/CodeGen/DirectX/any.ll b/llvm/test/CodeGen/DirectX/any.ll
new file mode 100644
index 00000000000000..516afa101e948d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/any.ll
@@ -0,0 +1,133 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for any are generated for float and half.
+
+
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.7-library"
+
+
+; CHECK:icmp ne i1 %{{.*}}, false
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_bool(i1 noundef %p0) #0 {
+entry:
+  %p0.addr = alloca i8, align 1
+  %frombool = zext i1 %p0 to i8
+  store i8 %frombool, ptr %p0.addr, align 1
+  %0 = load i8, ptr %p0.addr, align 1
+  %tobool = trunc i8 %0 to i1
+  %dx.any = call i1 @llvm.dx.any.i1(i1 %tobool)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.i1(i1) #1
+
+; CHECK:icmp ne i64 %{{.*}}, 0
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_int64_t(i64 noundef %p0) #0 {
+entry:
+  %p0.addr = alloca i64, align 8
+  store i64 %p0, ptr %p0.addr, align 8
+  %0 = load i64, ptr %p0.addr, align 8
+  %dx.any = call i1 @llvm.dx.any.i64(i64 %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.i64(i64) #1
+
+; CHECK:icmp ne i32 %{{.*}}, 0
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_int(i32 noundef %p0) #0 {
+entry:
+  %p0.addr = alloca i32, align 4
+  store i32 %p0, ptr %p0.addr, align 4
+  %0 = load i32, ptr %p0.addr, align 4
+  %dx.any = call i1 @llvm.dx.any.i32(i32 %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.i32(i32) #1
+
+; CHECK:icmp ne i16 %{{.*}}, 0
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_int16_t(i16 noundef %p0) #0 {
+entry:
+  %p0.addr = alloca i16, align 2
+  store i16 %p0, ptr %p0.addr, align 2
+  %0 = load i16, ptr %p0.addr, align 2
+  %dx.any = call i1 @llvm.dx.any.i16(i16 %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.i16(i16) #1
+
+; CHECK:fcmp une double %{{.*}}, 0.000000e+00
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_double(double noundef %p0) #0 {
+entry:
+  %p0.addr = alloca double, align 8
+  store double %p0, ptr %p0.addr, align 8
+  %0 = load double, ptr %p0.addr, align 8
+  %dx.any = call i1 @llvm.dx.any.f64(double %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.f64(double) #1
+
+; CHECK:fcmp une float %{{.*}}, 0.000000e+00
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_float(float noundef %p0) #0 {
+entry:
+  %p0.addr = alloca float, align 4
+  store float %p0, ptr %p0.addr, align 4
+  %0 = load float, ptr %p0.addr, align 4
+  %dx.any = call i1 @llvm.dx.any.f32(float %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.f32(float) #1
+
+; CHECK:fcmp une half %{{.*}}, 0xH0000
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_half(half noundef %p0) #0 {
+entry:
+  %p0.addr = alloca half, align 2
+  store half %p0, ptr %p0.addr, align 2
+  %0 = load half, ptr %p0.addr, align 2
+  %dx.any = call i1 @llvm.dx.any.f16(half %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.f16(half) #1
+
+; CHECK:icmp ne <4 x i1> %extractvec, zeroinitialize
+; CHECK:extractelement <4 x i1> %{{.*}}, i64 0
+; CHECK:extractelement <4 x i1> %{{.*}}, i64 1
+; CHECK:or i1  %{{.*}}, %{{.*}}
+; CHECK:extractelement <4 x i1> %{{.*}}, i64 2
+...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 8, 2024

@llvm/pr-subscribers-clang

Author: Farzon Lotfi (farzonl)

Changes

This change implements lowering for #70076, #70100, #70072, & #70102
CGBuiltin.cpp - - simplify lerp intrinsic
IntrinsicsDirectX.td - simplify lerp intrinsic
SemaChecking.cpp - remove unnecessary check
DXILIntrinsicExpansion.* - add intrinsic to instruction expansion cases
DXILOpLowering.cpp - make sure DXILIntrinsicExpansion happens first
DirectX.h - changes to support new pass
DirectXTargetMachine.cpp - changes to support new pass


Patch is 28.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84526.diff

15 Files Affected:

  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+4-31)
  • (modified) clang/lib/Sema/SemaChecking.cpp (-2)
  • (modified) clang/test/CodeGenHLSL/builtins/lerp.hlsl (+4-9)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1-4)
  • (modified) llvm/lib/Target/DirectX/CMakeLists.txt (+1)
  • (added) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+187)
  • (added) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h (+33)
  • (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+5-1)
  • (modified) llvm/lib/Target/DirectX/DirectX.h (+6)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetMachine.cpp (+2)
  • (added) llvm/test/CodeGen/DirectX/any.ll (+133)
  • (added) llvm/test/CodeGen/DirectX/exp-vec.ll (+23)
  • (added) llvm/test/CodeGen/DirectX/exp.ll (+38)
  • (added) llvm/test/CodeGen/DirectX/lerp.ll (+64)
  • (added) llvm/test/CodeGen/DirectX/rcp.ll (+63)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 20c35757939152..1d12237fb25e05 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18009,38 +18009,11 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     Value *X = EmitScalarExpr(E->getArg(0));
     Value *Y = EmitScalarExpr(E->getArg(1));
     Value *S = EmitScalarExpr(E->getArg(2));
-    llvm::Type *Xty = X->getType();
-    llvm::Type *Yty = Y->getType();
-    llvm::Type *Sty = S->getType();
-    if (!Xty->isVectorTy() && !Yty->isVectorTy() && !Sty->isVectorTy()) {
-      if (Xty->isFloatingPointTy()) {
-        auto V = Builder.CreateFSub(Y, X);
-        V = Builder.CreateFMul(S, V);
-        return Builder.CreateFAdd(X, V, "dx.lerp");
-      }
-      llvm_unreachable("Scalar Lerp is only supported on floats.");
-    }
-    // A VectorSplat should have happened
-    assert(Xty->isVectorTy() && Yty->isVectorTy() && Sty->isVectorTy() &&
-           "Lerp of vector and scalar is not supported.");
-
-    [[maybe_unused]] auto *XVecTy =
-        E->getArg(0)->getType()->getAs<VectorType>();
-    [[maybe_unused]] auto *YVecTy =
-        E->getArg(1)->getType()->getAs<VectorType>();
-    [[maybe_unused]] auto *SVecTy =
-        E->getArg(2)->getType()->getAs<VectorType>();
-    // A HLSLVectorTruncation should have happend
-    assert(XVecTy->getNumElements() == YVecTy->getNumElements() &&
-           XVecTy->getNumElements() == SVecTy->getNumElements() &&
-           "Lerp requires vectors to be of the same size.");
-    assert(XVecTy->getElementType()->isRealFloatingType() &&
-           XVecTy->getElementType() == YVecTy->getElementType() &&
-           XVecTy->getElementType() == SVecTy->getElementType() &&
-           "Lerp requires float vectors to be of the same type.");
+    if (!E->getArg(0)->getType()->hasFloatingRepresentation())
+      llvm_unreachable("lerp operand must have a float representation");
     return Builder.CreateIntrinsic(
-        /*ReturnType=*/Xty, Intrinsic::dx_lerp, ArrayRef<Value *>{X, Y, S},
-        nullptr, "dx.lerp");
+        /*ReturnType=*/X->getType(), Intrinsic::dx_lerp,
+        ArrayRef<Value *>{X, Y, S}, nullptr, "dx.lerp");
   }
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index a5f42b630c3fa2..8a2b7384a0b0d5 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5300,8 +5300,6 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     if (SemaBuiltinElementwiseTernaryMath(TheCall))
       return true;
-    if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
-      return true;
     break;
   }
   case Builtin::BI__builtin_hlsl_mad: {
diff --git a/clang/test/CodeGenHLSL/builtins/lerp.hlsl b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
index a6b3d9643d674c..38a71ed3bcec94 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
@@ -6,13 +6,10 @@
 // RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
 // RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
 
-// NATIVE_HALF: %3 = fsub half %1, %0
-// NATIVE_HALF: %4 = fmul half %2, %3
-// NATIVE_HALF: %dx.lerp = fadd half %0, %4
+
+// NATIVE_HALF: %dx.lerp = call half @llvm.dx.lerp.f16(half %0, half %1, half %2)
 // NATIVE_HALF: ret half %dx.lerp
-// NO_HALF: %3 = fsub float %1, %0
-// NO_HALF: %4 = fmul float %2, %3
-// NO_HALF: %dx.lerp = fadd float %0, %4
+// NO_HALF: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
 // NO_HALF: ret float %dx.lerp
 half test_lerp_half(half p0) { return lerp(p0, p0, p0); }
 
@@ -34,9 +31,7 @@ half3 test_lerp_half3(half3 p0, half3 p1) { return lerp(p0, p0, p0); }
 // NO_HALF: ret <4 x float> %dx.lerp
 half4 test_lerp_half4(half4 p0, half4 p1) { return lerp(p0, p0, p0); }
 
-// CHECK: %3 = fsub float %1, %0
-// CHECK: %4 = fmul float %2, %3
-// CHECK: %dx.lerp = fadd float %0, %4
+// CHECK: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
 // CHECK: ret float %dx.lerp
 float test_lerp_float(float p0, float p1) { return lerp(p0, p0, p0); }
 
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 7229292e377a83..c3597432cfb974 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -29,10 +29,7 @@ def int_dx_dot :
 
 def int_dx_frac  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 
-def int_dx_lerp :
-    Intrinsic<[LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>,LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [IntrNoMem, IntrWillReturn] >;
+def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>], [IntrNoMem, IntrWillReturn] >;
 
 def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index bf93280779bf8b..4c70b3f9230edb 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -19,6 +19,7 @@ add_llvm_target(DirectXCodeGen
   DirectXSubtarget.cpp
   DirectXTargetMachine.cpp
   DXContainerGlobals.cpp
+  DXILIntrinsicExpansion.cpp
   DXILMetadata.cpp
   DXILOpBuilder.cpp
   DXILOpLowering.cpp
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
new file mode 100644
index 00000000000000..54a01d20e20548
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -0,0 +1,187 @@
+//===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains DXIL intrinsic expansions for those that don't have
+//  opcodes in DirectX Intermediate Language (DXIL).
+//===----------------------------------------------------------------------===//
+
+#include "DXILIntrinsicExpansion.h"
+#include "DirectX.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "dxil-intrinsic-expansion"
+#define M_LOG2E_F 1.44269504088896340735992468100189214f
+
+using namespace llvm;
+
+static bool isIntrinsicExpansion(Function &F) {
+  switch (F.getIntrinsicID()) {
+  case Intrinsic::exp:
+  case Intrinsic::dx_any:
+  case Intrinsic::dx_lerp:
+  case Intrinsic::dx_rcp:
+    return true;
+  }
+  return false;
+}
+
+static bool expandExpIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  Type *Ty = X->getType();
+  Type *EltTy = Ty->getScalarType();
+  Constant *Log2eConst =
+      Ty->isVectorTy()
+          ? ConstantVector::getSplat(
+                ElementCount::getFixed(
+                    dyn_cast<FixedVectorType>(Ty)->getNumElements()),
+                ConstantFP::get(EltTy, M_LOG2E_F))
+          : ConstantFP::get(EltTy, M_LOG2E_F);
+  Value *NewX = Builder.CreateFMul(Log2eConst, X);
+  auto *Exp2Call =
+      Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
+  Exp2Call->setTailCall(Orig->isTailCall());
+  Exp2Call->setAttributes(Orig->getAttributes());
+  Orig->replaceAllUsesWith(Exp2Call);
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandAnyIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  Type *Ty = X->getType();
+  Type *EltTy = Ty->getScalarType();
+
+  if (!Ty->isVectorTy()) {
+    Value *Cond = EltTy->isFloatingPointTy()
+                      ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
+                      : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
+    Orig->replaceAllUsesWith(Cond);
+  } else {
+    auto *XVec = dyn_cast<FixedVectorType>(Ty);
+    Value *Cond =
+        EltTy->isFloatingPointTy()
+            ? Builder.CreateFCmpUNE(
+                  X, ConstantVector::getSplat(
+                         ElementCount::getFixed(XVec->getNumElements()),
+                         ConstantFP::get(EltTy, 0)))
+            : Builder.CreateICmpNE(
+                  X, ConstantVector::getSplat(
+                         ElementCount::getFixed(XVec->getNumElements()),
+                         ConstantInt::get(EltTy, 0)));
+    Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
+    for (unsigned I = 1; I < XVec->getNumElements(); I++) {
+      Value *Elt = Builder.CreateExtractElement(Cond, I);
+      Result = Builder.CreateOr(Result, Elt);
+    }
+    Orig->replaceAllUsesWith(Result);
+  }
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandLerpIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  Value *Y = Orig->getOperand(1);
+  Value *S = Orig->getOperand(2);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  auto *V = Builder.CreateFSub(Y, X);
+  V = Builder.CreateFMul(S, V);
+  auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandReciprocalIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+  Type *Ty = X->getType();
+  Type *EltTy = Ty->getScalarType();
+  Constant *One =
+      Ty->isVectorTy()
+          ? ConstantVector::getSplat(
+                ElementCount::getFixed(
+                    dyn_cast<FixedVectorType>(Ty)->getNumElements()),
+                ConstantFP::get(EltTy, 1.0))
+          : ConstantFP::get(EltTy, 1.0);
+  auto *Result = Builder.CreateFDiv(One, X, "dx.rcp");
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
+static bool expandIntrinsic(Function &F, CallInst *Orig) {
+  switch (F.getIntrinsicID()) {
+  case Intrinsic::exp:
+    return expandExpIntrinsic(Orig);
+  case Intrinsic::dx_any:
+    return expandAnyIntrinsic(Orig);
+  case Intrinsic::dx_lerp:
+    return expandLerpIntrinsic(Orig);
+  case Intrinsic::dx_rcp:
+    return expandReciprocalIntrinsic(Orig);
+  }
+  return false;
+}
+
+static bool intrinsicExpansion(Module &M) {
+  for (auto &F : make_early_inc_range(M.functions())) {
+    if (!isIntrinsicExpansion(F))
+      continue;
+    bool IntrinsicExpanded = false;
+    for (User *U : make_early_inc_range(F.users())) {
+      auto *IntrinsicCall = dyn_cast<CallInst>(U);
+      if (!IntrinsicCall)
+        continue;
+      IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
+    }
+    if (F.user_empty() && IntrinsicExpanded)
+      F.eraseFromParent();
+  }
+  return true;
+}
+
+PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
+                                              ModuleAnalysisManager &) {
+  if (intrinsicExpansion(M))
+    return PreservedAnalyses::none();
+  return PreservedAnalyses::all();
+}
+
+bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
+  return intrinsicExpansion(M);
+}
+
+char DXILIntrinsicExpansionLegacy::ID = 0;
+
+INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
+                      "DXIL Intrinsic Expansion", false, false)
+INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
+                    "DXIL Intrinsic Expansion", false, false)
+
+ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
+  return new DXILIntrinsicExpansionLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h
new file mode 100644
index 00000000000000..c86681af7a3712
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.h
@@ -0,0 +1,33 @@
+//===- DXILIntrinsicExpansion.h - Prepare LLVM Module for DXIL encoding----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
+#define LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
+
+#include "DXILResource.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+
+namespace llvm {
+
+/// A pass that transforms DXIL Intrinsics that don't have DXIL opCodes
+class DXILIntrinsicExpansion : public PassInfoMixin<DXILIntrinsicExpansion> {
+public:
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
+};
+
+class DXILIntrinsicExpansionLegacy : public ModulePass {
+
+public:
+  bool runOnModule(Module &M) override;
+  DXILIntrinsicExpansionLegacy() : ModulePass(ID) {}
+
+  static char ID; // Pass identification.
+};
+} // namespace llvm
+
+#endif // LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 6b649b76beecdf..e5c2042e7d16ae 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "DXILConstants.h"
+#include "DXILIntrinsicExpansion.h"
 #include "DXILOpBuilder.h"
 #include "DirectX.h"
 #include "llvm/ADT/SmallVector.h"
@@ -94,9 +95,12 @@ class DXILOpLoweringLegacy : public ModulePass {
   DXILOpLoweringLegacy() : ModulePass(ID) {}
 
   static char ID; // Pass identification.
+  void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
+    // Specify the passes that your pass depends on
+    AU.addRequired<DXILIntrinsicExpansionLegacy>();
+  }
 };
 char DXILOpLoweringLegacy::ID = 0;
-
 } // end anonymous namespace
 
 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index eaecc3ac280c4c..11b5412c21d783 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -28,6 +28,12 @@ void initializeDXILPrepareModulePass(PassRegistry &);
 /// Pass to convert modules into DXIL-compatable modules
 ModulePass *createDXILPrepareModulePass();
 
+/// Initializer for DXIL Intrinsic Expansion
+void initializeDXILIntrinsicExpansionLegacyPass(PassRegistry &);
+
+/// Pass to expand intrinsic operations that lack DXIL opCodes
+ModulePass *createDXILIntrinsicExpansionLegacyPass();
+
 /// Initializer for DXILOpLowering
 void initializeDXILOpLoweringLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 06938f8c74f155..03c825b3977db3 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -39,6 +39,7 @@ using namespace llvm;
 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   RegisterTargetMachine<DirectXTargetMachine> X(getTheDirectXTarget());
   auto *PR = PassRegistry::getPassRegistry();
+  initializeDXILIntrinsicExpansionLegacyPass(*PR);
   initializeDXILPrepareModulePass(*PR);
   initializeEmbedDXILPassPass(*PR);
   initializeWriteDXILPassPass(*PR);
@@ -76,6 +77,7 @@ class DirectXPassConfig : public TargetPassConfig {
 
   FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; }
   void addCodeGenPrepare() override {
+    addPass(createDXILIntrinsicExpansionLegacyPass());
     addPass(createDXILOpLoweringLegacyPass());
     addPass(createDXILPrepareModulePass());
     addPass(createDXILTranslateMetadataPass());
diff --git a/llvm/test/CodeGen/DirectX/any.ll b/llvm/test/CodeGen/DirectX/any.ll
new file mode 100644
index 00000000000000..516afa101e948d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/any.ll
@@ -0,0 +1,133 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for any are generated for float and half.
+
+
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxil-pc-shadermodel6.7-library"
+
+
+; CHECK:icmp ne i1 %{{.*}}, false
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_bool(i1 noundef %p0) #0 {
+entry:
+  %p0.addr = alloca i8, align 1
+  %frombool = zext i1 %p0 to i8
+  store i8 %frombool, ptr %p0.addr, align 1
+  %0 = load i8, ptr %p0.addr, align 1
+  %tobool = trunc i8 %0 to i1
+  %dx.any = call i1 @llvm.dx.any.i1(i1 %tobool)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.i1(i1) #1
+
+; CHECK:icmp ne i64 %{{.*}}, 0
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_int64_t(i64 noundef %p0) #0 {
+entry:
+  %p0.addr = alloca i64, align 8
+  store i64 %p0, ptr %p0.addr, align 8
+  %0 = load i64, ptr %p0.addr, align 8
+  %dx.any = call i1 @llvm.dx.any.i64(i64 %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.i64(i64) #1
+
+; CHECK:icmp ne i32 %{{.*}}, 0
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_int(i32 noundef %p0) #0 {
+entry:
+  %p0.addr = alloca i32, align 4
+  store i32 %p0, ptr %p0.addr, align 4
+  %0 = load i32, ptr %p0.addr, align 4
+  %dx.any = call i1 @llvm.dx.any.i32(i32 %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.i32(i32) #1
+
+; CHECK:icmp ne i16 %{{.*}}, 0
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_int16_t(i16 noundef %p0) #0 {
+entry:
+  %p0.addr = alloca i16, align 2
+  store i16 %p0, ptr %p0.addr, align 2
+  %0 = load i16, ptr %p0.addr, align 2
+  %dx.any = call i1 @llvm.dx.any.i16(i16 %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.i16(i16) #1
+
+; CHECK:fcmp une double %{{.*}}, 0.000000e+00
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_double(double noundef %p0) #0 {
+entry:
+  %p0.addr = alloca double, align 8
+  store double %p0, ptr %p0.addr, align 8
+  %0 = load double, ptr %p0.addr, align 8
+  %dx.any = call i1 @llvm.dx.any.f64(double %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.f64(double) #1
+
+; CHECK:fcmp une float %{{.*}}, 0.000000e+00
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_float(float noundef %p0) #0 {
+entry:
+  %p0.addr = alloca float, align 4
+  store float %p0, ptr %p0.addr, align 4
+  %0 = load float, ptr %p0.addr, align 4
+  %dx.any = call i1 @llvm.dx.any.f32(float %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.f32(float) #1
+
+; CHECK:fcmp une half %{{.*}}, 0xH0000
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @any_half(half noundef %p0) #0 {
+entry:
+  %p0.addr = alloca half, align 2
+  store half %p0, ptr %p0.addr, align 2
+  %0 = load half, ptr %p0.addr, align 2
+  %dx.any = call i1 @llvm.dx.any.f16(half %0)
+  ret i1 %dx.any
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn
+declare i1 @llvm.dx.any.f16(half) #1
+
+; CHECK:icmp ne <4 x i1> %extractvec, zeroinitialize
+; CHECK:extractelement <4 x i1> %{{.*}}, i64 0
+; CHECK:extractelement <4 x i1> %{{.*}}, i64 1
+; CHECK:or i1  %{{.*}}, %{{.*}}
+; CHECK:extractelement <4 x i1> %{{.*}}, i64 2
+...
[truncated]

@farzonl farzonl self-assigned this Mar 8, 2024
@farzonl farzonl requested a review from bogner March 11, 2024 16:54
@farzonl farzonl force-pushed the dxil-intrinsic-lowering branch 2 times, most recently from bb65d5b to e91291b Compare March 12, 2024 16:36
clang/lib/CodeGen/CGBuiltin.cpp Show resolved Hide resolved
clang/lib/Sema/SemaChecking.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp Outdated Show resolved Hide resolved
llvm/test/CodeGen/DirectX/any.ll Outdated Show resolved Hide resolved
llvm/test/CodeGen/DirectX/any.ll Outdated Show resolved Hide resolved
llvm/test/CodeGen/DirectX/any.ll Outdated Show resolved Hide resolved
llvm/test/CodeGen/DirectX/any.ll Outdated Show resolved Hide resolved
@farzonl farzonl linked an issue Mar 12, 2024 that may be closed by this pull request
Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple more very minor notes on the tests, then this LGTM

llvm/test/CodeGen/DirectX/any.ll Show resolved Hide resolved
llvm/test/CodeGen/DirectX/any.ll Outdated Show resolved Hide resolved
This change implements lowering for llvm#70076, llvm#70100, llvm#70072, & llvm#70102
`CGBuiltin.cpp` - - simplify `lerp` intrinsic
`IntrinsicsDirectX.td` - simplify `lerp` intrinsic
`SemaChecking.cpp` - remove unnecessary check
`DXILIntrinsicExpansion.*` - add intrinsic to instruction expansion cases
`DXILOpLowering.cpp` -  make sure `DXILIntrinsicExpansion` happens first
`DirectX.h` - changes to support new pass
`DirectXTargetMachine.cpp` - changes to support new pass

Why `any`, and `lerp` as instruction expansion just for DXIL?
-  SPIR-V there is an [OpAny](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpAny)
- SPIR-V has a GLSL lerp extension via [Fmix](https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html#FMix)

Why `exp` instruction expansion?
- We have an `exp2` opcode and `exp` reuses that opcode. So instruction expansion is a convenient way to do preprocessing.
- Further SPIR-V has a GLSL exp extension via  [Exp](https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html#Exp) and [Exp2](https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html#Exp2)

Why `rcp` as instruction expansion?
This one is a bit of the odd man out and might have to move to `cgbuiltins` when we better understand SPIRV requirements. However I included it because it seems like [fast math mode has an AllowRecip flag](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_fp_fast_math_mode) which lets you compute the reciprocal without performing the division.  We don't have that in DXIL so thought to include it.
@farzonl farzonl merged commit de1a97d into llvm:main Mar 15, 2024
3 of 4 checks passed
@farzonl farzonl deleted the dxil-intrinsic-lowering branch March 15, 2024 00:34
Copy link
Contributor

@coopp coopp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me for what I can understand.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX clang:codegen clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[HLSL] implement lerp intrinsic
5 participants