Skip to content

Commit

Permalink
[DirectX][DXIL] Distinguish return type for overload type resolution. (
Browse files Browse the repository at this point in the history
…#85646)

Return type of DXIL Ops may be different from valid overload type of the
parameters, if any. Such DXIL Ops are correctly represented in DXIL.td.
However, DXILEmitter assumes the return type to be the same as parameter
overload type, if one exists. This results in generation in incorrect
overload index value in DXILOperation.inc for the DXIL Op and incorrect
DXIL operation function call in DXILOpLowering pass.

This change distinguishes return types correctly from parameter overload
types in DXILEmitter backend to handle such DXIL ops.

Add specification for DXIL Op `isinf` and corresponding tests to verify
the above change.

Fixes issue #85125
  • Loading branch information
bharadwajy committed Mar 20, 2024
1 parent 891172d commit 3f39571
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 30 deletions.
3 changes: 3 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ class DXILOpMapping<int opCode, DXILOpClass opClass,
}

// Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
"Determines if the specified value is infinite.",
[llvm_i1_ty, llvm_halforfloat_ty]>;
def Sin : DXILOpMapping<13, unary, int_sin,
"Returns sine(theta) for theta in radians.",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
Expand Down
43 changes: 21 additions & 22 deletions llvm/lib/Target/DirectX/DXILOpBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,13 @@ static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
/// its specification in DXIL.td.
/// \param OverloadTy Return type to be used to construct DXIL function type.
static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
Type *OverloadTy) {
Type *ReturnTy, Type *OverloadTy) {
SmallVector<Type *> ArgTys;

auto ParamKinds = getOpCodeParameterKind(*Prop);

// Add OverloadTy as return type of the function
ArgTys.emplace_back(OverloadTy);
// Add ReturnTy as return type of the function
ArgTys.emplace_back(ReturnTy);

// Add DXIL Opcode value type viz., Int32 as first argument
ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext()));
Expand All @@ -249,34 +249,33 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
}

static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
Type *OverloadTy, Module &M) {
const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
namespace llvm {
namespace dxil {

CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
Type *OverloadTy,
llvm::iterator_range<Use *> Args) {
const OpCodeProperty *Prop = getOpCodeProperty(OpCode);

OverloadKind Kind = getOverloadKind(OverloadTy);
if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
}

std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
// Dependent on name to dedup.
if (auto *Fn = M.getFunction(FnName))
return FunctionCallee(Fn);

FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy);
return M.getOrInsertFunction(FnName, DXILOpFT);
}

namespace llvm {
namespace dxil {

CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
llvm::iterator_range<Use *> Args) {
auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M);
std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop);
FunctionCallee DXILFn;
// Get the function with name DXILFnName, if one exists
if (auto *Func = M.getFunction(DXILFnName)) {
DXILFn = FunctionCallee(Func);
} else {
// Construct and add a function with name DXILFnName
FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
}
SmallVector<Value *> FullArgs;
FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
FullArgs.append(Args.begin(), Args.end());
return B.CreateCall(Fn, FullArgs);
return B.CreateCall(DXILFn, FullArgs);
}

Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
Expand Down
8 changes: 7 additions & 1 deletion llvm/lib/Target/DirectX/DXILOpBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ namespace dxil {
class DXILOpBuilder {
public:
DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {}
CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
/// Create an instruction that calls DXIL Op with return type, specified
/// opcode, and call arguments. \param OpCode Opcode of the DXIL Op call
/// constructed \param ReturnTy Return type of the DXIL Op call constructed
/// \param OverloadTy Overload type of the DXIL Op call constructed
/// \return DXIL Op call constructed
CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
Type *OverloadTy,
llvm::iterator_range<Use *> Args);
Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
static const char *getOpCodeName(dxil::OpCode DXILOp);
Expand Down
7 changes: 2 additions & 5 deletions llvm/lib/Target/DirectX/DXILOpLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,16 @@ using namespace llvm::dxil;

static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
IRBuilder<> B(M.getContext());
Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
DXILOpBuilder DXILB(M, B);
Type *OverloadTy = DXILB.getOverloadTy(DXILOp, F.getFunctionType());
for (User *U : make_early_inc_range(F.users())) {
CallInst *CI = dyn_cast<CallInst>(U);
if (!CI)
continue;

SmallVector<Value *> Args;
Args.emplace_back(DXILOpArg);
Args.append(CI->arg_begin(), CI->arg_end());
B.SetInsertPoint(CI);
CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, OverloadTy, CI->args());
CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(),
OverloadTy, CI->args());

