Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/include/llvm/SYCLLowerIR/TargetHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct KernelPayload {
KernelPayload(Function *Kernel, MDNode *MD = nullptr);
Function *Kernel;
MDNode *MD;
SmallVector<MDNode *> DependentMDs;
};

ArchType getArchType(const Module &M);
Expand Down
9 changes: 7 additions & 2 deletions llvm/lib/SYCLLowerIR/LocalAccessorToSharedMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,12 @@ Function *LocalAccessorToSharedMemoryPass::processKernel(Module &M,
void LocalAccessorToSharedMemoryPass::postProcessKernels(
SmallVectorImpl<std::pair<Function *, KernelPayload>> &NewToOldKernels) {
for (auto &Pair : NewToOldKernels) {
std::get<1>(Pair).MD->replaceOperandWith(
0, llvm::ConstantAsMetadata::get(std::get<0>(Pair)));
auto KP = std::get<1>(Pair);
auto *F = std::get<0>(Pair);
KP.MD->replaceOperandWith(0, llvm::ConstantAsMetadata::get(F));
// The MD node of the kernel has been altered, make sure that all the
// dependent nodes are kept up to date.
for (MDNode *D : KP.DependentMDs)
D->replaceOperandWith(0, llvm::ConstantAsMetadata::get(F));
}
}
34 changes: 32 additions & 2 deletions llvm/lib/SYCLLowerIR/TargetHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ void populateKernels(Module &M, SmallVectorImpl<KernelPayload> &Kernels,
if (!AnnotationMetadata)
return;

SmallVector<MDNode *, 4> PossibleDependencies;
// It is possible that the annotations node contains multiple pointers to the
// same metadata, recognise visited ones.
SmallSet<MDNode *, 4> Visited;
Expand All @@ -70,9 +71,12 @@ void populateKernels(Module &M, SmallVectorImpl<KernelPayload> &Kernels,
auto *Type = dyn_cast<MDString>(MetadataNode->getOperand(1));
if (!Type)
continue;
// Only process kernel entry points.
if (Type->getString() != "kernel")
// Only process kernel entry points,
if (Type->getString() != "kernel") {
// but keep track of other nodes that point to the same function.
PossibleDependencies.push_back(MetadataNode);
continue;
}

// Get a pointer to the entry point function from the metadata.
const MDOperand &FuncOperand = MetadataNode->getOperand(0);
Expand All @@ -82,6 +86,32 @@ void populateKernels(Module &M, SmallVectorImpl<KernelPayload> &Kernels,
if (auto *Func = dyn_cast<Function>(FuncConstant->getValue()))
Kernels.push_back(KernelPayload(Func, MetadataNode));
}

// We need to match non-kernel metadata nodes using the kernel name to the
// kernel nodes. To avoid checking matched nodes multiple times keep track of
// handled entries.
SmallSet<unsigned, 4> HandledNodes;
for (auto &KP : Kernels) {
auto *KernelConstant = cast<ConstantAsMetadata>(KP.MD->getOperand(0));
auto KernelName =
cast<Function>(KernelConstant->getValue())->getFunction().getName();
for (unsigned I = 0; I < PossibleDependencies.size(); ++I) {
if (HandledNodes.contains(I))
continue;
MDNode *Dep = PossibleDependencies[I];
const MDOperand &FuncOperand = Dep->getOperand(0);
if (!FuncOperand)
continue;
if (auto *FuncConstant = dyn_cast<ConstantAsMetadata>(FuncOperand))
if (auto *Func = dyn_cast<Function>(FuncConstant->getValue()))
// We've found a match, append the dependent node to the kernel
// payload and keep track of matched entries.
if (KernelName == Func->getFunction().getName()) {
KP.DependentMDs.push_back(Dep);
HandledNodes.insert(I);
}
}
}
}

} // namespace TargetHelpers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ source_filename = "basic-transformation.ll"
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

; This test checks that the transformation is applied in the basic case.
; This test checks that the transformation is applied in the basic case. It
; also makes sure that a non-kernel node using the function's signature gets
; correcly updated (`maxntid`).

; CHECK: @_ZTS14example_kernel_shared_mem = external addrspace(3) global [0 x i8], align 4

Expand All @@ -23,13 +25,15 @@ entry:
ret void
}

!nvvm.annotations = !{!0, !1, !2, !1, !3, !3, !3, !3, !4, !4, !3}
!nvvmir.version = !{!5}
!nvvm.annotations = !{!0, !1, !2, !1, !3, !3, !3, !3, !4, !4, !3, !5}
!nvvmir.version = !{!6}

!0 = distinct !{void (i32 addrspace(3)*, i32 addrspace(1)*, i32)* @_ZTS14example_kernel, !"kernel", i32 1}
; CHECK: !0 = distinct !{void (i32, i32 addrspace(1)*, i32)* @_ZTS14example_kernel, !"kernel", i32 1}
!1 = !{null, !"align", i32 8}
!2 = !{null, !"align", i32 8, !"align", i32 65544, !"align", i32 131080}
!3 = !{null, !"align", i32 16}
!4 = !{null, !"align", i32 16, !"align", i32 65552, !"align", i32 131088}
!5 = !{i32 1, i32 4}
; CHECK: !5 = distinct !{void (i32, i32 addrspace(1)*, i32)* @_ZTS14example_kernel, !"maxntidx", i32 256}
!5 = !{void (i32 addrspace(3)*, i32 addrspace(1)*, i32)* @_ZTS14example_kernel, !"maxntidx", i32 256}
!6 = !{i32 1, i32 4}