diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 2eaceeba61770..8f4817258e3b1 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -18066,15 +18066,22 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments, return Arg; } -Intrinsic::ID getDotProductIntrinsic(QualType QT) { +Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) { + if (QT->hasFloatingRepresentation()) { + switch (elementCount) { + case 2: + return Intrinsic::dx_dot2; + case 3: + return Intrinsic::dx_dot3; + case 4: + return Intrinsic::dx_dot4; + } + } if (QT->hasSignedIntegerRepresentation()) return Intrinsic::dx_sdot; - if (QT->hasUnsignedIntegerRepresentation()) - return Intrinsic::dx_udot; - assert(QT->hasFloatingRepresentation()); - return Intrinsic::dx_dot; - ; + assert(QT->hasUnsignedIntegerRepresentation()); + return Intrinsic::dx_udot; } Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, @@ -18128,8 +18135,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, assert(T0->getScalarType() == T1->getScalarType() && "Dot product of vectors need the same element types."); - [[maybe_unused]] auto *VecTy0 = - E->getArg(0)->getType()->getAs(); + auto *VecTy0 = E->getArg(0)->getType()->getAs(); [[maybe_unused]] auto *VecTy1 = E->getArg(1)->getType()->getAs(); // A HLSLVectorTruncation should have happend @@ -18138,7 +18144,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, return Builder.CreateIntrinsic( /*ReturnType=*/T0->getScalarType(), - getDotProductIntrinsic(E->getArg(0)->getType()), + getDotProductIntrinsic(E->getArg(0)->getType(), + VecTy0->getNumElements()), ArrayRef{Op0, Op1}, nullptr, "dx.dot"); } break; case Builtin::BI__builtin_hlsl_lerp: { diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl index 0f993193c00cc..307d71cce3cb6 100644 --- a/clang/test/CodeGenHLSL/builtins/dot.hlsl +++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl @@ -110,21 +110,21 @@ uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); } // NO_HALF: ret float %dx.dot half test_dot_half(half p0, half p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1) +// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1) // NATIVE_HALF: ret half %dx.dot -// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1) +// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1) // NO_HALF: ret float %dx.dot half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1) +// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1) // NATIVE_HALF: ret half %dx.dot -// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1) +// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1) // NO_HALF: ret float %dx.dot half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1) +// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1) // NATIVE_HALF: ret half %dx.dot -// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1) +// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1) // NO_HALF: ret float %dx.dot half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); } @@ -132,34 +132,34 @@ half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); } // CHECK: ret float %dx.dot float test_dot_float(float p0, float p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1) +// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1) // CHECK: ret float %dx.dot float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1) +// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1) // CHECK: ret float %dx.dot float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1) +// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1) // CHECK: ret float %dx.dot float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1) +// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1) // CHECK: ret float %dx.dot float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1) +// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1) // CHECK: ret float %dx.dot float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1) +// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1) // CHECK: ret float %dx.dot float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); } // CHECK: %conv = sitofp i32 %1 to float // CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0 // CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer -// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat) +// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %splat.splat) // CHECK: ret float %dx.dot float test_builtin_dot_float2_int_splat(float2 p0, int p1) { return dot(p0, p1); @@ -168,7 +168,7 @@ float test_builtin_dot_float2_int_splat(float2 p0, int p1) { // CHECK: %conv = sitofp i32 %1 to float // CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0 // CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer -// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat) +// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %splat.splat) // CHECK: ret float %dx.dot float test_builtin_dot_float3_int_splat(float3 p0, int p1) { return dot(p0, p1); diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 1164b241ba7b0..a871fac46b9fd 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -24,7 +24,15 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>; def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; -def int_dx_dot : +def int_dx_dot2 : + Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], + [IntrNoMem, IntrWillReturn, Commutative] >; +def int_dx_dot3 : + Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], + [IntrNoMem, IntrWillReturn, Commutative] >; +def int_dx_dot4 : Intrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], [IntrNoMem, IntrWillReturn, Commutative] >; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index f7e69ebae15b6..2e6d58e14fd32 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -303,6 +303,15 @@ def IMad : DXILOpMapping<48, tertiary, int_dx_imad, "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">; def UMad : DXILOpMapping<49, tertiary, int_dx_umad, "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">; +let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)) in + def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2, + "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1">; +let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)) in + def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3, + "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2">; +let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)) in + def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4, + "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3">; def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id, "Reads the thread ID">; def GroupId : DXILOpMapping<94, groupId, int_dx_group_id, diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 0841ae95423c7..0b3982ea0f438 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -254,7 +254,7 @@ namespace dxil { CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, Type *OverloadTy, - llvm::iterator_range Args) { + SmallVector Args) { const OpCodeProperty *Prop = getOpCodeProperty(OpCode); OverloadKind Kind = getOverloadKind(OverloadTy); @@ -272,10 +272,8 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy); DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT); } - SmallVector FullArgs; - FullArgs.emplace_back(B.getInt32((int32_t)OpCode)); - FullArgs.append(Args.begin(), Args.end()); - return B.CreateCall(DXILFn, FullArgs); + + return B.CreateCall(DXILFn, Args); } Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) { diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h index f3abcc6e02a4e..5babeae470178 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.h +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h @@ -13,7 +13,7 @@ #define LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H #include "DXILConstants.h" -#include "llvm/ADT/iterator_range.h" +#include "llvm/ADT/SmallVector.h" namespace llvm { class Module; @@ -35,8 +35,7 @@ class DXILOpBuilder { /// \param OverloadTy Overload type of the DXIL Op call constructed /// \return DXIL Op call constructed CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, - Type *OverloadTy, - llvm::iterator_range Args); + Type *OverloadTy, SmallVector Args); Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT); static const char *getOpCodeName(dxil::OpCode DXILOp); diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index 3e334b0ec298d..f09e322f88e1f 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -30,6 +30,48 @@ using namespace llvm; using namespace llvm::dxil; +static bool isVectorArgExpansion(Function &F) { + switch (F.getIntrinsicID()) { + case Intrinsic::dx_dot2: + case Intrinsic::dx_dot3: + case Intrinsic::dx_dot4: + return true; + } + return false; +} + +static SmallVector populateOperands(Value *Arg, IRBuilder<> &Builder) { + SmallVector ExtractedElements; + auto *VecArg = dyn_cast(Arg->getType()); + for (unsigned I = 0; I < VecArg->getNumElements(); ++I) { + Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I); + Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index); + ExtractedElements.push_back(ExtractedElement); + } + return ExtractedElements; +} + +static SmallVector argVectorFlatten(CallInst *Orig, + IRBuilder<> &Builder) { + // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening. + unsigned NumOperands = Orig->getNumOperands() - 1; + assert(NumOperands > 0); + Value *Arg0 = Orig->getOperand(0); + [[maybe_unused]] auto *VecArg0 = dyn_cast(Arg0->getType()); + assert(VecArg0); + SmallVector NewOperands = populateOperands(Arg0, Builder); + for (unsigned I = 1; I < NumOperands; ++I) { + Value *Arg = Orig->getOperand(I); + [[maybe_unused]] auto *VecArg = dyn_cast(Arg->getType()); + assert(VecArg); + assert(VecArg0->getElementType() == VecArg->getElementType()); + assert(VecArg0->getNumElements() == VecArg->getNumElements()); + auto NextOperandList = populateOperands(Arg, Builder); + NewOperands.append(NextOperandList.begin(), NextOperandList.end()); + } + return NewOperands; +} + static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) { IRBuilder<> B(M.getContext()); DXILOpBuilder DXILB(M, B); @@ -39,9 +81,18 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) { if (!CI) continue; + SmallVector Args; + Value *DXILOpArg = B.getInt32(static_cast(DXILOp)); + Args.emplace_back(DXILOpArg); B.SetInsertPoint(CI); - CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(), - OverloadTy, CI->args()); + if (isVectorArgExpansion(F)) { + SmallVector NewArgs = argVectorFlatten(CI, B); + Args.append(NewArgs.begin(), NewArgs.end()); + } else + Args.append(CI->arg_begin(), CI->arg_end()); + + CallInst *DXILCI = + DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args); CI->replaceAllUsesWith(DXILCI); CI->eraseFromParent(); diff --git a/llvm/test/CodeGen/DirectX/dot2_error.ll b/llvm/test/CodeGen/DirectX/dot2_error.ll new file mode 100644 index 0000000000000..a27bfaedacd57 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/dot2_error.ll @@ -0,0 +1,10 @@ +; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s + +; DXIL operation dot2 does not support double overload type +; CHECK: LLVM ERROR: Invalid Overload + +define noundef double @dot_double2(<2 x double> noundef %a, <2 x double> noundef %b) { +entry: + %dx.dot = call double @llvm.dx.dot2.v2f64(<2 x double> %a, <2 x double> %b) + ret double %dx.dot +} diff --git a/llvm/test/CodeGen/DirectX/dot3_error.ll b/llvm/test/CodeGen/DirectX/dot3_error.ll new file mode 100644 index 0000000000000..eb69fb145038a --- /dev/null +++ b/llvm/test/CodeGen/DirectX/dot3_error.ll @@ -0,0 +1,10 @@ +; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s + +; DXIL operation dot3 does not support double overload type +; CHECK: LLVM ERROR: Invalid Overload + +define noundef double @dot_double3(<3 x double> noundef %a, <3 x double> noundef %b) { +entry: + %dx.dot = call double @llvm.dx.dot3.v3f64(<3 x double> %a, <3 x double> %b) + ret double %dx.dot +} diff --git a/llvm/test/CodeGen/DirectX/dot4_error.ll b/llvm/test/CodeGen/DirectX/dot4_error.ll new file mode 100644 index 0000000000000..5cd632684c0c0 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/dot4_error.ll @@ -0,0 +1,10 @@ +; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s + +; DXIL operation dot4 does not support double overload type +; CHECK: LLVM ERROR: Invalid Overload + +define noundef double @dot_double4(<4 x double> noundef %a, <4 x double> noundef %b) { +entry: + %dx.dot = call double @llvm.dx.dot4.v4f64(<4 x double> %a, <4 x double> %b) + ret double %dx.dot +} diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll new file mode 100644 index 0000000000000..3e13b2ad2650c --- /dev/null +++ b/llvm/test/CodeGen/DirectX/fdot.ll @@ -0,0 +1,94 @@ +; RUN: opt -S -dxil-op-lower < %s | FileCheck %s + +; Make sure dxil operation function calls for dot are generated for int/uint vectors. + +; CHECK-LABEL: dot_half2 +define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) { +entry: +; CHECK: extractelement <2 x half> %a, i32 0 +; CHECK: extractelement <2 x half> %a, i32 1 +; CHECK: extractelement <2 x half> %b, i32 0 +; CHECK: extractelement <2 x half> %b, i32 1 +; CHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) + %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b) + ret half %dx.dot +} + +; CHECK-LABEL: dot_half3 +define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) { +entry: +; CHECK: extractelement <3 x half> %a, i32 0 +; CHECK: extractelement <3 x half> %a, i32 1 +; CHECK: extractelement <3 x half> %a, i32 2 +; CHECK: extractelement <3 x half> %b, i32 0 +; CHECK: extractelement <3 x half> %b, i32 1 +; CHECK: extractelement <3 x half> %b, i32 2 +; CHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) + %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b) + ret half %dx.dot +} + +; CHECK-LABEL: dot_half4 +define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) { +entry: +; CHECK: extractelement <4 x half> %a, i32 0 +; CHECK: extractelement <4 x half> %a, i32 1 +; CHECK: extractelement <4 x half> %a, i32 2 +; CHECK: extractelement <4 x half> %a, i32 3 +; CHECK: extractelement <4 x half> %b, i32 0 +; CHECK: extractelement <4 x half> %b, i32 1 +; CHECK: extractelement <4 x half> %b, i32 2 +; CHECK: extractelement <4 x half> %b, i32 3 +; CHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) + %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b) + ret half %dx.dot +} + +; CHECK-LABEL: dot_float2 +define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) { +entry: +; CHECK: extractelement <2 x float> %a, i32 0 +; CHECK: extractelement <2 x float> %a, i32 1 +; CHECK: extractelement <2 x float> %b, i32 0 +; CHECK: extractelement <2 x float> %b, i32 1 +; CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) + %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b) + ret float %dx.dot +} + +; CHECK-LABEL: dot_float3 +define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) { +entry: +; CHECK: extractelement <3 x float> %a, i32 0 +; CHECK: extractelement <3 x float> %a, i32 1 +; CHECK: extractelement <3 x float> %a, i32 2 +; CHECK: extractelement <3 x float> %b, i32 0 +; CHECK: extractelement <3 x float> %b, i32 1 +; CHECK: extractelement <3 x float> %b, i32 2 +; CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) + %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b) + ret float %dx.dot +} + +; CHECK-LABEL: dot_float4 +define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) { +entry: +; CHECK: extractelement <4 x float> %a, i32 0 +; CHECK: extractelement <4 x float> %a, i32 1 +; CHECK: extractelement <4 x float> %a, i32 2 +; CHECK: extractelement <4 x float> %a, i32 3 +; CHECK: extractelement <4 x float> %b, i32 0 +; CHECK: extractelement <4 x float> %b, i32 1 +; CHECK: extractelement <4 x float> %b, i32 2 +; CHECK: extractelement <4 x float> %b, i32 3 +; CHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) + %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b) + ret float %dx.dot +} + +declare half @llvm.dx.dot.v2f16(<2 x half> , <2 x half> ) +declare half @llvm.dx.dot.v3f16(<3 x half> , <3 x half> ) +declare half @llvm.dx.dot.v4f16(<4 x half> , <4 x half> ) +declare float @llvm.dx.dot.v2f32(<2 x float>, <2 x float>) +declare float @llvm.dx.dot.v3f32(<3 x float>, <3 x float>) +declare float @llvm.dx.dot.v4f32(<4 x float>, <4 x float>)