128 changes: 127 additions & 1 deletion libc/utils/gpu/loader/nvptx/Loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@
#include "Server.h"

#include "cuda.h"

#include "llvm/Object/ELF.h"
#include "llvm/Object/ELFObjectFile.h"

#include <cstddef>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <vector>

using namespace llvm;
using namespace object;

/// The arguments to the '_start' kernel.
struct kernel_args_t {
Expand Down Expand Up @@ -51,11 +59,122 @@ static void handle_error(const char *msg) {
exit(EXIT_FAILURE);
}

// Gets the names of all the globals that contain functions to initialize or
// deinitialize. We need to do this manually because the NVPTX toolchain does
// not contain the necessary binary manipulation tools.
template <typename Alloc>
Expected<void *> get_ctor_dtor_array(const void *image, const size_t size,
Alloc allocator, CUmodule binary) {
auto mem_buffer = MemoryBuffer::getMemBuffer(
StringRef(reinterpret_cast<const char *>(image), size), "image",
/*RequiresNullTerminator=*/false);
Expected<ELF64LEObjectFile> elf_or_err =
ELF64LEObjectFile::create(*mem_buffer);
if (!elf_or_err)
handle_error(toString(elf_or_err.takeError()).c_str());

std::vector<std::pair<const char *, uint16_t>> ctors;
std::vector<std::pair<const char *, uint16_t>> dtors;
// CUDA has no way to iterate over all the symbols so we need to inspect the
// ELF directly using the LLVM libraries.
for (const auto &symbol : elf_or_err->symbols()) {
auto name_or_err = symbol.getName();
if (!name_or_err)
handle_error(toString(name_or_err.takeError()).c_str());

// Search for all symbols that contain a constructor or destructor.
if (!name_or_err->starts_with("__init_array_object_") &&
!name_or_err->starts_with("__fini_array_object_"))
continue;

uint16_t priority;
if (name_or_err->rsplit('_').second.getAsInteger(10, priority))
handle_error("Invalid priority for constructor or destructor");

if (name_or_err->starts_with("__init"))
ctors.emplace_back(std::make_pair(name_or_err->data(), priority));
else
dtors.emplace_back(std::make_pair(name_or_err->data(), priority));
}
// Lower priority constructors are run before higher ones. The reverse is true
// for destructors.
llvm::sort(ctors, [](auto x, auto y) { return x.second < y.second; });
llvm::sort(dtors, [](auto x, auto y) { return x.second < y.second; });
llvm::reverse(dtors);

// Allocate host pinned memory to make these arrays visible to the GPU.
CUdeviceptr *dev_memory = reinterpret_cast<CUdeviceptr *>(allocator(
ctors.size() * sizeof(CUdeviceptr) + dtors.size() * sizeof(CUdeviceptr)));
uint64_t global_size = 0;

// Get the address of the global and then store the address of the constructor
// function to call in the constructor array.
CUdeviceptr *dev_ctors_start = dev_memory;
CUdeviceptr *dev_ctors_end = dev_ctors_start + ctors.size();
for (uint64_t i = 0; i < ctors.size(); ++i) {
CUdeviceptr dev_ptr;
if (CUresult err =
cuModuleGetGlobal(&dev_ptr, &global_size, binary, ctors[i].first))
handle_error(err);
if (CUresult err =
cuMemcpyDtoH(&dev_ctors_start[i], dev_ptr, sizeof(uintptr_t)))
handle_error(err);
}

// Get the address of the global and then store the address of the destructor
// function to call in the destructor array.
CUdeviceptr *dev_dtors_start = dev_ctors_end;
CUdeviceptr *dev_dtors_end = dev_dtors_start + dtors.size();
for (uint64_t i = 0; i < dtors.size(); ++i) {
CUdeviceptr dev_ptr;
if (CUresult err =
cuModuleGetGlobal(&dev_ptr, &global_size, binary, dtors[i].first))
handle_error(err);
if (CUresult err =
cuMemcpyDtoH(&dev_dtors_start[i], dev_ptr, sizeof(uintptr_t)))
handle_error(err);
}

// Obtain the address of the pointers the startup implementation uses to
// iterate the constructors and destructors.
CUdeviceptr init_start;
if (CUresult err = cuModuleGetGlobal(&init_start, &global_size, binary,
"__init_array_start"))
handle_error(err);
CUdeviceptr init_end;
if (CUresult err = cuModuleGetGlobal(&init_end, &global_size, binary,
"__init_array_end"))
handle_error(err);
CUdeviceptr fini_start;
if (CUresult err = cuModuleGetGlobal(&fini_start, &global_size, binary,
"__fini_array_start"))
handle_error(err);
CUdeviceptr fini_end;
if (CUresult err = cuModuleGetGlobal(&fini_end, &global_size, binary,
"__fini_array_end"))
handle_error(err);

// Copy the pointers to the newly written array to the symbols so the startup
// implementation can iterate them.
if (CUresult err =
cuMemcpyHtoD(init_start, &dev_ctors_start, sizeof(uintptr_t)))
handle_error(err);
if (CUresult err = cuMemcpyHtoD(init_end, &dev_ctors_end, sizeof(uintptr_t)))
handle_error(err);
if (CUresult err =
cuMemcpyHtoD(fini_start, &dev_dtors_start, sizeof(uintptr_t)))
handle_error(err);
if (CUresult err = cuMemcpyHtoD(fini_end, &dev_dtors_end, sizeof(uintptr_t)))
handle_error(err);

return dev_memory;
}

int load(int argc, char **argv, char **envp, void *image, size_t size,
const LaunchParameters &params) {

if (CUresult err = cuInit(0))
handle_error(err);

// Obtain the first device found on the system.
CUdevice device;
if (CUresult err = cuDeviceGet(&device, 0))
Expand Down Expand Up @@ -91,6 +210,11 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
handle_error(err);
return dev_ptr;
};

auto memory_or_err = get_ctor_dtor_array(image, size, allocator, binary);
if (!memory_or_err)
handle_error(toString(memory_or_err.takeError()).c_str());

void *dev_argv = copy_argument_vector(argc, argv, allocator);
if (!dev_argv)
handle_error("Failed to allocate device argv");
Expand Down Expand Up @@ -153,6 +277,8 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
handle_error(err);

// Free the memory allocated for the device.
if (CUresult err = cuMemFreeHost(*memory_or_err))
handle_error(err);
if (CUresult err = cuMemFree(dev_ret))
handle_error(err);
if (CUresult err = cuMemFreeHost(dev_argv))
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ set(NVPTXCodeGen_sources
NVVMIntrRange.cpp
NVVMReflect.cpp
NVPTXProxyRegErasure.cpp
NVPTXCtorDtorLowering.cpp
)

add_llvm_target(NVPTXCodeGen
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ FunctionPass *createNVPTXISelDag(NVPTXTargetMachine &TM,
llvm::CodeGenOpt::Level OptLevel);
ModulePass *createNVPTXAssignValidGlobalNamesPass();
ModulePass *createGenericToNVVMLegacyPass();
ModulePass *createNVPTXCtorDtorLoweringLegacyPass();
FunctionPass *createNVVMIntrRangePass(unsigned int SmVersion);
FunctionPass *createNVVMReflectPass(unsigned int SmVersion);
MachineFunctionPass *createNVPTXPrologEpilogPass();
Expand Down
11 changes: 9 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@

using namespace llvm;

static cl::opt<bool>
LowerCtorDtor("nvptx-lower-global-ctor-dtor",
cl::desc("Lower GPU ctor / dtors to globals on the device."),
cl::init(false), cl::Hidden);

#define DEPOTNAME "__local_depot"

/// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
Expand Down Expand Up @@ -788,12 +793,14 @@ bool NVPTXAsmPrinter::doInitialization(Module &M) {
report_fatal_error("Module has aliases, which NVPTX does not support.");
return true; // error
}
if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors"))) {
if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) &&
!LowerCtorDtor) {
report_fatal_error(
"Module has a nontrivial global ctor, which NVPTX does not support.");
return true; // error
}
if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors"))) {
if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) &&
!LowerCtorDtor) {
report_fatal_error(
"Module has a nontrivial global dtor, which NVPTX does not support.");
return true; // error
Expand Down
116 changes: 116 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXCtorDtorLowering.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//===-- NVPTXCtorDtorLowering.cpp - Handle global ctors and dtors --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This pass creates a unified init and fini kernel with the required metadata
//===----------------------------------------------------------------------===//

#include "NVPTXCtorDtorLowering.h"
#include "NVPTX.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"

using namespace llvm;

#define DEBUG_TYPE "nvptx-lower-ctor-dtor"

static cl::opt<std::string>
GlobalStr("nvptx-lower-global-ctor-dtor-id",
cl::desc("Override unique ID of ctor/dtor globals."),
cl::init(""), cl::Hidden);

namespace {

static std::string getHash(StringRef Str) {
llvm::MD5 Hasher;
llvm::MD5::MD5Result Hash;
Hasher.update(Str);
Hasher.final(Hash);
return llvm::utohexstr(Hash.low(), /*LowerCase=*/true);
}

static bool createInitOrFiniGlobls(Module &M, StringRef GlobalName,
bool IsCtor) {
GlobalVariable *GV = M.getGlobalVariable(GlobalName);
if (!GV || !GV->hasInitializer())
return false;
ConstantArray *GA = dyn_cast<ConstantArray>(GV->getInitializer());
if (!GA || GA->getNumOperands() == 0)
return false;

// NVPTX has no way to emit variables at specific sections or support for
// the traditional constructor sections. Instead, we emit mangled global
// names so the runtime can build the list manually.
for (Value *V : GA->operands()) {
auto *CS = cast<ConstantStruct>(V);
auto *F = cast<Constant>(CS->getOperand(1));
uint64_t Priority = cast<ConstantInt>(CS->getOperand(0))->getSExtValue();
std::string PriorityStr = "." + std::to_string(Priority);
// We append a semi-unique hash and the priority to the global name.
std::string GlobalID =
!GlobalStr.empty() ? GlobalStr : getHash(M.getSourceFileName());
std::string NameStr =
((IsCtor ? "__init_array_object_" : "__fini_array_object_") +
F->getName() + "_" + GlobalID + "_" + std::to_string(Priority))
.str();
// PTX does not support exported names with '.' in them.
llvm::transform(NameStr, NameStr.begin(),
[](char c) { return c == '.' ? '_' : c; });

auto *GV = new GlobalVariable(M, F->getType(), /*IsConstant=*/true,
GlobalValue::ExternalLinkage, F, NameStr,
nullptr, GlobalValue::NotThreadLocal,
/*AddressSpace=*/4);
// This isn't respected by Nvidia, simply put here for clarity.
GV->setSection(IsCtor ? ".init_array" + PriorityStr
: ".fini_array" + PriorityStr);
GV->setVisibility(GlobalVariable::ProtectedVisibility);
appendToUsed(M, {GV});
}

GV->eraseFromParent();
return true;
}

static bool lowerCtorsAndDtors(Module &M) {
bool Modified = false;
Modified |= createInitOrFiniGlobls(M, "llvm.global_ctors", /*IsCtor =*/true);
Modified |= createInitOrFiniGlobls(M, "llvm.global_dtors", /*IsCtor =*/false);
return Modified;
}

class NVPTXCtorDtorLoweringLegacy final : public ModulePass {
public:
static char ID;
NVPTXCtorDtorLoweringLegacy() : ModulePass(ID) {}
bool runOnModule(Module &M) override { return lowerCtorsAndDtors(M); }
};

} // End anonymous namespace

PreservedAnalyses NVPTXCtorDtorLoweringPass::run(Module &M,
ModuleAnalysisManager &AM) {
return lowerCtorsAndDtors(M) ? PreservedAnalyses::none()
: PreservedAnalyses::all();
}

char NVPTXCtorDtorLoweringLegacy::ID = 0;
char &llvm::NVPTXCtorDtorLoweringLegacyPassID = NVPTXCtorDtorLoweringLegacy::ID;
INITIALIZE_PASS(NVPTXCtorDtorLoweringLegacy, DEBUG_TYPE,
"Lower ctors and dtors for NVPTX", false, false)

ModulePass *llvm::createNVPTXCtorDtorLoweringLegacyPass() {
return new NVPTXCtorDtorLoweringLegacy();
}
30 changes: 30 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXCtorDtorLowering.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===-- NVPTXCtorDtorLowering.h --------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIB_TARGET_NVPTX_NVPTXCTORDTORLOWERING_H
#define LLVM_LIB_TARGET_NVPTX_NVPTXCTORDTORLOWERING_H

#include "llvm/IR/PassManager.h"

namespace llvm {
class Module;
class PassRegistry;

extern char &NVPTXCtorDtorLoweringLegacyPassID;
extern void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &);

/// Lower llvm.global_ctors and llvm.global_dtors to special kernels.
class NVPTXCtorDtorLoweringPass
: public PassInfoMixin<NVPTXCtorDtorLoweringPass> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
};

} // namespace llvm

