diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index e965df810add5..e89691ab7921c 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -18036,6 +18036,17 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments, return Arg; } +Intrinsic::ID getDotProductIntrinsic(QualType QT) { + if (QT->hasSignedIntegerRepresentation()) + return Intrinsic::dx_sdot; + if (QT->hasUnsignedIntegerRepresentation()) + return Intrinsic::dx_udot; + + assert(QT->hasFloatingRepresentation()); + return Intrinsic::dx_dot; + ; +} + Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E) { if (!getLangOpts().HLSL) @@ -18096,7 +18107,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, "Dot product requires vectors to be of the same size."); return Builder.CreateIntrinsic( - /*ReturnType=*/T0->getScalarType(), Intrinsic::dx_dot, + /*ReturnType=*/T0->getScalarType(), + getDotProductIntrinsic(E->getArg(0)->getType()), ArrayRef{Op0, Op1}, nullptr, "dx.dot"); } break; case Builtin::BI__builtin_hlsl_lerp: { diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index a0b256ab5579e..f9112a29027ac 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -5484,6 +5484,18 @@ bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) { checkFloatorHalf); } +bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) { + auto checkDoubleVector = [](clang::QualType PassedType) -> bool { + if (const auto *VecTy = dyn_cast(PassedType)) { + clang::QualType BaseType = VecTy->getElementType(); + return !BaseType->isHalfType() && !BaseType->isFloat32Type(); + } + return false; + }; + return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy, + checkDoubleVector); +} + void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall, QualType ReturnType) { auto *VecTyA = TheCall->getArg(0)->getType()->getAs(); @@ -5520,6 +5532,8 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { return true; if (SemaBuiltinVectorToScalarMath(TheCall)) return true; + if (CheckNoDoubleVectors(this, TheCall)) + return true; break; } case Builtin::BI__builtin_hlsl_elementwise_rcp: { diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl index c064d118caf3e..0f993193c00cc 100644 --- a/clang/test/CodeGenHLSL/builtins/dot.hlsl +++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl @@ -11,15 +11,15 @@ // NATIVE_HALF: ret i16 %dx.dot int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1) +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v2i16(<2 x i16> %0, <2 x i16> %1) // NATIVE_HALF: ret i16 %dx.dot int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1) +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v3i16(<3 x i16> %0, <3 x i16> %1) // NATIVE_HALF: ret i16 %dx.dot int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1) +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v4i16(<4 x i16> %0, <4 x i16> %1) // NATIVE_HALF: ret i16 %dx.dot int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); } @@ -27,15 +27,15 @@ int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); } // NATIVE_HALF: ret i16 %dx.dot uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1) +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v2i16(<2 x i16> %0, <2 x i16> %1) // NATIVE_HALF: ret i16 %dx.dot uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1) +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %0, <3 x i16> %1) // NATIVE_HALF: ret i16 %dx.dot uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1) +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v4i16(<4 x i16> %0, <4 x i16> %1) // NATIVE_HALF: ret i16 %dx.dot uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); } #endif @@ -44,15 +44,15 @@ uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); } // CHECK: ret i32 %dx.dot int test_dot_int(int p0, int p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1) +// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v2i32(<2 x i32> %0, <2 x i32> %1) // CHECK: ret i32 %dx.dot int test_dot_int2(int2 p0, int2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1) +// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v3i32(<3 x i32> %0, <3 x i32> %1) // CHECK: ret i32 %dx.dot int test_dot_int3(int3 p0, int3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1) +// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %0, <4 x i32> %1) // CHECK: ret i32 %dx.dot int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); } @@ -60,15 +60,15 @@ int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); } // CHECK: ret i32 %dx.dot uint test_dot_uint(uint p0, uint p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1) +// CHECK: %dx.dot = call i32 @llvm.dx.udot.v2i32(<2 x i32> %0, <2 x i32> %1) // CHECK: ret i32 %dx.dot uint test_dot_uint2(uint2 p0, uint2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1) +// CHECK: %dx.dot = call i32 @llvm.dx.udot.v3i32(<3 x i32> %0, <3 x i32> %1) // CHECK: ret i32 %dx.dot uint test_dot_uint3(uint3 p0, uint3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1) +// CHECK: %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %0, <4 x i32> %1) // CHECK: ret i32 %dx.dot uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); } @@ -76,15 +76,15 @@ uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); } // CHECK: ret i64 %dx.dot int64_t test_dot_long(int64_t p0, int64_t p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1) +// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v2i64(<2 x i64> %0, <2 x i64> %1) // CHECK: ret i64 %dx.dot int64_t test_dot_long2(int64_t2 p0, int64_t2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1) +// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v3i64(<3 x i64> %0, <3 x i64> %1) // CHECK: ret i64 %dx.dot int64_t test_dot_long3(int64_t3 p0, int64_t3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1) +// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v4i64(<4 x i64> %0, <4 x i64> %1) // CHECK: ret i64 %dx.dot int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); } @@ -92,15 +92,15 @@ int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); } // CHECK: ret i64 %dx.dot uint64_t test_dot_ulong(uint64_t p0, uint64_t p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1) +// CHECK: %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %0, <2 x i64> %1) // CHECK: ret i64 %dx.dot uint64_t test_dot_ulong2(uint64_t2 p0, uint64_t2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1) +// CHECK: %dx.dot = call i64 @llvm.dx.udot.v3i64(<3 x i64> %0, <3 x i64> %1) // CHECK: ret i64 %dx.dot uint64_t test_dot_ulong3(uint64_t3 p0, uint64_t3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1) +// CHECK: %dx.dot = call i64 @llvm.dx.udot.v4i64(<4 x i64> %0, <4 x i64> %1) // CHECK: ret i64 %dx.dot uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); } diff --git a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl index 59eb9482b9ef9..ba7ffc20484ae 100644 --- a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl +++ b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl @@ -108,3 +108,12 @@ int test_builtin_dot_bool_type_promotion(bool p0, bool p1) { return __builtin_hlsl_dot(p0, p1); // expected-error@-1 {{1st argument must be a vector, integer or floating point type (was 'bool')}} } + +double test_dot_double(double2 p0, double2 p1) { + return dot(p0, p1); + // expected-error@-1 {{call to 'dot' is ambiguous}} +} +double test_dot_double_builtin(double2 p0, double2 p1) { + return __builtin_hlsl_dot(p0, p1); + // expected-error@-1 {{passing 'double2' (aka 'vector') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}} +} diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 5c72f06f96ed1..1164b241ba7b0 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -23,9 +23,18 @@ def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">, def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>; def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; + def int_dx_dot : Intrinsic<[LLVMVectorElementType<0>], - [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], + [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], + [IntrNoMem, IntrWillReturn, Commutative] >; +def int_dx_sdot : + Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], + [IntrNoMem, IntrWillReturn, Commutative] >; +def int_dx_udot : + Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], [IntrNoMem, IntrWillReturn, Commutative] >; def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>; diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index bc38c10a1fceb..0db42bc0a0fb6 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -39,11 +39,44 @@ static bool isIntrinsicExpansion(Function &F) { case Intrinsic::dx_uclamp: case Intrinsic::dx_lerp: case Intrinsic::dx_rcp: + case Intrinsic::dx_sdot: + case Intrinsic::dx_udot: return true; } return false; } +static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) { + assert(DotIntrinsic == Intrinsic::dx_sdot || + DotIntrinsic == Intrinsic::dx_udot); + Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot + ? Intrinsic::dx_imad + : Intrinsic::dx_umad; + Value *A = Orig->getOperand(0); + Value *B = Orig->getOperand(1); + Type *ATy = A->getType(); + Type *BTy = B->getType(); + assert(ATy->isVectorTy() && BTy->isVectorTy()); + + IRBuilder<> Builder(Orig->getParent()); + Builder.SetInsertPoint(Orig); + + auto *AVec = dyn_cast(A->getType()); + Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0); + Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0); + Value *Result = Builder.CreateMul(Elt0, Elt1); + for (unsigned I = 1; I < AVec->getNumElements(); I++) { + Elt0 = Builder.CreateExtractElement(A, I); + Elt1 = Builder.CreateExtractElement(B, I); + Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic, + ArrayRef{Elt0, Elt1, Result}, + nullptr, "dx.mad"); + } + Orig->replaceAllUsesWith(Result); + Orig->eraseFromParent(); + return true; +} + static bool expandExpIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); IRBuilder<> Builder(Orig->getParent()); @@ -191,6 +224,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) { return expandLerpIntrinsic(Orig); case Intrinsic::dx_rcp: return expandRcpIntrinsic(Orig); + case Intrinsic::dx_sdot: + case Intrinsic::dx_udot: + return expandIntegerDot(Orig, F.getIntrinsicID()); } return false; } diff --git a/llvm/test/CodeGen/DirectX/idot.ll b/llvm/test/CodeGen/DirectX/idot.ll new file mode 100644 index 0000000000000..9f89a8d6d340d --- /dev/null +++ b/llvm/test/CodeGen/DirectX/idot.ll @@ -0,0 +1,100 @@ +; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK +; RUN: opt -S -dxil-op-lower < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK + +; Make sure dxil operation function calls for dot are generated for int/uint vectors. + +; CHECK-LABEL: dot_int16_t2 +define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) { +entry: +; CHECK: extractelement <2 x i16> %a, i64 0 +; CHECK: extractelement <2 x i16> %b, i64 0 +; CHECK: mul i16 %{{.*}}, %{{.*}} +; CHECK: extractelement <2 x i16> %a, i64 1 +; CHECK: extractelement <2 x i16> %b, i64 1 +; EXPCHECK: call i16 @llvm.dx.imad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) +; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 48, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) + %dx.dot = call i16 @llvm.dx.sdot.v3i16(<2 x i16> %a, <2 x i16> %b) + ret i16 %dx.dot +} + +; CHECK-LABEL: sdot_int4 +define noundef i32 @sdot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) { +entry: +; CHECK: extractelement <4 x i32> %a, i64 0 +; CHECK: extractelement <4 x i32> %b, i64 0 +; CHECK: mul i32 %{{.*}}, %{{.*}} +; CHECK: extractelement <4 x i32> %a, i64 1 +; CHECK: extractelement <4 x i32> %b, i64 1 +; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) +; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) +; CHECK: extractelement <4 x i32> %a, i64 2 +; CHECK: extractelement <4 x i32> %b, i64 2 +; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) +; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) +; CHECK: extractelement <4 x i32> %a, i64 3 +; CHECK: extractelement <4 x i32> %b, i64 3 +; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) +; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %a, <4 x i32> %b) + ret i32 %dx.dot +} + +; CHECK-LABEL: dot_uint16_t3 +define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) { +entry: +; CHECK: extractelement <3 x i16> %a, i64 0 +; CHECK: extractelement <3 x i16> %b, i64 0 +; CHECK: mul i16 %{{.*}}, %{{.*}} +; CHECK: extractelement <3 x i16> %a, i64 1 +; CHECK: extractelement <3 x i16> %b, i64 1 +; EXPCHECK: call i16 @llvm.dx.umad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) +; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) +; CHECK: extractelement <3 x i16> %a, i64 2 +; CHECK: extractelement <3 x i16> %b, i64 2 +; EXPCHECK: call i16 @llvm.dx.umad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) +; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) + %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %a, <3 x i16> %b) + ret i16 %dx.dot +} + +; CHECK-LABEL: dot_uint4 +define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) { +entry: +; CHECK: extractelement <4 x i32> %a, i64 0 +; CHECK: extractelement <4 x i32> %b, i64 0 +; CHECK: mul i32 %{{.*}}, %{{.*}} +; CHECK: extractelement <4 x i32> %a, i64 1 +; CHECK: extractelement <4 x i32> %b, i64 1 +; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) +; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) +; CHECK: extractelement <4 x i32> %a, i64 2 +; CHECK: extractelement <4 x i32> %b, i64 2 +; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) +; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) +; CHECK: extractelement <4 x i32> %a, i64 3 +; CHECK: extractelement <4 x i32> %b, i64 3 +; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) +; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %a, <4 x i32> %b) + ret i32 %dx.dot +} + +; CHECK-LABEL: dot_uint64_t4 +define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) { +entry: +; CHECK: extractelement <2 x i64> %a, i64 0 +; CHECK: extractelement <2 x i64> %b, i64 0 +; CHECK: mul i64 %{{.*}}, %{{.*}} +; CHECK: extractelement <2 x i64> %a, i64 1 +; CHECK: extractelement <2 x i64> %b, i64 1 +; EXPCHECK: call i64 @llvm.dx.umad.i64(i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}}) +; DOPCHECK: call i64 @dx.op.tertiary.i64(i32 49, i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}}) + %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %a, <2 x i64> %b) + ret i64 %dx.dot +} + +declare i16 @llvm.dx.sdot.v2i16(<2 x i16>, <2 x i16>) +declare i32 @llvm.dx.sdot.v4i32(<4 x i32>, <4 x i32>) +declare i16 @llvm.dx.udot.v3i32(<3 x i16>, <3 x i16>) +declare i32 @llvm.dx.udot.v4i32(<4 x i32>, <4 x i32>) +declare i64 @llvm.dx.udot.v2i64(<2 x i64>, <2 x i64>)