46 changes: 46 additions & 0 deletions llvm/lib/Target/DirectX/DXILOpBuilder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//===- DXILOpBuilder.h - Helper class for build DIXLOp functions ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file contains class to help build DXIL op functions.
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H
#define LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H

#include "DXILConstants.h"
#include "llvm/ADT/iterator_range.h"

namespace llvm {
class Module;
class IRBuilderBase;
class CallInst;
class Value;
class Type;
class FunctionType;
class Use;

namespace DXIL {

class DXILOpBuilder {
public:
DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {}
CallInst *createDXILOpCall(DXIL::OpCode OpCode, Type *OverloadTy,
llvm::iterator_range<Use *> Args);
Type *getOverloadTy(DXIL::OpCode OpCode, FunctionType *FT,
bool NoOpCodeParam);
static const char *getOpCodeName(DXIL::OpCode DXILOp);

private:
Module &M;
IRBuilderBase &B;
};

} // namespace DXIL
} // namespace llvm

#endif
167 changes: 6 additions & 161 deletions llvm/lib/Target/DirectX/DXILOpLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "DXILConstants.h"
#include "DXILOpBuilder.h"
#include "DirectX.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/CodeGen/Passes.h"
Expand All @@ -28,168 +29,12 @@
using namespace llvm;
using namespace llvm::DXIL;

constexpr StringLiteral DXILOpNamePrefix = "dx.op.";

enum OverloadKind : uint16_t {
VOID = 1,
HALF = 1 << 1,
FLOAT = 1 << 2,
DOUBLE = 1 << 3,
I1 = 1 << 4,
I8 = 1 << 5,
I16 = 1 << 6,
I32 = 1 << 7,
I64 = 1 << 8,
UserDefineType = 1 << 9,
ObjectType = 1 << 10,
};

static const char *getOverloadTypeName(OverloadKind Kind) {
switch (Kind) {
case OverloadKind::HALF:
return "f16";
case OverloadKind::FLOAT:
return "f32";
case OverloadKind::DOUBLE:
return "f64";
case OverloadKind::I1:
return "i1";
case OverloadKind::I8:
return "i8";
case OverloadKind::I16:
return "i16";
case OverloadKind::I32:
return "i32";
case OverloadKind::I64:
return "i64";
case OverloadKind::VOID:
case OverloadKind::ObjectType:
case OverloadKind::UserDefineType:
break;
}
llvm_unreachable("invalid overload type for name");
return "void";
}

static OverloadKind getOverloadKind(Type *Ty) {
Type::TypeID T = Ty->getTypeID();
switch (T) {
case Type::VoidTyID:
return OverloadKind::VOID;
case Type::HalfTyID:
return OverloadKind::HALF;
case Type::FloatTyID:
return OverloadKind::FLOAT;
case Type::DoubleTyID:
return OverloadKind::DOUBLE;
case Type::IntegerTyID: {
IntegerType *ITy = cast<IntegerType>(Ty);
unsigned Bits = ITy->getBitWidth();
switch (Bits) {
case 1:
return OverloadKind::I1;
case 8:
return OverloadKind::I8;
case 16:
return OverloadKind::I16;
case 32:
return OverloadKind::I32;
case 64:
return OverloadKind::I64;
default:
llvm_unreachable("invalid overload type");
return OverloadKind::VOID;
}
}
case Type::PointerTyID:
return OverloadKind::UserDefineType;
case Type::StructTyID:
return OverloadKind::ObjectType;
default:
llvm_unreachable("invalid overload type");
return OverloadKind::VOID;
}
}

static std::string getTypeName(OverloadKind Kind, Type *Ty) {
if (Kind < OverloadKind::UserDefineType) {
return getOverloadTypeName(Kind);
} else if (Kind == OverloadKind::UserDefineType) {
StructType *ST = cast<StructType>(Ty);
return ST->getStructName().str();
} else if (Kind == OverloadKind::ObjectType) {
StructType *ST = cast<StructType>(Ty);
return ST->getStructName().str();
} else {
std::string Str;
raw_string_ostream OS(Str);
Ty->print(OS);
return OS.str();
}
}

// Static properties.
struct OpCodeProperty {
DXIL::OpCode OpCode;
// Offset in DXILOpCodeNameTable.
unsigned OpCodeNameOffset;
DXIL::OpCodeClass OpCodeClass;
// Offset in DXILOpCodeClassNameTable.
unsigned OpCodeClassNameOffset;
uint16_t OverloadTys;
llvm::Attribute::AttrKind FuncAttr;
};

// Include getOpCodeClassName getOpCodeProperty and getOpCodeName which
// generated by tableGen.
#define DXIL_OP_OPERATION_TABLE
#include "DXILOperation.inc"
#undef DXIL_OP_OPERATION_TABLE

