Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DXIL] Add Float Dot Intrinsic Lowering #86071

Merged
merged 2 commits into from
Mar 25, 2024

Conversation

farzonl
Copy link
Member

@farzonl farzonl commented Mar 21, 2024

Completes #83626

  • CGBuiltin.cpp - modify getDotProductIntrinsic to be able to emit dot2, dot3, and dot4 intrinsics based on element count
  • IntrinsicsDirectX.td - for floating point add dot2, dot3, and dot4 inntrinsics -DXIL.td add dxilop intrinsic lowering for dot2, dot3, & dot4.
  • DXILOpLowering.cpp - add vector arg flattening for dot product.
  • DXILOpBuilder.h - modify createDXILOpCall to take a smallVector instead of an iterator
  • DXILOpBuilder.cpp - modify createDXILOpCall by moving the small vector up to the calling function in DXILOpLowering.cpp.
    • Moving one function up gives us access to the CallInst and Function which were needed to distinguish the dot product intrinsics and get the operands without using the iterator.

@farzonl farzonl self-assigned this Mar 21, 2024
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:codegen backend:DirectX HLSL HLSL Language Support llvm:ir labels Mar 21, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 21, 2024

@llvm/pr-subscribers-backend-directx
@llvm/pr-subscribers-clang
@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-llvm-ir

Author: Farzon Lotfi (farzonl)

Changes

Completes #83626

  • CGBuiltin.cpp - modify getDotProductIntrinsic to be able to emit dot2, dot3, and dot4 intrinsics based on element count
  • IntrinsicsDirectX.td - for floating point add dot2, dot3, and dot4 inntrinsics -DXIL.td add dxilop intrinsic lowering for dot2, dot3, & dot4.
  • DXILOpLowering.cpp - add vector arg flattening for dot product.
  • DXILOpBuilder.h - modify createDXILOpCall to take a smallVector instead of an iterator
  • DXILOpBuilder.cpp - modify createDXILOpCall by moving the small vector up to the calling function in DXILOpLowering.cpp.
    • Moving one function up gives us access to the CallInst and Function which were needed to distinguish the dot product intrinsics and get the operands without using the iterator.

Patch is 20.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86071.diff

11 Files Affected:

  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+16-9)
  • (modified) clang/test/CodeGenHLSL/builtins/dot.hlsl (+14-14)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+9-1)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+9)
  • (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+3-5)
  • (modified) llvm/lib/Target/DirectX/DXILOpBuilder.h (+2-3)
  • (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+53-2)
  • (added) llvm/test/CodeGen/DirectX/dot2_error.ll (+10)
  • (added) llvm/test/CodeGen/DirectX/dot3_error.ll (+10)
  • (added) llvm/test/CodeGen/DirectX/dot4_error.ll (+10)
  • (added) llvm/test/CodeGen/DirectX/fdot.ll (+94)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 77cb269d43c5a8..a4b99181769326 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18036,15 +18036,22 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
   return Arg;
 }
 
-Intrinsic::ID getDotProductIntrinsic(QualType QT) {
+Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) {
+  if (QT->hasFloatingRepresentation()) {
+    switch (elementCount) {
+    case 2:
+      return Intrinsic::dx_dot2;
+    case 3:
+      return Intrinsic::dx_dot3;
+    case 4:
+      return Intrinsic::dx_dot4;
+    }
+  }
   if (QT->hasSignedIntegerRepresentation())
     return Intrinsic::dx_sdot;
-  if (QT->hasUnsignedIntegerRepresentation())
-    return Intrinsic::dx_udot;
 
-  assert(QT->hasFloatingRepresentation());
-  return Intrinsic::dx_dot;
-  ;
+  assert(QT->hasUnsignedIntegerRepresentation());
+  return Intrinsic::dx_udot;
 }
 
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
@@ -18098,8 +18105,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     assert(T0->getScalarType() == T1->getScalarType() &&
            "Dot product of vectors need the same element types.");
 
-    [[maybe_unused]] auto *VecTy0 =
-        E->getArg(0)->getType()->getAs<VectorType>();
+    auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
     [[maybe_unused]] auto *VecTy1 =
         E->getArg(1)->getType()->getAs<VectorType>();
     // A HLSLVectorTruncation should have happend