#endif // LLVM_LIB_TARGET_NVPTX_NVPTXCTORDTORLOWERING_H
9 changes: 9 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "NVPTXAliasAnalysis.h"
#include "NVPTXAllocaHoisting.h"
#include "NVPTXAtomicLower.h"
#include "NVPTXCtorDtorLowering.h"
#include "NVPTXLowerAggrCopies.h"
#include "NVPTXMachineFunctionInfo.h"
#include "NVPTXTargetObjectFile.h"
Expand Down Expand Up @@ -68,8 +69,10 @@ void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
void initializeNVPTXAllocaHoistingPass(PassRegistry &);
void initializeNVPTXAssignValidGlobalNamesPass(PassRegistry&);
void initializeNVPTXAtomicLowerPass(PassRegistry &);
void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &);
void initializeNVPTXLowerAggrCopiesPass(PassRegistry &);
void initializeNVPTXLowerAllocaPass(PassRegistry &);
void initializeNVPTXCtorDtorLoweringLegacyPass(PassRegistry &);
void initializeNVPTXLowerArgsPass(PassRegistry &);
void initializeNVPTXProxyRegErasurePass(PassRegistry &);
void initializeNVVMIntrRangePass(PassRegistry &);
Expand All @@ -95,6 +98,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
initializeNVPTXAtomicLowerPass(PR);
initializeNVPTXLowerArgsPass(PR);
initializeNVPTXLowerAllocaPass(PR);
initializeNVPTXCtorDtorLoweringLegacyPass(PR);
initializeNVPTXLowerAggrCopiesPass(PR);
initializeNVPTXProxyRegErasurePass(PR);
initializeNVPTXDAGToDAGISelPass(PR);
Expand Down Expand Up @@ -249,6 +253,10 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
PB.registerPipelineParsingCallback(
[](StringRef PassName, ModulePassManager &PM,
ArrayRef<PassBuilder::PipelineElement>) {
if (PassName == "nvptx-lower-ctor-dtor") {
PM.addPass(NVPTXCtorDtorLoweringPass());
return true;
}
if (PassName == "generic-to-nvvm") {
PM.addPass(GenericToNVVMPass());
return true;
Expand Down Expand Up @@ -369,6 +377,7 @@ void NVPTXPassConfig::addIRPasses() {
}

addPass(createAtomicExpandPass());
addPass(createNVPTXCtorDtorLoweringLegacyPass());

// === LSR and other generic IR passes ===
TargetPassConfig::addIRPasses();
Expand Down
32 changes: 32 additions & 0 deletions llvm/test/CodeGen/NVPTX/lower-ctor-dtor.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
; RUN: opt -S -mtriple=nvptx64-- -nvptx-lower-ctor-dtor < %s | FileCheck %s
; RUN: opt -S -mtriple=nvptx64-- -passes=nvptx-lower-ctor-dtor < %s | FileCheck %s
; RUN: opt -S -mtriple=nvptx64-- -passes=nvptx-lower-ctor-dtor \
; RUN: -nvptx-lower-global-ctor-dtor-id=unique_id < %s | FileCheck %s --check-prefix=GLOBAL

; Make sure we get the same result if we run multiple times
; RUN: opt -S -mtriple=nvptx64-- -passes=nvptx-lower-ctor-dtor,nvptx-lower-ctor-dtor < %s | FileCheck %s
; RUN: llc -nvptx-lower-global-ctor-dtor -mtriple=nvptx64-amd-amdhsa -mcpu=sm_70 -filetype=asm -o - < %s | FileCheck %s -check-prefix=VISIBILITY

@llvm.global_ctors = appending addrspace(1) global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 1, ptr @foo, ptr null }]
@llvm.global_dtors = appending addrspace(1) global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 1, ptr @bar, ptr null }]

