Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ struct EntryProperties {
unsigned NumThreadsX{0}; // X component
unsigned NumThreadsY{0}; // Y component
unsigned NumThreadsZ{0}; // Z component
unsigned WaveSizeMin{0}; // Minimum component
unsigned WaveSizeMax{0}; // Maximum component
unsigned WaveSizePref{0}; // Preferred component

EntryProperties(const Function *Fn = nullptr) : Entry(Fn) {};
};
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Analysis/DXILMetadataAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
Success = llvm::to_integer(NumThreadsVec[2], EFP.NumThreadsZ, 10);
assert(Success && "Failed to parse Z component of numthreads");
}
// Get wavesize attribute value, if one exists
StringRef WaveSizeStr =
F.getFnAttribute("hlsl.wavesize").getValueAsString();
if (!WaveSizeStr.empty()) {
SmallVector<StringRef> WaveSizeVec;
WaveSizeStr.split(WaveSizeVec, ',');
assert(WaveSizeVec.size() == 3 && "Invalid wavesize specified");
// Read in the three component values of numthreads
[[maybe_unused]] bool Success =
llvm::to_integer(WaveSizeVec[0], EFP.WaveSizeMin, 10);
assert(Success && "Failed to parse Min component of wavesize");
Success = llvm::to_integer(WaveSizeVec[1], EFP.WaveSizeMax, 10);
assert(Success && "Failed to parse Max component of wavesize");
Success = llvm::to_integer(WaveSizeVec[2], EFP.WaveSizePref, 10);
assert(Success && "Failed to parse Preferred component of wavesize");
}
MMDAI.EntryPropertyVec.push_back(EFP);
}
return MMDAI;
Expand Down
66 changes: 53 additions & 13 deletions llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ enum class EntryPropsTag {
ASStateTag,
WaveSize,
EntryRootSig,
WaveRange = 23,
};

} // namespace
Expand Down Expand Up @@ -177,14 +178,15 @@ getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
case EntryPropsTag::ASStateTag:
case EntryPropsTag::WaveSize:
case EntryPropsTag::EntryRootSig:
case EntryPropsTag::WaveRange:
llvm_unreachable("NYI: Unhandled entry property tag");
}
return MDVals;
}

static MDTuple *
getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
const Triple::EnvironmentType ShaderProfile) {
static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP,
uint64_t EntryShaderFlags,
const ModuleMetadataInfo &MMDI) {
SmallVector<Metadata *> MDVals;
LLVMContext &Ctx = EP.Entry->getContext();
if (EntryShaderFlags != 0)
Expand All @@ -195,12 +197,13 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
// FIXME: support more props.
// See https://github.com/llvm/llvm-project/issues/57948.
// Add shader kind for lib entries.
if (ShaderProfile == Triple::EnvironmentType::Library &&
if (MMDI.ShaderProfile == Triple::EnvironmentType::Library &&
EP.ShaderStage != Triple::EnvironmentType::Library)
MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
getShaderStage(EP.ShaderStage), Ctx));

if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
// Handle mandatory "hlsl.numthreads"
MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
Expand All @@ -210,8 +213,48 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));

// Handle optional "hlsl.wavesize". The fields are optionally represented
// if they are non-zero.
if (EP.WaveSizeMin != 0) {
bool IsWaveRange = VersionTuple(6, 8) <= MMDI.ShaderModelVersion;
bool IsWaveSize =
!IsWaveRange && VersionTuple(6, 6) <= MMDI.ShaderModelVersion;

if (!IsWaveRange && !IsWaveSize) {
reportError(M, "Shader model 6.6 or greater is required to specify "
"the \"hlsl.wavesize\" function attribute");
return nullptr;
}

// A range is being specified if EP.WaveSizeMax != 0
if (EP.WaveSizeMax && !IsWaveRange) {
reportError(
M, "Shader model 6.8 or greater is required to specify "
"wave size range values of the \"hlsl.wavesize\" function "
"attribute");
return nullptr;
}

EntryPropsTag Tag =
IsWaveSize ? EntryPropsTag::WaveSize : EntryPropsTag::WaveRange;
MDVals.emplace_back(ConstantAsMetadata::get(
ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));

SmallVector<Metadata *> WaveSizeVals = {ConstantAsMetadata::get(
ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMin))};
if (IsWaveRange) {
WaveSizeVals.push_back(ConstantAsMetadata::get(
ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMax)));
WaveSizeVals.push_back(ConstantAsMetadata::get(
ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizePref)));
}