@@ -18108,7 +18114,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
 
     return Builder.CreateIntrinsic(
         /*ReturnType=*/T0->getScalarType(),
-        getDotProductIntrinsic(E->getArg(0)->getType()),
+        getDotProductIntrinsic(E->getArg(0)->getType(),
+                               VecTy0->getNumElements()),
         ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
   } break;
   case Builtin::BI__builtin_hlsl_lerp: {
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index 0f993193c00cce..307d71cce3cb6d 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -110,21 +110,21 @@ uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
 // NO_HALF: ret float %dx.dot
 half test_dot_half(half p0, half p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1)
 // NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
 // NO_HALF: ret float %dx.dot
 half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1)
 // NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
 // NO_HALF: ret float %dx.dot
 half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }
 
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1)
 // NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
 // NO_HALF: ret float %dx.dot
 half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
 
@@ -132,34 +132,34 @@ half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
 // CHECK: ret float %dx.dot
 float test_dot_float(float p0, float p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
 
-// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK:  %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK:  %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
 
-// CHECK:  %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK:  %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
 
 // CHECK: %conv = sitofp i32 %1 to float
 // CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
 // CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat)
+// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %splat.splat)
 // CHECK: ret float %dx.dot
 float test_builtin_dot_float2_int_splat(float2 p0, int p1) {
   return dot(p0, p1);
@@ -168,7 +168,7 @@ float test_builtin_dot_float2_int_splat(float2 p0, int p1) {
 // CHECK: %conv = sitofp i32 %1 to float
 // CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
 // CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat)
+// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %splat.splat)
 // CHECK: ret float %dx.dot
 float test_builtin_dot_float3_int_splat(float3 p0, int p1) {
   return dot(p0, p1);
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 1164b241ba7b0d..a871fac46b9fd5 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -24,7 +24,15 @@ def int_dx_any  : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
 def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; 
 
-def int_dx_dot : 
+def int_dx_dot2 : 
+    Intrinsic<[LLVMVectorElementType<0>], 
+    [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+    [IntrNoMem, IntrWillReturn, Commutative] >;
+def int_dx_dot3 : 
+    Intrinsic<[LLVMVectorElementType<0>], 
+    [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+    [IntrNoMem, IntrWillReturn, Commutative] >;
+def int_dx_dot4 : 
     Intrinsic<[LLVMVectorElementType<0>], 
     [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, IntrWillReturn, Commutative] >;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 36eb29d53766f0..9e393902e2f208 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -295,6 +295,15 @@ def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
                          "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
 def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
                          "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
+def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
+                         "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1",
+                         [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>;
+def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
+                         "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2",
+                         [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>;
+def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
+                         "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3",
+                         [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>;
 def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
                              "Reads the thread ID">;
 def GroupId  : DXILOpMapping<94, groupId, int_dx_group_id,
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index a1eacc2d48009c..990557710f8c53 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -254,7 +254,7 @@ namespace dxil {
 
 CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
                                           Type *OverloadTy,
-                                          llvm::iterator_range<Use *> Args) {
+                                          SmallVector<Value *> Args) {
   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
 
   OverloadKind Kind = getOverloadKind(OverloadTy);
@@ -272,10 +272,8 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
     FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
     DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
   }
-  SmallVector<Value *> FullArgs;
-  FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
-  FullArgs.append(Args.begin(), Args.end());
-  return B.CreateCall(DXILFn, FullArgs);
+
+  return B.CreateCall(DXILFn, Args);
 }
 
 Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index f3abcc6e02a4e3..5babeae470178b 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -13,7 +13,7 @@
 #define LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H
 
 #include "DXILConstants.h"
-#include "llvm/ADT/iterator_range.h"
+#include "llvm/ADT/SmallVector.h"
 
 namespace llvm {
 class Module;
@@ -35,8 +35,7 @@ class DXILOpBuilder {
   /// \param OverloadTy Overload type of the DXIL Op call constructed
   /// \return DXIL Op call constructed
   CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
-                             Type *OverloadTy,
-                             llvm::iterator_range<Use *> Args);
+                             Type *OverloadTy, SmallVector<Value *> Args);
   Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
   static const char *getOpCodeName(dxil::OpCode DXILOp);
 
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 3e334b0ec298d3..f09e322f88e1fd 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -30,6 +30,48 @@
 using namespace llvm;
 using namespace llvm::dxil;
 
+static bool isVectorArgExpansion(Function &F) {
+  switch (F.getIntrinsicID()) {
+  case Intrinsic::dx_dot2:
+  case Intrinsic::dx_dot3:
+  case Intrinsic::dx_dot4:
+    return true;
+  }
+  return false;
+}
+
+static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
+  SmallVector<Value *, 4> ExtractedElements;
+  auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+  for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
+    Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
+    Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
+    ExtractedElements.push_back(ExtractedElement);
+  }
+  return ExtractedElements;
+}
+
+static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
+                                             IRBuilder<> &Builder) {
+  // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
+  unsigned NumOperands = Orig->getNumOperands() - 1;
+  assert(NumOperands > 0);
+  Value *Arg0 = Orig->getOperand(0);
+  [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
+  assert(VecArg0);
+  SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
+  for (unsigned I = 1; I < NumOperands; ++I) {
+    Value *Arg = Orig->getOperand(I);
+    [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+    assert(VecArg);
+    assert(VecArg0->getElementType() == VecArg->getElementType());
+    assert(VecArg0->getNumElements() == VecArg->getNumElements());
+    auto NextOperandList = populateOperands(Arg, Builder);
+    NewOperands.append(NextOperandList.begin(), NextOperandList.end());
+  }
+  return NewOperands;
+}
+
 static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
   IRBuilder<> B(M.getContext());
   DXILOpBuilder DXILB(M, B);
@@ -39,9 +81,18 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
     if (!CI)
       continue;
 
+    SmallVector<Value *> Args;
+    Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
+    Args.emplace_back(DXILOpArg);
     B.SetInsertPoint(CI);
-    CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(),
-                                              OverloadTy, CI->args());
+    if (isVectorArgExpansion(F)) {
+      SmallVector<Value *> NewArgs = argVectorFlatten(CI, B);
+      Args.append(NewArgs.begin(), NewArgs.end());
+    } else
+      Args.append(CI->arg_begin(), CI->arg_end());
+
+    CallInst *DXILCI =
+        DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args);
 
     CI->replaceAllUsesWith(DXILCI);
     CI->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/dot2_error.ll b/llvm/test/CodeGen/DirectX/dot2_error.ll
new file mode 100644
index 00000000000000..a27bfaedacd573
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot2_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation dot2 does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @dot_double2(<2 x double> noundef %a, <2 x double> noundef %b) {
+entry:
+  %dx.dot = call double @llvm.dx.dot2.v2f64(<2 x double> %a, <2 x double> %b)
+  ret double %dx.dot
+}
diff --git a/llvm/test/CodeGen/DirectX/dot3_error.ll b/llvm/test/CodeGen/DirectX/dot3_error.ll
new file mode 100644
index 00000000000000..eb69fb145038aa
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot3_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation dot3 does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @dot_double3(<3 x double> noundef %a, <3 x double> noundef %b) {
+entry:
+  %dx.dot = call double @llvm.dx.dot3.v3f64(<3 x double> %a, <3 x double> %b)
+  ret double %dx.dot
+}
diff --git a/llvm/test/CodeGen/DirectX/dot4_error.ll b/llvm/test/CodeGen/DirectX/dot4_error.ll
new file mode 100644
index 00000000000000..5cd632684c0c01
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot4_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation dot4 does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @dot_double4(<4 x double> noundef %a, <4 x double> noundef %b) {
+entry:
+  %dx.dot = call double @llvm.dx.dot4.v4f64(<4 x double> %a, <4 x double> %b)
+  ret double %dx.dot
+}
diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll
new file mode 100644
index 00000000000000..3e13b2ad2650c8
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/fdot.ll
@@ -0,0 +1,94 @@
+; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s
+
+; Make sure dxil operation function calls for dot are generated for int/uint vectors.
+
+; CHECK-LABEL: dot_half2
+define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) {
+entry:
+; CHECK: extractelement <2 x half> %a, i32 0
+; CHECK: extractelement <2 x half> %a, i32 1
+; CHECK: extractelement <2 x half> %b, i32 0
+; CHECK: extractelement <2 x half> %b, i32 1
+; CHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b)
+  ret half %dx.dot
+}
+
+; CHECK-LABEL: dot_half3
+define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) {
+entry:
+; CHECK: extractelement <3 x half> %a, i32 0
+; CHECK: extractelement <3 x half> %a, i32 1
+; CHECK: extractelement <3 x half> %a, i32 2
+; CHECK: extractelement <3 x half> %b, i32 0
+; CHECK: extractelement <3 x half> %b, i32 1
+; CHECK: extractelement <3 x half> %b, i32 2
+; CHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b)
+  ret half %dx.dot
+}
+
+; CHECK-LABEL: dot_half4
+define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
+entry:
+; CHECK: extractelement <4 x half> %a, i32 0
+; CHECK: extractelement <4 x half> %a, i32 1
+; CHECK: extractelement <4 x half> %a, i32 2
+; CHECK: extractelement <4 x half> %a, i32 3
+; CHECK: extractelement <4 x half> %b, i32 0
+; CHECK: extractelement <4 x half> %b, i32 1
+; CHECK: extractelement <4 x half> %b, i32 2
+; CHECK: extractelement <4 x half> %b, i32 3
+; CHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b)
+  ret half %dx.dot
+}
+
+; CHECK-LABEL: dot_float2
+define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) {
+entry:
+; CHECK: extractelement <2 x float> %a, i32 0
+; CHECK: extractelement <2 x float> %a, i32 1
+; CHECK: extractelement <2 x float> %b, i32 0
+; CHECK: extractelement <2 x float> %b, i32 1
+; CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+  %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b)
+  ret float %dx.dot
+}
+
+; CHECK-LABEL: dot_float3
+define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) {
+entry:
+; CHECK: extractelement <3 x float> %a, i32 0
+; CHECK: extractelement <3 x float> %a, i32 1
+; CHECK: extractelement <3 x float> %a, i32 2
+; CHECK: extractelement <3 x float> %b, i32 0
+; CHECK: extractelement <3 x float> %b, i32 1
+; CHECK: extractelement <3 x float> %b, i32 2
+; CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+  %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b)
+  ret float %dx.dot
+}
+
+; CHECK-LABEL: dot_float4
+define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
+entry:
+; CHECK: extractelement <4 x float> %a, i32 0
+; CHECK: extractelement <4 x float> %a, i32 1
+; CHECK: extractelement <4 x float> %a, i32 2
+; CHECK: extractelement <4 x float> %a, i32 3
+; CHECK: extractelement <4 x float> %b, i32 0
+; CHECK: extractelement <4 x fl...
[truncated]