; CHECK-NOT: @llvm.global_ctors
; CHECK-NOT: @llvm.global_dtors

; CHECK: @__init_array_object_foo_[[HASH:[0-9a-f]+]]_1 = protected addrspace(4) constant ptr @foo, section ".init_array.1"
; CHECK: @__fini_array_object_bar_[[HASH:[0-9a-f]+]]_1 = protected addrspace(4) constant ptr @bar, section ".fini_array.1"
; CHECK: @llvm.used = appending global [2 x ptr] [ptr addrspacecast (ptr addrspace(4) @__init_array_object_foo_[[HASH]]_1 to ptr), ptr addrspacecast (ptr addrspace(4) @__fini_array_object_bar_[[HASH]]_1 to ptr)], section "llvm.metadata"
; GLOBAL: @__init_array_object_foo_unique_id_1 = protected addrspace(4) constant ptr @foo, section ".init_array.1"
; GLOBAL: @__fini_array_object_bar_unique_id_1 = protected addrspace(4) constant ptr @bar, section ".fini_array.1"
; GLOBAL: @llvm.used = appending global [2 x ptr] [ptr addrspacecast (ptr addrspace(4) @__init_array_object_foo_unique_id_1 to ptr), ptr addrspacecast (ptr addrspace(4) @__fini_array_object_bar_unique_id_1 to ptr)], section "llvm.metadata"

; VISIBILITY: .visible .const .align 8 .u64 __init_array_object_foo_[[HASH:[0-9a-f]+]]_1 = foo;
; VISIBILITY: .visible .const .align 8 .u64 __fini_array_object_bar_[[HASH:[0-9a-f]+]]_1 = bar;

define internal void @foo() {
ret void
}

define internal void @bar() {
ret void
}