MDVals.emplace_back(MDNode::get(Ctx, WaveSizeVals));
}
}
}

if (MDVals.empty())
return nullptr;
return MDNode::get(Ctx, MDVals);
Expand All @@ -236,12 +279,11 @@ static MDTuple *constructEntryMetadata(const Function *EntryFn,
return MDNode::get(Ctx, MDVals);
}

static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
MDNode *MDResources,
static MDTuple *emitEntryMD(Module &M, const EntryProperties &EP,
MDTuple *Signatures, MDNode *MDResources,
const uint64_t EntryShaderFlags,
const Triple::EnvironmentType ShaderProfile) {
MDTuple *Properties =
getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
const ModuleMetadataInfo &MMDI) {
MDTuple *Properties = getEntryPropAsMetadata(M, EP, EntryShaderFlags, MMDI);
return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
EP.Entry->getContext());
}
Expand Down Expand Up @@ -523,10 +565,8 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
"'"));
}

EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
EntryShaderFlags,
MMDI.ShaderProfile));
EntryFnMDNodes.emplace_back(emitEntryMD(
M, EntryProp, Signatures, ResourceMD, EntryShaderFlags, MMDI));
}

NamedMDNode *EntryPointsNamedMD =
Expand Down
31 changes: 31 additions & 0 deletions llvm/test/CodeGen/DirectX/wavesize-md-errs.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
; RUN: split-file %s %t
; RUN: not opt -S --dxil-translate-metadata %t/low-sm.ll 2>&1 | FileCheck %t/low-sm.ll
; RUN: not opt -S --dxil-translate-metadata %t/low-sm-for-range.ll 2>&1 | FileCheck %t/low-sm-for-range.ll

; Test that wavesize metadata is only allowed on applicable shader model versions

;--- low-sm.ll

; CHECK: Shader model 6.6 or greater is required to specify the "hlsl.wavesize" function attribute

target triple = "dxil-unknown-shadermodel6.5-compute"

define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

;--- low-sm-for-range.ll

; CHECK: Shader model 6.8 or greater is required to specify wave size range values of the "hlsl.wavesize" function attribute

target triple = "dxil-unknown-shadermodel6.7-compute"

define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
71 changes: 71 additions & 0 deletions llvm/test/CodeGen/DirectX/wavesize-md-valid.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
; RUN: split-file %s %t
; RUN: opt -S --dxil-translate-metadata %t/only.ll | FileCheck %t/only.ll
; RUN: opt -S --dxil-translate-metadata %t/min.ll | FileCheck %t/min.ll
; RUN: opt -S --dxil-translate-metadata %t/max.ll | FileCheck %t/max.ll
; RUN: opt -S --dxil-translate-metadata %t/pref.ll | FileCheck %t/pref.ll

; Test that wave size/range metadata is correctly generated with the correct tag

;--- only.ll

; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
; CHECK: ![[#PROPS]] = !{{{.*}}i32 11, ![[#WAVE_SIZE:]]{{.*}}}
; CHECK: ![[#WAVE_SIZE]] = !{i32 16}

target triple = "dxil-unknown-shadermodel6.6-compute"

define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

;--- min.ll

; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 0, i32 0}

target triple = "dxil-unknown-shadermodel6.8-compute"

define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.wavesize"="16,0,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

;--- max.ll

; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 32, i32 0}

target triple = "dxil-unknown-shadermodel6.8-compute"

define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.wavesize"="16,32,0" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

;--- pref.ll

; CHECK: !dx.entryPoints = !{![[#ENTRY:]]}
; CHECK: ![[#ENTRY]] = !{ptr @main, !"main", null, null, ![[#PROPS:]]}
; CHECK: ![[#PROPS]] = !{{{.*}}i32 23, ![[#WAVE_SIZE:]]{{.*}}}
; CHECK: ![[#WAVE_SIZE]] = !{i32 16, i32 64, i32 32}

target triple = "dxil-unknown-shadermodel6.8-compute"

define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.wavesize"="16,64,32" "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }