-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DXIL] Add Float Dot
Intrinsic Lowering
#86071
Conversation
@llvm/pr-subscribers-backend-directx @llvm/pr-subscribers-llvm-ir Author: Farzon Lotfi (farzonl) ChangesCompletes #83626
Patch is 20.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86071.diff 11 Files Affected:
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 77cb269d43c5a8..a4b99181769326 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18036,15 +18036,22 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
return Arg;
}
-Intrinsic::ID getDotProductIntrinsic(QualType QT) {
+Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) {
+ if (QT->hasFloatingRepresentation()) {
+ switch (elementCount) {
+ case 2:
+ return Intrinsic::dx_dot2;
+ case 3:
+ return Intrinsic::dx_dot3;
+ case 4:
+ return Intrinsic::dx_dot4;
+ }
+ }
if (QT->hasSignedIntegerRepresentation())
return Intrinsic::dx_sdot;
- if (QT->hasUnsignedIntegerRepresentation())
- return Intrinsic::dx_udot;
- assert(QT->hasFloatingRepresentation());
- return Intrinsic::dx_dot;
- ;
+ assert(QT->hasUnsignedIntegerRepresentation());
+ return Intrinsic::dx_udot;
}
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
@@ -18098,8 +18105,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
assert(T0->getScalarType() == T1->getScalarType() &&
"Dot product of vectors need the same element types.");
- [[maybe_unused]] auto *VecTy0 =
- E->getArg(0)->getType()->getAs<VectorType>();
+ auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *VecTy1 =
E->getArg(1)->getType()->getAs<VectorType>();
// A HLSLVectorTruncation should have happend
@@ -18108,7 +18114,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return Builder.CreateIntrinsic(
/*ReturnType=*/T0->getScalarType(),
- getDotProductIntrinsic(E->getArg(0)->getType()),
+ getDotProductIntrinsic(E->getArg(0)->getType(),
+ VecTy0->getNumElements()),
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
} break;
case Builtin::BI__builtin_hlsl_lerp: {
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index 0f993193c00cce..307d71cce3cb6d 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -110,21 +110,21 @@ uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
// NO_HALF: ret float %dx.dot
half test_dot_half(half p0, half p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1)
// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
// NO_HALF: ret float %dx.dot
half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1)
// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
// NO_HALF: ret float %dx.dot
half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }
-// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1)
// NATIVE_HALF: ret half %dx.dot
-// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
// NO_HALF: ret float %dx.dot
half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
@@ -132,34 +132,34 @@ half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
// CHECK: ret float %dx.dot
float test_dot_float(float p0, float p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
-// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
+// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
// CHECK: %conv = sitofp i32 %1 to float
// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
-// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat)
+// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %splat.splat)
// CHECK: ret float %dx.dot
float test_builtin_dot_float2_int_splat(float2 p0, int p1) {
return dot(p0, p1);
@@ -168,7 +168,7 @@ float test_builtin_dot_float2_int_splat(float2 p0, int p1) {
// CHECK: %conv = sitofp i32 %1 to float
// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
-// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat)
+// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %splat.splat)
// CHECK: ret float %dx.dot
float test_builtin_dot_float3_int_splat(float3 p0, int p1) {
return dot(p0, p1);
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 1164b241ba7b0d..a871fac46b9fd5 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -24,7 +24,15 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
-def int_dx_dot :
+def int_dx_dot2 :
+ Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
+def int_dx_dot3 :
+ Intrinsic<[LLVMVectorElementType<0>],
+ [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
+ [IntrNoMem, IntrWillReturn, Commutative] >;
+def int_dx_dot4 :
Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 36eb29d53766f0..9e393902e2f208 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -295,6 +295,15 @@ def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
"Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
"Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
+def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
+ "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1",
+ [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>;
+def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
+ "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2",
+ [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>;
+def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
+ "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3",
+ [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>;
def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
"Reads the thread ID">;
def GroupId : DXILOpMapping<94, groupId, int_dx_group_id,
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index a1eacc2d48009c..990557710f8c53 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -254,7 +254,7 @@ namespace dxil {
CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
Type *OverloadTy,
- llvm::iterator_range<Use *> Args) {
+ SmallVector<Value *> Args) {
const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
OverloadKind Kind = getOverloadKind(OverloadTy);
@@ -272,10 +272,8 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
}
- SmallVector<Value *> FullArgs;
- FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
- FullArgs.append(Args.begin(), Args.end());
- return B.CreateCall(DXILFn, FullArgs);
+
+ return B.CreateCall(DXILFn, Args);
}
Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index f3abcc6e02a4e3..5babeae470178b 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -13,7 +13,7 @@
#define LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H
#include "DXILConstants.h"
-#include "llvm/ADT/iterator_range.h"
+#include "llvm/ADT/SmallVector.h"
namespace llvm {
class Module;
@@ -35,8 +35,7 @@ class DXILOpBuilder {
/// \param OverloadTy Overload type of the DXIL Op call constructed
/// \return DXIL Op call constructed
CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
- Type *OverloadTy,
- llvm::iterator_range<Use *> Args);
+ Type *OverloadTy, SmallVector<Value *> Args);
Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
static const char *getOpCodeName(dxil::OpCode DXILOp);
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 3e334b0ec298d3..f09e322f88e1fd 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -30,6 +30,48 @@
using namespace llvm;
using namespace llvm::dxil;
+static bool isVectorArgExpansion(Function &F) {
+ switch (F.getIntrinsicID()) {
+ case Intrinsic::dx_dot2:
+ case Intrinsic::dx_dot3:
+ case Intrinsic::dx_dot4:
+ return true;
+ }
+ return false;
+}
+
+static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
+ SmallVector<Value *, 4> ExtractedElements;
+ auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+ for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
+ Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
+ Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
+ ExtractedElements.push_back(ExtractedElement);
+ }
+ return ExtractedElements;
+}
+
+static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
+ IRBuilder<> &Builder) {
+ // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
+ unsigned NumOperands = Orig->getNumOperands() - 1;
+ assert(NumOperands > 0);
+ Value *Arg0 = Orig->getOperand(0);
+ [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
+ assert(VecArg0);
+ SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
+ for (unsigned I = 1; I < NumOperands; ++I) {
+ Value *Arg = Orig->getOperand(I);
+ [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
+ assert(VecArg);
+ assert(VecArg0->getElementType() == VecArg->getElementType());
+ assert(VecArg0->getNumElements() == VecArg->getNumElements());
+ auto NextOperandList = populateOperands(Arg, Builder);
+ NewOperands.append(NextOperandList.begin(), NextOperandList.end());
+ }
+ return NewOperands;
+}
+
static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
IRBuilder<> B(M.getContext());
DXILOpBuilder DXILB(M, B);
@@ -39,9 +81,18 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
if (!CI)
continue;
+ SmallVector<Value *> Args;
+ Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
+ Args.emplace_back(DXILOpArg);
B.SetInsertPoint(CI);
- CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(),
- OverloadTy, CI->args());
+ if (isVectorArgExpansion(F)) {
+ SmallVector<Value *> NewArgs = argVectorFlatten(CI, B);
+ Args.append(NewArgs.begin(), NewArgs.end());
+ } else
+ Args.append(CI->arg_begin(), CI->arg_end());
+
+ CallInst *DXILCI =
+ DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args);
CI->replaceAllUsesWith(DXILCI);
CI->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/dot2_error.ll b/llvm/test/CodeGen/DirectX/dot2_error.ll
new file mode 100644
index 00000000000000..a27bfaedacd573
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot2_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation dot2 does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @dot_double2(<2 x double> noundef %a, <2 x double> noundef %b) {
+entry:
+ %dx.dot = call double @llvm.dx.dot2.v2f64(<2 x double> %a, <2 x double> %b)
+ ret double %dx.dot
+}
diff --git a/llvm/test/CodeGen/DirectX/dot3_error.ll b/llvm/test/CodeGen/DirectX/dot3_error.ll
new file mode 100644
index 00000000000000..eb69fb145038aa
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot3_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation dot3 does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @dot_double3(<3 x double> noundef %a, <3 x double> noundef %b) {
+entry:
+ %dx.dot = call double @llvm.dx.dot3.v3f64(<3 x double> %a, <3 x double> %b)
+ ret double %dx.dot
+}
diff --git a/llvm/test/CodeGen/DirectX/dot4_error.ll b/llvm/test/CodeGen/DirectX/dot4_error.ll
new file mode 100644
index 00000000000000..5cd632684c0c01
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/dot4_error.ll
@@ -0,0 +1,10 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation dot4 does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @dot_double4(<4 x double> noundef %a, <4 x double> noundef %b) {
+entry:
+ %dx.dot = call double @llvm.dx.dot4.v4f64(<4 x double> %a, <4 x double> %b)
+ ret double %dx.dot
+}
diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll
new file mode 100644
index 00000000000000..3e13b2ad2650c8
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/fdot.ll
@@ -0,0 +1,94 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for dot are generated for int/uint vectors.
+
+; CHECK-LABEL: dot_half2
+define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) {
+entry:
+; CHECK: extractelement <2 x half> %a, i32 0
+; CHECK: extractelement <2 x half> %a, i32 1
+; CHECK: extractelement <2 x half> %b, i32 0
+; CHECK: extractelement <2 x half> %b, i32 1
+; CHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b)
+ ret half %dx.dot
+}
+
+; CHECK-LABEL: dot_half3
+define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) {
+entry:
+; CHECK: extractelement <3 x half> %a, i32 0
+; CHECK: extractelement <3 x half> %a, i32 1
+; CHECK: extractelement <3 x half> %a, i32 2
+; CHECK: extractelement <3 x half> %b, i32 0
+; CHECK: extractelement <3 x half> %b, i32 1
+; CHECK: extractelement <3 x half> %b, i32 2
+; CHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b)
+ ret half %dx.dot
+}
+
+; CHECK-LABEL: dot_half4
+define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
+entry:
+; CHECK: extractelement <4 x half> %a, i32 0
+; CHECK: extractelement <4 x half> %a, i32 1
+; CHECK: extractelement <4 x half> %a, i32 2
+; CHECK: extractelement <4 x half> %a, i32 3
+; CHECK: extractelement <4 x half> %b, i32 0
+; CHECK: extractelement <4 x half> %b, i32 1
+; CHECK: extractelement <4 x half> %b, i32 2
+; CHECK: extractelement <4 x half> %b, i32 3
+; CHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b)
+ ret half %dx.dot
+}
+
+; CHECK-LABEL: dot_float2
+define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) {
+entry:
+; CHECK: extractelement <2 x float> %a, i32 0
+; CHECK: extractelement <2 x float> %a, i32 1
+; CHECK: extractelement <2 x float> %b, i32 0
+; CHECK: extractelement <2 x float> %b, i32 1
+; CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+ %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b)
+ ret float %dx.dot
+}
+
+; CHECK-LABEL: dot_float3
+define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) {
+entry:
+; CHECK: extractelement <3 x float> %a, i32 0
+; CHECK: extractelement <3 x float> %a, i32 1
+; CHECK: extractelement <3 x float> %a, i32 2
+; CHECK: extractelement <3 x float> %b, i32 0
+; CHECK: extractelement <3 x float> %b, i32 1
+; CHECK: extractelement <3 x float> %b, i32 2
+; CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
+ %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b)
+ ret float %dx.dot
+}
+
+; CHECK-LABEL: dot_float4
+define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
+entry:
+; CHECK: extractelement <4 x float> %a, i32 0
+; CHECK: extractelement <4 x float> %a, i32 1
+; CHECK: extractelement <4 x float> %a, i32 2
+; CHECK: extractelement <4 x float> %a, i32 3
+; CHECK: extractelement <4 x float> %b, i32 0
+; CHECK: extractelement <4 x fl...
[truncated]
|
Dot
Intrinsic LoweringDot
Intrinsic Lowering
Completes llvm#83626 - `CGBuiltin.cpp` - modify `getDotProductIntrinsic` to be able to emit `dot2`, `dot3`, and `dot4` intrinsics based on element count - `IntrinsicsDirectX.td` - for floating point add `dot2`,`dot3`, and `dot4` inntrinsics -`DXIL.td` add dxilop intrinsic lowering for `dot2`,`dot3`, & `dot4`. -`DXILOpLowering.cpp` - add vector arg flattening for dot product. -`DXILOpBuilder.h` - modify `createDXILOpCall` to take a smallVector instead of an iterator - `DXILOpBuilder.cpp` - modify createDXILOpCall by moving the small vector up to the callee function in `DXILOpLowering.cpp`. Moving one function up gives us access to the callInst and Function Which were needed to distinguish the dot product intrinsics and get the operands without using the iterator.
7e794b6
to
98fa816
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
✅ With the latest revision this PR passed the Python code formatter. |
Completes #83626
CGBuiltin.cpp
- modifygetDotProductIntrinsic
to be able to emitdot2
,dot3
, anddot4
intrinsics based on element countIntrinsicsDirectX.td
- for floating point adddot2
,dot3
, anddot4
inntrinsics -DXIL.td
add dxilop intrinsic lowering fordot2
,dot3
, &dot4
.DXILOpLowering.cpp
- add vector arg flattening for dot product.DXILOpBuilder.h
- modifycreateDXILOpCall
to take a smallVector instead of an iteratorDXILOpBuilder.cpp
- modifycreateDXILOpCall
by moving the small vector up to the calling function inDXILOpLowering.cpp
.CallInst
andFunction
which were needed to distinguish the dot product intrinsics and get the operands without using the iterator.