diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 216fa5b10c8f4..36eb29d53766f 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -255,6 +255,9 @@ class DXILOpMapping; def Sin : DXILOpMapping<13, unary, int_sin, "Returns sine(theta) for theta in radians.", [llvm_halforfloat_ty, LLVMMatchType<0>]>; diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 11b24d0449236..a1eacc2d48009 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -229,13 +229,13 @@ static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) { /// its specification in DXIL.td. /// \param OverloadTy Return type to be used to construct DXIL function type. static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, - Type *OverloadTy) { + Type *ReturnTy, Type *OverloadTy) { SmallVector ArgTys; auto ParamKinds = getOpCodeParameterKind(*Prop); - // Add OverloadTy as return type of the function - ArgTys.emplace_back(OverloadTy); + // Add ReturnTy as return type of the function + ArgTys.emplace_back(ReturnTy); // Add DXIL Opcode value type viz., Int32 as first argument ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext())); @@ -249,34 +249,33 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, ArgTys[0], ArrayRef(&ArgTys[1], ArgTys.size() - 1), false); } -static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp, - Type *OverloadTy, Module &M) { - const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); +namespace llvm { +namespace dxil { + +CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, + Type *OverloadTy, + llvm::iterator_range Args) { + const OpCodeProperty *Prop = getOpCodeProperty(OpCode); OverloadKind Kind = getOverloadKind(OverloadTy); if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false); } - std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); - // Dependent on name to dedup. - if (auto *Fn = M.getFunction(FnName)) - return FunctionCallee(Fn); - - FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy); - return M.getOrInsertFunction(FnName, DXILOpFT); -} - -namespace llvm { -namespace dxil { - -CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy, - llvm::iterator_range Args) { - auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M); + std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop); + FunctionCallee DXILFn; + // Get the function with name DXILFnName, if one exists + if (auto *Func = M.getFunction(DXILFnName)) { + DXILFn = FunctionCallee(Func); + } else { + // Construct and add a function with name DXILFnName + FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy); + DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT); + } SmallVector FullArgs; FullArgs.emplace_back(B.getInt32((int32_t)OpCode)); FullArgs.append(Args.begin(), Args.end()); - return B.CreateCall(Fn, FullArgs); + return B.CreateCall(DXILFn, FullArgs); } Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) { diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h index 1c15f109184ad..f3abcc6e02a4e 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.h +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h @@ -29,7 +29,13 @@ namespace dxil { class DXILOpBuilder { public: DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {} - CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy, + /// Create an instruction that calls DXIL Op with return type, specified + /// opcode, and call arguments. \param OpCode Opcode of the DXIL Op call + /// constructed \param ReturnTy Return type of the DXIL Op call constructed + /// \param OverloadTy Overload type of the DXIL Op call constructed + /// \return DXIL Op call constructed + CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, + Type *OverloadTy, llvm::iterator_range Args); Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT); static const char *getOpCodeName(dxil::OpCode DXILOp); diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index e5c2042e7d16a..3e334b0ec298d 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -32,7 +32,6 @@ using namespace llvm::dxil; static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) { IRBuilder<> B(M.getContext()); - Value *DXILOpArg = B.getInt32(static_cast(DXILOp)); DXILOpBuilder DXILB(M, B); Type *OverloadTy = DXILB.getOverloadTy(DXILOp, F.getFunctionType()); for (User *U : make_early_inc_range(F.users())) { @@ -40,11 +39,9 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) { if (!CI) continue; - SmallVector Args; - Args.emplace_back(DXILOpArg); - Args.append(CI->arg_begin(), CI->arg_end()); B.SetInsertPoint(CI); - CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, OverloadTy, CI->args()); + CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(), + OverloadTy, CI->args()); CI->replaceAllUsesWith(DXILCI); CI->eraseFromParent(); diff --git a/llvm/test/CodeGen/DirectX/isinf.ll b/llvm/test/CodeGen/DirectX/isinf.ll new file mode 100644 index 0000000000000..e2975da90bfc1 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/isinf.ll @@ -0,0 +1,25 @@ +; RUN: opt -S -dxil-op-lower < %s | FileCheck %s + +; Make sure dxil operation function calls for isinf are generated for float and half. +; CHECK: call i1 @dx.op.isSpecialFloat.f32(i32 9, float %{{.*}}) +; CHECK: call i1 @dx.op.isSpecialFloat.f16(i32 9, half %{{.*}}) + +; Function Attrs: noinline nounwind optnone +define noundef i1 @isinf_float(float noundef %a) #0 { +entry: + %a.addr = alloca float, align 4 + store float %a, ptr %a.addr, align 4 + %0 = load float, ptr %a.addr, align 4 + %dx.isinf = call i1 @llvm.dx.isinf.f32(float %0) + ret i1 %dx.isinf +} + +; Function Attrs: noinline nounwind optnone +define noundef i1 @isinf_half(half noundef %p0) #0 { +entry: + %p0.addr = alloca half, align 2 + store half %p0, ptr %p0.addr, align 2 + %0 = load half, ptr %p0.addr, align 2 + %dx.isinf = call i1 @llvm.dx.isinf.f16(half %0) + ret i1 %dx.isinf +} diff --git a/llvm/test/CodeGen/DirectX/isinf_error.ll b/llvm/test/CodeGen/DirectX/isinf_error.ll new file mode 100644 index 0000000000000..95b2d0cabcc43 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/isinf_error.ll @@ -0,0 +1,13 @@ +; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s + +; DXIL operation isinf does not support double overload type +; CHECK: LLVM ERROR: Invalid Overload Type + +define noundef i1 @isinf_double(double noundef %a) #0 { +entry: + %a.addr = alloca double, align 8 + store double %a, ptr %a.addr, align 8 + %0 = load double, ptr %a.addr, align 8 + %dx.isinf = call i1 @llvm.dx.isinf.f64(double %0) + ret i1 %dx.isinf +} diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp index 59089929837eb..af1efb8aa99f7 100644 --- a/llvm/utils/TableGen/DXILEmitter.cpp +++ b/llvm/utils/TableGen/DXILEmitter.cpp @@ -119,7 +119,7 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) { // Populate OpTypes with return type and parameter types // Parameter indices of overloaded parameters. - // This vector contains overload parameters in the order order used to + // This vector contains overload parameters in the order used to // resolve an LLVMMatchType in accordance with convention outlined in // the comment before the definition of class LLVMMatchType in // llvm/IR/Intrinsics.td @@ -398,10 +398,20 @@ static void emitDXILOperationTable(std::vector &Ops, OS << " static const OpCodeProperty OpCodeProps[] = {\n"; for (auto &Op : Ops) { + // Consider Op.OverloadParamIndex as the overload parameter index, by + // default + auto OLParamIdx = Op.OverloadParamIndex; + // If no overload parameter index is set, treat first parameter type as + // overload type - unless the Op has no parameters, in which case treat the + // return type - as overload parameter to emit the appropriate overload kind + // enum. + if (OLParamIdx < 0) { + OLParamIdx = (Op.OpTypes.size() > 1) ? 1 : 0; + } OS << " { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName) << ", OpCodeClass::" << Op.OpClass << ", " << OpClassStrings.get(Op.OpClass.data()) << ", " - << getOverloadKindStr(Op.OpTypes[0]) << ", " + << getOverloadKindStr(Op.OpTypes[OLParamIdx]) << ", " << emitDXILOperationAttr(Op.OpAttributes) << ", " << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", " << Parameters.get(ParameterMap[Op.OpClass]) << " },\n";