diff --git a/llvm/include/llvm/SYCLLowerIR/TargetHelpers.h b/llvm/include/llvm/SYCLLowerIR/TargetHelpers.h index b2b383237a705..fba50396e6be2 100644 --- a/llvm/include/llvm/SYCLLowerIR/TargetHelpers.h +++ b/llvm/include/llvm/SYCLLowerIR/TargetHelpers.h @@ -28,6 +28,7 @@ struct KernelPayload { KernelPayload(Function *Kernel, MDNode *MD = nullptr); Function *Kernel; MDNode *MD; + SmallVector DependentMDs; }; ArchType getArchType(const Module &M); diff --git a/llvm/lib/SYCLLowerIR/LocalAccessorToSharedMemory.cpp b/llvm/lib/SYCLLowerIR/LocalAccessorToSharedMemory.cpp index 18c5bb00fb488..febf46177a6c4 100644 --- a/llvm/lib/SYCLLowerIR/LocalAccessorToSharedMemory.cpp +++ b/llvm/lib/SYCLLowerIR/LocalAccessorToSharedMemory.cpp @@ -208,7 +208,12 @@ Function *LocalAccessorToSharedMemoryPass::processKernel(Module &M, void LocalAccessorToSharedMemoryPass::postProcessKernels( SmallVectorImpl> &NewToOldKernels) { for (auto &Pair : NewToOldKernels) { - std::get<1>(Pair).MD->replaceOperandWith( - 0, llvm::ConstantAsMetadata::get(std::get<0>(Pair))); + auto KP = std::get<1>(Pair); + auto *F = std::get<0>(Pair); + KP.MD->replaceOperandWith(0, llvm::ConstantAsMetadata::get(F)); + // The MD node of the kernel has been altered, make sure that all the + // dependent nodes are kept up to date. + for (MDNode *D : KP.DependentMDs) + D->replaceOperandWith(0, llvm::ConstantAsMetadata::get(F)); } } diff --git a/llvm/lib/SYCLLowerIR/TargetHelpers.cpp b/llvm/lib/SYCLLowerIR/TargetHelpers.cpp index a4a7e35cfc297..8c45148d181e4 100644 --- a/llvm/lib/SYCLLowerIR/TargetHelpers.cpp +++ b/llvm/lib/SYCLLowerIR/TargetHelpers.cpp @@ -56,6 +56,7 @@ void populateKernels(Module &M, SmallVectorImpl &Kernels, if (!AnnotationMetadata) return; + SmallVector PossibleDependencies; // It is possible that the annotations node contains multiple pointers to the // same metadata, recognise visited ones. SmallSet Visited; @@ -70,9 +71,12 @@ void populateKernels(Module &M, SmallVectorImpl &Kernels, auto *Type = dyn_cast(MetadataNode->getOperand(1)); if (!Type) continue; - // Only process kernel entry points. - if (Type->getString() != "kernel") + // Only process kernel entry points, + if (Type->getString() != "kernel") { + // but keep track of other nodes that point to the same function. + PossibleDependencies.push_back(MetadataNode); continue; + } // Get a pointer to the entry point function from the metadata. const MDOperand &FuncOperand = MetadataNode->getOperand(0); @@ -82,6 +86,32 @@ void populateKernels(Module &M, SmallVectorImpl &Kernels, if (auto *Func = dyn_cast(FuncConstant->getValue())) Kernels.push_back(KernelPayload(Func, MetadataNode)); } + + // We need to match non-kernel metadata nodes using the kernel name to the + // kernel nodes. To avoid checking matched nodes multiple times keep track of + // handled entries. + SmallSet HandledNodes; + for (auto &KP : Kernels) { + auto *KernelConstant = cast(KP.MD->getOperand(0)); + auto KernelName = + cast(KernelConstant->getValue())->getFunction().getName(); + for (unsigned I = 0; I < PossibleDependencies.size(); ++I) { + if (HandledNodes.contains(I)) + continue; + MDNode *Dep = PossibleDependencies[I]; + const MDOperand &FuncOperand = Dep->getOperand(0); + if (!FuncOperand) + continue; + if (auto *FuncConstant = dyn_cast(FuncOperand)) + if (auto *Func = dyn_cast(FuncConstant->getValue())) + // We've found a match, append the dependent node to the kernel + // payload and keep track of matched entries. + if (KernelName == Func->getFunction().getName()) { + KP.DependentMDs.push_back(Dep); + HandledNodes.insert(I); + } + } + } } } // namespace TargetHelpers diff --git a/llvm/test/CodeGen/NVPTX/local-accessor-to-shared-memory-basic-transformation.ll b/llvm/test/CodeGen/NVPTX/local-accessor-to-shared-memory-basic-transformation.ll index b7e0103f5949d..31785f3303a49 100644 --- a/llvm/test/CodeGen/NVPTX/local-accessor-to-shared-memory-basic-transformation.ll +++ b/llvm/test/CodeGen/NVPTX/local-accessor-to-shared-memory-basic-transformation.ll @@ -4,7 +4,9 @@ source_filename = "basic-transformation.ll" target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64" target triple = "nvptx64-nvidia-cuda" -; This test checks that the transformation is applied in the basic case. +; This test checks that the transformation is applied in the basic case. It +; also makes sure that a non-kernel node using the function's signature gets +; correcly updated (`maxntid`). ; CHECK: @_ZTS14example_kernel_shared_mem = external addrspace(3) global [0 x i8], align 4 @@ -23,8 +25,8 @@ entry: ret void } -!nvvm.annotations = !{!0, !1, !2, !1, !3, !3, !3, !3, !4, !4, !3} -!nvvmir.version = !{!5} +!nvvm.annotations = !{!0, !1, !2, !1, !3, !3, !3, !3, !4, !4, !3, !5} +!nvvmir.version = !{!6} !0 = distinct !{void (i32 addrspace(3)*, i32 addrspace(1)*, i32)* @_ZTS14example_kernel, !"kernel", i32 1} ; CHECK: !0 = distinct !{void (i32, i32 addrspace(1)*, i32)* @_ZTS14example_kernel, !"kernel", i32 1} @@ -32,4 +34,6 @@ entry: !2 = !{null, !"align", i32 8, !"align", i32 65544, !"align", i32 131080} !3 = !{null, !"align", i32 16} !4 = !{null, !"align", i32 16, !"align", i32 65552, !"align", i32 131088} -!5 = !{i32 1, i32 4} +; CHECK: !5 = distinct !{void (i32, i32 addrspace(1)*, i32)* @_ZTS14example_kernel, !"maxntidx", i32 256} +!5 = !{void (i32 addrspace(3)*, i32 addrspace(1)*, i32)* @_ZTS14example_kernel, !"maxntidx", i32 256} +!6 = !{i32 1, i32 4}