CI->replaceAllUsesWith(DXILCI);
CI->eraseFromParent();
Expand Down
25 changes: 25 additions & 0 deletions llvm/test/CodeGen/DirectX/isinf.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s

; Make sure dxil operation function calls for isinf are generated for float and half.
; CHECK: call i1 @dx.op.isSpecialFloat.f32(i32 9, float %{{.*}})
; CHECK: call i1 @dx.op.isSpecialFloat.f16(i32 9, half %{{.*}})

; Function Attrs: noinline nounwind optnone
define noundef i1 @isinf_float(float noundef %a) #0 {
entry:
%a.addr = alloca float, align 4
store float %a, ptr %a.addr, align 4
%0 = load float, ptr %a.addr, align 4
%dx.isinf = call i1 @llvm.dx.isinf.f32(float %0)
ret i1 %dx.isinf
}

; Function Attrs: noinline nounwind optnone
define noundef i1 @isinf_half(half noundef %p0) #0 {
entry:
%p0.addr = alloca half, align 2
store half %p0, ptr %p0.addr, align 2
%0 = load half, ptr %p0.addr, align 2
%dx.isinf = call i1 @llvm.dx.isinf.f16(half %0)
ret i1 %dx.isinf
}
13 changes: 13 additions & 0 deletions llvm/test/CodeGen/DirectX/isinf_error.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s

; DXIL operation isinf does not support double overload type
; CHECK: LLVM ERROR: Invalid Overload Type

define noundef i1 @isinf_double(double noundef %a) #0 {
entry:
%a.addr = alloca double, align 8
store double %a, ptr %a.addr, align 8
%0 = load double, ptr %a.addr, align 8
%dx.isinf = call i1 @llvm.dx.isinf.f64(double %0)
ret i1 %dx.isinf
}
14 changes: 12 additions & 2 deletions llvm/utils/TableGen/DXILEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
// Populate OpTypes with return type and parameter types

// Parameter indices of overloaded parameters.
// This vector contains overload parameters in the order order used to
// This vector contains overload parameters in the order used to
// resolve an LLVMMatchType in accordance with convention outlined in
// the comment before the definition of class LLVMMatchType in
// llvm/IR/Intrinsics.td
Expand Down Expand Up @@ -398,10 +398,20 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,

OS << " static const OpCodeProperty OpCodeProps[] = {\n";
for (auto &Op : Ops) {
// Consider Op.OverloadParamIndex as the overload parameter index, by
// default
auto OLParamIdx = Op.OverloadParamIndex;
// If no overload parameter index is set, treat first parameter type as
// overload type - unless the Op has no parameters, in which case treat the
// return type - as overload parameter to emit the appropriate overload kind
// enum.
if (OLParamIdx < 0) {
OLParamIdx = (Op.OpTypes.size() > 1) ? 1 : 0;
}
OS << " { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
<< ", OpCodeClass::" << Op.OpClass << ", "
<< OpClassStrings.get(Op.OpClass.data()) << ", "
<< getOverloadKindStr(Op.OpTypes[0]) << ", "
<< getOverloadKindStr(Op.OpTypes[OLParamIdx]) << ", "
<< emitDXILOperationAttr(Op.OpAttributes) << ", "
<< Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
<< Parameters.get(ParameterMap[Op.OpClass]) << " },\n";
Expand Down

0 comments on commit 3f39571

Please sign in to comment.