@farzonl farzonl changed the title Add Float Dot Intrinsic Lowering [DXIL] Add Float Dot Intrinsic Lowering Mar 21, 2024
@farzonl farzonl linked an issue Mar 21, 2024 that may be closed by this pull request
Completes llvm#83626
- `CGBuiltin.cpp` - modify `getDotProductIntrinsic` to be able to emit
  `dot2`, `dot3`, and `dot4` intrinsics based on element count
- `IntrinsicsDirectX.td` - for floating point add `dot2`,`dot3`, and `dot4`
  inntrinsics
-`DXIL.td` add dxilop intrinsic lowering for  `dot2`,`dot3`, & `dot4`.
-`DXILOpLowering.cpp` - add vector arg flattening for dot product.
-`DXILOpBuilder.h` - modify `createDXILOpCall` to take a smallVector
instead of an iterator
- `DXILOpBuilder.cpp` - modify createDXILOpCall by moving the small
  vector up to the callee function  in `DXILOpLowering.cpp`. Moving one
  function up gives us access to the callInst and Function Which were
  needed to distinguish the dot product intrinsics and get the operands
  without using the iterator.
Copy link

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link

✅ With the latest revision this PR passed the Python code formatter.

@farzonl farzonl merged commit 060df78 into llvm:main Mar 25, 2024
3 of 4 checks passed
@farzonl farzonl deleted the dxil-dot-float-lowering branch March 27, 2024 01:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX clang:codegen clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[DXIL] implement dot intrinsic lowering
4 participants