static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
const OpCodeProperty &Prop) {
if (Kind == OverloadKind::VOID) {
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
}
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
getTypeName(Kind, Ty))
.str();
}

static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F,
Module &M) {
const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);

// Get return type as overload type for DXILOp.
// Only simple mapping case here, so return type is good enough.
Type *OverloadTy = F.getReturnType();

OverloadKind Kind = getOverloadKind(OverloadTy);
// FIXME: find the issue and report error in clang instead of check it in
// backend.
if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
llvm_unreachable("invalid overload");
}

std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
assert(!M.getFunction(FnName) && "Function already exists");

auto &Ctx = M.getContext();
Type *OpCodeTy = Type::getInt32Ty(Ctx);

SmallVector<Type *> ArgTypes;
// DXIL has i32 opcode as first arg.
ArgTypes.emplace_back(OpCodeTy);
FunctionType *FT = F.getFunctionType();
ArgTypes.append(FT->param_begin(), FT->param_end());
FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false);
return M.getOrInsertFunction(FnName, DXILOpFT);
}

static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) {
auto DXILOpFn = createDXILOpFunction(DXILOp, F, M);
IRBuilder<> B(M.getContext());
Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
DXILOpBuilder DXILB(M, B);
Type *OverloadTy =
DXILB.getOverloadTy(DXILOp, F.getFunctionType(), /*NoOpCodeParam*/ true);
for (User *U : make_early_inc_range(F.users())) {
CallInst *CI = dyn_cast<CallInst>(U);
if (!CI)
Expand All @@ -199,8 +44,8 @@ static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) {
Args.emplace_back(DXILOpArg);
Args.append(CI->arg_begin(), CI->arg_end());
B.SetInsertPoint(CI);
CallInst *DXILCI = B.CreateCall(DXILOpFn, Args);
LLVM_DEBUG(DXILCI->setName(getOpCodeName(DXILOp)));
CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, OverloadTy, CI->args());

CI->replaceAllUsesWith(DXILCI);
CI->eraseFromParent();
}
Expand Down
106 changes: 88 additions & 18 deletions llvm/utils/TableGen/DXILEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,29 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/DXILOperationCommon.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"

using namespace llvm;
using namespace llvm::DXIL;

