Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 164 additions & 7 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,36 @@ CGHLSLRuntime::emitSPIRVUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
VariableName.str());
}

static void createSPIRVLocationStore(IRBuilder<> &B, llvm::Module &M,
llvm::Value *Source, unsigned Location,
StringRef Name) {
auto *GV = new llvm::GlobalVariable(
M, Source->getType(), /* isConstant= */ false,
llvm::GlobalValue::ExternalLinkage,
/* Initializer= */ nullptr, /* Name= */ Name, /* insertBefore= */ nullptr,
llvm::GlobalVariable::GeneralDynamicTLSModel,
/* AddressSpace */ 8, /* isExternallyInitialized= */ false);
GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
addLocationDecoration(GV, Location);
B.CreateStore(Source, GV);
}

void CGHLSLRuntime::emitSPIRVUserSemanticStore(
llvm::IRBuilder<> &B, llvm::Value *Source,
HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
Twine BaseName = Twine(Semantic->getAttrName()->getName());
Twine VariableName = BaseName.concat(Twine(Index.value_or(0)));
unsigned Location = SPIRVLastAssignedOutputSemanticLocation;

// DXC completely ignores the semantic/index pair. Location are assigned from
// the first semantic to the last.
llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Source->getType());
unsigned ElementCount = AT ? AT->getNumElements() : 1;
SPIRVLastAssignedOutputSemanticLocation += ElementCount;
createSPIRVLocationStore(B, CGM.getModule(), Source, Location,
VariableName.str());
}

llvm::Value *
CGHLSLRuntime::emitDXILUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
HLSLAppliedSemanticAttr *Semantic,
Expand All @@ -609,6 +639,23 @@ CGHLSLRuntime::emitDXILUserSemanticLoad(llvm::IRBuilder<> &B, llvm::Type *Type,
return Value;
}

void CGHLSLRuntime::emitDXILUserSemanticStore(llvm::IRBuilder<> &B,
llvm::Value *Source,
HLSLAppliedSemanticAttr *Semantic,
std::optional<unsigned> Index) {
// DXIL packing rules etc shall be handled here.
// FIXME: generate proper sigpoint, index, col, row values.
SmallVector<Value *> Args{B.getInt32(4),
B.getInt32(0),
B.getInt32(0),
B.getInt8(0),
llvm::PoisonValue::get(B.getInt32Ty()),
Source};

llvm::Intrinsic::ID IntrinsicID = llvm::Intrinsic::dx_store_output;
B.CreateIntrinsic(/*ReturnType=*/CGM.VoidTy, IntrinsicID, Args, nullptr);
}

llvm::Value *CGHLSLRuntime::emitUserSemanticLoad(
IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
Expand All @@ -621,6 +668,19 @@ llvm::Value *CGHLSLRuntime::emitUserSemanticLoad(
llvm_unreachable("Unsupported target for user-semantic load.");
}

void CGHLSLRuntime::emitUserSemanticStore(IRBuilder<> &B, llvm::Value *Source,
const clang::DeclaratorDecl *Decl,
HLSLAppliedSemanticAttr *Semantic,
std::optional<unsigned> Index) {
if (CGM.getTarget().getTriple().isSPIRV())
return emitSPIRVUserSemanticStore(B, Source, Semantic, Index);

if (CGM.getTarget().getTriple().isDXIL())
return emitDXILUserSemanticStore(B, Source, Semantic, Index);

llvm_unreachable("Unsupported target for user-semantic load.");
}

llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad(
IRBuilder<> &B, llvm::Type *Type, const clang::DeclaratorDecl *Decl,
HLSLAppliedSemanticAttr *Semantic, std::optional<unsigned> Index) {
Expand Down Expand Up @@ -669,6 +729,34 @@ llvm::Value *CGHLSLRuntime::emitSystemSemanticLoad(
llvm_unreachable("non-handled system semantic. FIXME.");
}

static void createSPIRVBuiltinStore(IRBuilder<> &B, llvm::Module &M,
llvm::Value *Source, const Twine &Name,
unsigned BuiltInID) {
auto *GV = new llvm::GlobalVariable(
M, Source->getType(), /* isConstant= */ false,
llvm::GlobalValue::ExternalLinkage,
/* Initializer= */ nullptr, Name, /* insertBefore= */ nullptr,
llvm::GlobalVariable::GeneralDynamicTLSModel,
/* AddressSpace */ 8, /* isExternallyInitialized= */ false);
addSPIRVBuiltinDecoration(GV, BuiltInID);
GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
B.CreateStore(Source, GV);
}

void CGHLSLRuntime::emitSystemSemanticStore(IRBuilder<> &B, llvm::Value *Source,
const clang::DeclaratorDecl *Decl,
HLSLAppliedSemanticAttr *Semantic,
std::optional<unsigned> Index) {

std::string SemanticName = Semantic->getAttrName()->getName().upper();
if (SemanticName == "SV_POSITION")
createSPIRVBuiltinStore(B, CGM.getModule(), Source,
Semantic->getAttrName()->getName(),
/* BuiltIn::Position */ 0);
else
llvm_unreachable("non-handled system semantic. FIXME.");
}

llvm::Value *CGHLSLRuntime::handleScalarSemanticLoad(
IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic) {
Expand All @@ -679,6 +767,16 @@ llvm::Value *CGHLSLRuntime::handleScalarSemanticLoad(
return emitUserSemanticLoad(B, Type, Decl, Semantic, Index);
}

void CGHLSLRuntime::handleScalarSemanticStore(
IRBuilder<> &B, const FunctionDecl *FD, llvm::Value *Source,
const clang::DeclaratorDecl *Decl, HLSLAppliedSemanticAttr *Semantic) {
std::optional<unsigned> Index = Semantic->getSemanticIndex();
if (Semantic->getAttrName()->getName().starts_with_insensitive("SV_"))
emitSystemSemanticStore(B, Source, Decl, Semantic, Index);
else
emitUserSemanticStore(B, Source, Decl, Semantic, Index);
}

std::pair<llvm::Value *, specific_attr_iterator<HLSLAppliedSemanticAttr>>
CGHLSLRuntime::handleStructSemanticLoad(
IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
Expand All @@ -705,6 +803,35 @@ CGHLSLRuntime::handleStructSemanticLoad(
return std::make_pair(Aggregate, AttrBegin);
}

specific_attr_iterator<HLSLAppliedSemanticAttr>
CGHLSLRuntime::handleStructSemanticStore(
IRBuilder<> &B, const FunctionDecl *FD, llvm::Value *Source,
const clang::DeclaratorDecl *Decl,
specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {

const llvm::StructType *ST = cast<StructType>(Source->getType());

const clang::RecordDecl *RD = nullptr;
if (const FunctionDecl *FD = dyn_cast<FunctionDecl>(Decl))
RD = FD->getDeclaredReturnType()->getAsRecordDecl();
else
RD = Decl->getType()->getAsRecordDecl();
assert(RD);

assert(std::distance(RD->field_begin(), RD->field_end()) ==
ST->getNumElements());

auto FieldDecl = RD->field_begin();
for (unsigned I = 0; I < ST->getNumElements(); ++I) {
llvm::Value *Extract = B.CreateExtractValue(Source, I);
AttrBegin =
handleSemanticStore(B, FD, Extract, *FieldDecl, AttrBegin, AttrEnd);
}

return AttrBegin;
}

std::pair<llvm::Value *, specific_attr_iterator<HLSLAppliedSemanticAttr>>
CGHLSLRuntime::handleSemanticLoad(
IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
Expand All @@ -721,6 +848,22 @@ CGHLSLRuntime::handleSemanticLoad(
AttrBegin);
}

specific_attr_iterator<HLSLAppliedSemanticAttr>
CGHLSLRuntime::handleSemanticStore(
IRBuilder<> &B, const FunctionDecl *FD, llvm::Value *Source,
const clang::DeclaratorDecl *Decl,
specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd) {
assert(AttrBegin != AttrEnd);
if (Source->getType()->isStructTy())
return handleStructSemanticStore(B, FD, Source, Decl, AttrBegin, AttrEnd);

HLSLAppliedSemanticAttr *Attr = *AttrBegin;
++AttrBegin;
handleScalarSemanticStore(B, FD, Source, Decl, Attr);
return AttrBegin;
}

void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
llvm::Function *Fn) {
llvm::Module &M = CGM.getModule();
Expand Down Expand Up @@ -752,20 +895,22 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
OB.emplace_back("convergencectrl", bundleArgs);
}

// FIXME: support struct parameters where semantics are on members.
// See: https://github.com/llvm/llvm-project/issues/57874
std::unordered_map<const DeclaratorDecl *, llvm::Value *> OutputSemantic;

unsigned SRetOffset = 0;
for (const auto &Param : Fn->args()) {
if (Param.hasStructRetAttr()) {
// FIXME: support output.
// See: https://github.com/llvm/llvm-project/issues/57874
SRetOffset = 1;
Args.emplace_back(PoisonValue::get(Param.getType()));
llvm::Type *VarType = Param.getParamStructRetType();
llvm::Value *Var = B.CreateAlloca(VarType);
OutputSemantic.emplace(FD, Var);
Args.push_back(Var);
continue;
}

const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
llvm::Value *SemanticValue = nullptr;
// FIXME: support inout/out parameters for semantics.
if ([[maybe_unused]] HLSLParamModifierAttr *MA =
PD->getAttr<HLSLParamModifierAttr>()) {
llvm_unreachable("Not handled yet");
Expand All @@ -792,8 +937,20 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,

CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);
CI->setCallingConv(Fn->getCallingConv());
// FIXME: Handle codegen for return type semantics.
// See: https://github.com/llvm/llvm-project/issues/57875

if (Fn->getReturnType() != CGM.VoidTy)
OutputSemantic.emplace(FD, CI);

for (auto &[Decl, Source] : OutputSemantic) {
AllocaInst *AI = dyn_cast<AllocaInst>(Source);
llvm::Value *SourceValue =
AI ? B.CreateLoad(AI->getAllocatedType(), Source) : Source;

auto AttrBegin = Decl->specific_attr_begin<HLSLAppliedSemanticAttr>();
auto AttrEnd = Decl->specific_attr_end<HLSLAppliedSemanticAttr>();
handleSemanticStore(B, FD, SourceValue, Decl, AttrBegin, AttrEnd);
}

B.CreateRetVoid();

// Add and identify root signature to function, if applicable
Expand Down
34 changes: 34 additions & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,25 +176,47 @@ class CGHLSLRuntime {
HLSLAppliedSemanticAttr *Semantic,
std::optional<unsigned> Index);

void emitSystemSemanticStore(llvm::IRBuilder<> &B, llvm::Value *Source,
const clang::DeclaratorDecl *Decl,
HLSLAppliedSemanticAttr *Semantic,
std::optional<unsigned> Index);

llvm::Value *handleScalarSemanticLoad(llvm::IRBuilder<> &B,
const FunctionDecl *FD,
llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
HLSLAppliedSemanticAttr *Semantic);

void handleScalarSemanticStore(llvm::IRBuilder<> &B, const FunctionDecl *FD,
llvm::Value *Source,
const clang::DeclaratorDecl *Decl,
HLSLAppliedSemanticAttr *Semantic);

std::pair<llvm::Value *, specific_attr_iterator<HLSLAppliedSemanticAttr>>
handleStructSemanticLoad(
llvm::IRBuilder<> &B, const FunctionDecl *FD, llvm::Type *Type,
const clang::DeclaratorDecl *Decl,
specific_attr_iterator<HLSLAppliedSemanticAttr> begin,
specific_attr_iterator<HLSLAppliedSemanticAttr> end);

specific_attr_iterator<HLSLAppliedSemanticAttr> handleStructSemanticStore(
llvm::IRBuilder<> &B, const FunctionDecl *FD, llvm::Value *Source,
const clang::DeclaratorDecl *Decl,
specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd);

std::pair<llvm::Value *, specific_attr_iterator<HLSLAppliedSemanticAttr>>
handleSemanticLoad(llvm::IRBuilder<> &B, const FunctionDecl *FD,
llvm::Type *Type, const clang::DeclaratorDecl *Decl,
specific_attr_iterator<HLSLAppliedSemanticAttr> begin,
specific_attr_iterator<HLSLAppliedSemanticAttr> end);

specific_attr_iterator<HLSLAppliedSemanticAttr>
handleSemanticStore(llvm::IRBuilder<> &B, const FunctionDecl *FD,
llvm::Value *Source, const clang::DeclaratorDecl *Decl,
specific_attr_iterator<HLSLAppliedSemanticAttr> AttrBegin,
specific_attr_iterator<HLSLAppliedSemanticAttr> AttrEnd);

public:
CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {}
virtual ~CGHLSLRuntime() {}
Expand Down Expand Up @@ -249,10 +271,22 @@ class CGHLSLRuntime {
HLSLAppliedSemanticAttr *Semantic,
std::optional<unsigned> Index);

void emitSPIRVUserSemanticStore(llvm::IRBuilder<> &B, llvm::Value *Source,
HLSLAppliedSemanticAttr *Semantic,
std::optional<unsigned> Index);
void emitDXILUserSemanticStore(llvm::IRBuilder<> &B, llvm::Value *Source,
HLSLAppliedSemanticAttr *Semantic,
std::optional<unsigned> Index);
void emitUserSemanticStore(llvm::IRBuilder<> &B, llvm::Value *Source,
const clang::DeclaratorDecl *Decl,
HLSLAppliedSemanticAttr *Semantic,
std::optional<unsigned> Index);

llvm::Triple::ArchType getArch();

llvm::DenseMap<const clang::RecordType *, llvm::TargetExtType *> LayoutTypes;
unsigned SPIRVLastAssignedInputSemanticLocation = 0;
unsigned SPIRVLastAssignedOutputSemanticLocation = 0;
};

} // namespace CodeGen
Expand Down
31 changes: 21 additions & 10 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {

bool SemaHLSL::determineActiveSemanticOnScalar(
FunctionDecl *FD, DeclaratorDecl *OutputDecl, DeclaratorDecl *D,
SemanticInfo &ActiveSemantic, llvm::StringSet<> &ActiveInputSemantics) {
SemanticInfo &ActiveSemantic, llvm::StringSet<> &UsedSemantics) {
if (ActiveSemantic.Semantic == nullptr) {
ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
if (ActiveSemantic.Semantic)
Expand Down Expand Up @@ -805,7 +805,7 @@ bool SemaHLSL::determineActiveSemanticOnScalar(
for (unsigned I = 0; I < ElementCount; ++I) {
Twine VariableName = BaseName.concat(Twine(Location + I));

auto [_, Inserted] = ActiveInputSemantics.insert(VariableName.str());
auto [_, Inserted] = UsedSemantics.insert(VariableName.str());
if (!Inserted) {
Diag(D->getLocation(), diag::err_hlsl_semantic_index_overlap)
<< VariableName.str();
Expand All @@ -816,26 +816,29 @@ bool SemaHLSL::determineActiveSemanticOnScalar(
return true;
}

bool SemaHLSL::determineActiveSemantic(
FunctionDecl *FD, DeclaratorDecl *OutputDecl, DeclaratorDecl *D,
SemanticInfo &ActiveSemantic, llvm::StringSet<> &ActiveInputSemantics) {
bool SemaHLSL::determineActiveSemantic(FunctionDecl *FD,
DeclaratorDecl *OutputDecl,
DeclaratorDecl *D,
SemanticInfo &ActiveSemantic,
llvm::StringSet<> &UsedSemantics) {
if (ActiveSemantic.Semantic == nullptr) {
ActiveSemantic.Semantic = D->getAttr<HLSLParsedSemanticAttr>();
if (ActiveSemantic.Semantic)
ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
}

const Type *T = D->getType()->getUnqualifiedDesugaredType();
const Type *T = D == FD ? &*FD->getReturnType() : &*D->getType();
T = T->getUnqualifiedDesugaredType();

const RecordType *RT = dyn_cast<RecordType>(T);
if (!RT)
return determineActiveSemanticOnScalar(FD, OutputDecl, D, ActiveSemantic,
ActiveInputSemantics);
UsedSemantics);

const RecordDecl *RD = RT->getDecl();
for (FieldDecl *Field : RD->fields()) {
SemanticInfo Info = ActiveSemantic;
if (!determineActiveSemantic(FD, OutputDecl, Field, Info,
ActiveInputSemantics)) {
if (!determineActiveSemantic(FD, OutputDecl, Field, Info, UsedSemantics)) {
Diag(Field->getLocation(), diag::note_hlsl_semantic_used_here) << Field;
return false;
}
Expand Down Expand Up @@ -915,13 +918,21 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
if (ActiveSemantic.Semantic)
ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();

// FIXME: Verify output semantics in parameters.
if (!determineActiveSemantic(FD, Param, Param, ActiveSemantic,
ActiveInputSemantics)) {
Diag(Param->getLocation(), diag::note_previous_decl) << Param;
FD->setInvalidDecl();
}
}
// FIXME: Verify return type semantic annotation.

SemanticInfo ActiveSemantic;
llvm::StringSet<> ActiveOutputSemantics;
ActiveSemantic.Semantic = FD->getAttr<HLSLParsedSemanticAttr>();
if (ActiveSemantic.Semantic)
ActiveSemantic.Index = ActiveSemantic.Semantic->getSemanticIndex();
if (!FD->getReturnType()->isVoidType())
determineActiveSemantic(FD, FD, FD, ActiveSemantic, ActiveOutputSemantics);
}

void SemaHLSL::checkSemanticAnnotation(
Expand Down
Loading
Loading