namespace {

struct DXILShaderModel {
int Major;
int Minor;
};

struct DXILParam {
int Pos; // position in parameter list
StringRef Type; // llvm type name, $o for overload, $r for resource
// type, $cb for legacy cbuffer, $u4 for u4 struct
int Pos; // position in parameter list
ParameterKind Kind;
StringRef Name; // short, unique name
StringRef Doc; // the documentation description of this parameter
bool IsConst; // whether this argument requires a constant value in the IR
StringRef EnumName; // the name of the enum type if applicable
int MaxValue; // the maximum value for this parameter if applicable
DXILParam(const Record *R) {
Name = R->getValueAsString("name");
Pos = R->getValueAsInt("pos");
Type = R->getValueAsString("llvm_type");
if (R->getValue("doc"))
Doc = R->getValueAsString("doc");
IsConst = R->getValueAsBit("is_const");
EnumName = R->getValueAsString("enum_name");
MaxValue = R->getValueAsInt("max_value");
}
DXILParam(const Record *R);
};

struct DXILOperationData {
Expand Down Expand Up @@ -74,7 +67,9 @@ struct DXILOperationData {
DXILShaderModel ShaderModel; // minimum shader model required
DXILShaderModel ShaderModelTranslated; // minimum shader model required with
// translation by linker
SmallVector<StringRef, 4> counters; // counters for this inst.
int OverloadParamIndex; // parameter index which control the overload.
// When < 0, should be only 1 overload type.
SmallVector<StringRef, 4> counters; // counters for this inst.
DXILOperationData(const Record *R) {
Name = R->getValueAsString("name");
DXILOp = R->getValueAsString("dxil_op");
Expand All @@ -93,16 +88,64 @@ struct DXILOperationData {
Doc = R->getValueAsString("doc");

ListInit *ParamList = R->getValueAsListInit("ops");
for (unsigned i = 0; i < ParamList->size(); ++i) {
Record *Param = ParamList->getElementAsRecord(i);
OverloadParamIndex = -1;
for (unsigned I = 0; I < ParamList->size(); ++I) {
Record *Param = ParamList->getElementAsRecord(I);
Params.emplace_back(DXILParam(Param));
auto &CurParam = Params.back();
if (CurParam.Kind >= ParameterKind::OVERLOAD)
OverloadParamIndex = I;
}
OverloadTypes = R->getValueAsString("oload_types");
FnAttr = R->getValueAsString("fn_attr");
}
};
} // end anonymous namespace

DXILParam::DXILParam(const Record *R) {
Name = R->getValueAsString("name");
Pos = R->getValueAsInt("pos");
Kind = parameterTypeNameToKind(R->getValueAsString("llvm_type"));
if (R->getValue("doc"))
Doc = R->getValueAsString("doc");
IsConst = R->getValueAsBit("is_const");
EnumName = R->getValueAsString("enum_name");
MaxValue = R->getValueAsInt("max_value");
}

static std::string parameterKindToString(ParameterKind Kind) {
switch (Kind) {
case ParameterKind::INVALID:
return "INVALID";
case ParameterKind::VOID:
return "VOID";
case ParameterKind::HALF:
return "HALF";
case ParameterKind::FLOAT:
return "FLOAT";
case ParameterKind::DOUBLE:
return "DOUBLE";
case ParameterKind::I1:
return "I1";
case ParameterKind::I8:
return "I8";
case ParameterKind::I16:
return "I16";
case ParameterKind::I32:
return "I32";
case ParameterKind::I64:
return "I64";
case ParameterKind::OVERLOAD:
return "OVERLOAD";
case ParameterKind::CBUFFER_RET:
return "CBUFFER_RET";
case ParameterKind::RESOURCE_RET:
return "RESOURCE_RET";
case ParameterKind::DXIL_HANDLE:
return "DXIL_HANDLE";
}
}

static void emitDXILOpEnum(DXILOperationData &DXILOp, raw_ostream &OS) {
// Name = ID, // Doc
OS << DXILOp.Name << " = " << DXILOp.DXILOpID << ", // " << DXILOp.Doc
Expand Down Expand Up @@ -271,7 +314,9 @@ static void emitDXILOperationTable(std::vector<DXILOperationData> &DXILOps,
// Collect Names.
SequenceToOffsetTable<std::string> OpClassStrings;
SequenceToOffsetTable<std::string> OpStrings;
SequenceToOffsetTable<SmallVector<ParameterKind>> Parameters;

StringMap<SmallVector<ParameterKind>> ParameterMap;
StringSet<> ClassSet;
for (auto &DXILOp : DXILOps) {
OpStrings.add(DXILOp.DXILOp.str());
Expand All @@ -280,16 +325,24 @@ static void emitDXILOperationTable(std::vector<DXILOperationData> &DXILOps,
continue;
ClassSet.insert(DXILOp.DXILClass);
OpClassStrings.add(getDXILOpClassName(DXILOp.DXILClass));
SmallVector<ParameterKind> ParamKindVec;
for (auto &Param : DXILOp.Params) {
ParamKindVec.emplace_back(Param.Kind);
}
ParameterMap[DXILOp.DXILClass] = ParamKindVec;
Parameters.add(ParamKindVec);
}

// Layout names.
OpStrings.layout();
OpClassStrings.layout();
Parameters.layout();

// Emit the DXIL operation table.
//{DXIL::OpCode::Sin, OpCodeNameIndex, OpCodeClass::Unary,
// OpCodeClassNameIndex,
// OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone},
// OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone, 0,
// 3, ParameterTableOffset},
OS << "static const OpCodeProperty *getOpCodeProperty(DXIL::OpCode DXILOp) "
"{\n";

Expand All @@ -300,7 +353,9 @@ static void emitDXILOperationTable(std::vector<DXILOperationData> &DXILOps,
<< ", OpCodeClass::" << DXILOp.DXILClass << ", "
<< OpClassStrings.get(getDXILOpClassName(DXILOp.DXILClass)) << ", "
<< getDXILOperationOverload(DXILOp.OverloadTypes) << ", "
<< emitDXILOperationFnAttr(DXILOp.FnAttr) << " },\n";
<< emitDXILOperationFnAttr(DXILOp.FnAttr) << ", "
<< DXILOp.OverloadParamIndex << ", " << DXILOp.Params.size() << ", "
<< Parameters.get(ParameterMap[DXILOp.DXILClass]) << " },\n";
}
OS << " };\n";

Expand Down Expand Up @@ -338,6 +393,21 @@ static void emitDXILOperationTable(std::vector<DXILOperationData> &DXILOps,
OS << " unsigned Index = Prop.OpCodeClassNameOffset;\n";
OS << " return DXILOpCodeClassNameTable + Index;\n";
OS << "}\n ";

OS << "static const ParameterKind *getOpCodeParameterKind(const "
"OpCodeProperty &Prop) "
"{\n\n";
OS << " static const ParameterKind DXILOpParameterKindTable[] = {\n";
Parameters.emit(
OS,
[](raw_ostream &ParamOS, ParameterKind Kind) {
ParamOS << "ParameterKind::" << parameterKindToString(Kind);
},
"ParameterKind::INVALID");
OS << " };\n\n";
OS << " unsigned Index = Prop.ParameterTableOffset;\n";
OS << " return DXILOpParameterKindTable + Index;\n";
OS << "}\n ";
}

namespace llvm {
Expand Down