diff --git a/offload/include/omptarget.h b/offload/include/omptarget.h index 8fd722bb15022..3317441f04eba 100644 --- a/offload/include/omptarget.h +++ b/offload/include/omptarget.h @@ -94,6 +94,8 @@ enum OpenMPOffloadingDeclareTargetFlags { OMP_DECLARE_TARGET_INDIRECT = 0x08, /// This is an entry corresponding to a requirement to be registered. OMP_REGISTER_REQUIRES = 0x10, + /// Mark the entry global as being an indirect vtable. + OMP_DECLARE_TARGET_INDIRECT_VTABLE = 0x20, }; enum TargetAllocTy : int32_t { diff --git a/offload/libomptarget/PluginManager.cpp b/offload/libomptarget/PluginManager.cpp index b57a2f815cba6..6fc330b92f0f5 100644 --- a/offload/libomptarget/PluginManager.cpp +++ b/offload/libomptarget/PluginManager.cpp @@ -434,7 +434,8 @@ static int loadImagesOntoDevice(DeviceTy &Device) { llvm::offloading::EntryTy DeviceEntry = Entry; if (Entry.Size) { - if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName, + if (!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE) && + Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName, &DeviceEntry.Address) != OFFLOAD_SUCCESS) REPORT("Failed to load symbol %s\n", Entry.SymbolName); @@ -443,7 +444,9 @@ static int loadImagesOntoDevice(DeviceTy &Device) { // the device to point to the memory on the host. if ((PM->getRequirements() & OMP_REQ_UNIFIED_SHARED_MEMORY) || (PM->getRequirements() & OMPX_REQ_AUTO_ZERO_COPY)) { - if (Device.RTL->data_submit(DeviceId, DeviceEntry.Address, + if (!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE) && + !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT) && + Device.RTL->data_submit(DeviceId, DeviceEntry.Address, Entry.Address, Entry.Size) != OFFLOAD_SUCCESS) REPORT("Failed to write symbol for USM %s\n", Entry.SymbolName); diff --git a/offload/libomptarget/device.cpp b/offload/libomptarget/device.cpp index 71423ae0c94d9..d5436bde47ba5 100644 --- a/offload/libomptarget/device.cpp +++ b/offload/libomptarget/device.cpp @@ -112,21 +112,58 @@ setupIndirectCallTable(DeviceTy &Device, __tgt_device_image *Image, llvm::SmallVector> IndirectCallTable; for (const auto &Entry : Entries) { if (Entry.Kind != llvm::object::OffloadKind::OFK_OpenMP || - Entry.Size == 0 || !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT)) + Entry.Size == 0 || + (!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT) && + !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE))) continue; - assert(Entry.Size == sizeof(void *) && "Global not a function pointer?"); - auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back(); - - void *Ptr; - if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName, &Ptr)) - return error::createOffloadError(error::ErrorCode::INVALID_BINARY, - "failed to load %s", Entry.SymbolName); - - HstPtr = Entry.Address; - if (Device.retrieveData(&DevPtr, Ptr, Entry.Size, AsyncInfo)) - return error::createOffloadError(error::ErrorCode::INVALID_BINARY, - "failed to load %s", Entry.SymbolName); + size_t PtrSize = sizeof(void *); + if (Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE) { + // This is a VTable entry, the current entry is the first index of the + // VTable and Entry.Size is the total size of the VTable. Unlike the + // indirect function case below, the Global is not of size Entry.Size and + // is instead of size PtrSize (sizeof(void*)). + void *Vtable; + void *res; + if (Device.RTL->get_global(Binary, PtrSize, Entry.SymbolName, &Vtable)) + return error::createOffloadError(error::ErrorCode::INVALID_BINARY, + "failed to load %s", Entry.SymbolName); + + // HstPtr = Entry.Address; + if (Device.retrieveData(&res, Vtable, PtrSize, AsyncInfo)) + return error::createOffloadError(error::ErrorCode::INVALID_BINARY, + "failed to load %s", Entry.SymbolName); + if (Device.synchronize(AsyncInfo)) + return error::createOffloadError( + error::ErrorCode::INVALID_BINARY, + "failed to synchronize after retrieving %s", Entry.SymbolName); + // Calculate and emplace entire Vtable from first Vtable byte + for (uint64_t i = 0; i < Entry.Size / PtrSize; ++i) { + auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back(); + HstPtr = reinterpret_cast( + reinterpret_cast(Entry.Address) + i * PtrSize); + DevPtr = reinterpret_cast(reinterpret_cast(res) + + i * PtrSize); + } + } else { + // Indirect function case: Entry.Size should equal PtrSize since we're + // dealing with a single function pointer (not a VTable) + assert(Entry.Size == PtrSize && "Global not a function pointer?"); + auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back(); + void *Ptr; + if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName, &Ptr)) + return error::createOffloadError(error::ErrorCode::INVALID_BINARY, + "failed to load %s", Entry.SymbolName); + + HstPtr = Entry.Address; + if (Device.retrieveData(&DevPtr, Ptr, Entry.Size, AsyncInfo)) + return error::createOffloadError(error::ErrorCode::INVALID_BINARY, + "failed to load %s", Entry.SymbolName); + } + if (Device.synchronize(AsyncInfo)) + return error::createOffloadError( + error::ErrorCode::INVALID_BINARY, + "failed to synchronize after retrieving %s", Entry.SymbolName); } // If we do not have any indirect globals we exit early. diff --git a/offload/test/api/omp_indirect_call_table_manual.c b/offload/test/api/omp_indirect_call_table_manual.c new file mode 100644 index 0000000000000..e958d47d69dad --- /dev/null +++ b/offload/test/api/omp_indirect_call_table_manual.c @@ -0,0 +1,107 @@ +// RUN: %libomptarget-compile-run-and-check-generic +#include +#include +#include + +// --------------------------------------------------------------------------- +// Various definitions copied from OpenMP RTL + +typedef struct { + uint64_t Reserved; + uint16_t Version; + uint16_t Kind; // OpenMP==1 + uint32_t Flags; + void *Address; + char *SymbolName; + uint64_t Size; + uint64_t Data; + void *AuxAddr; +} __tgt_offload_entry; + +enum OpenMPOffloadingDeclareTargetFlags { + /// Mark the entry global as having a 'link' attribute. + OMP_DECLARE_TARGET_LINK = 0x01, + /// Mark the entry global as being an indirectly callable function. + OMP_DECLARE_TARGET_INDIRECT = 0x08, + /// This is an entry corresponding to a requirement to be registered. + OMP_REGISTER_REQUIRES = 0x10, + /// Mark the entry global as being an indirect vtable. + OMP_DECLARE_TARGET_INDIRECT_VTABLE = 0x20, +}; + +#pragma omp begin declare variant match(device = {kind(gpu)}) +// Provided by the runtime. +void *__llvm_omp_indirect_call_lookup(void *host_ptr); +#pragma omp declare target to(__llvm_omp_indirect_call_lookup) \ + device_type(nohost) +#pragma omp end declare variant + +#pragma omp begin declare variant match(device = {kind(cpu)}) +// We assume unified addressing on the CPU target. +void *__llvm_omp_indirect_call_lookup(void *host_ptr) { return host_ptr; } +#pragma omp end declare variant + +#pragma omp begin declare target +void foo(int *i) { *i += 1; } +void bar(int *i) { *i += 10; } +void baz(int *i) { *i += 100; } +#pragma omp end declare target + +typedef void (*fptr_t)(int *i); + +// Dispatch Table - declare separately on host and device to avoid +// registering with the library; this also allows us to use separate +// names, which is convenient for debugging. This dispatchTable is +// intended to mimic what Clang emits for C++ vtables. +fptr_t dispatchTable[] = {foo, bar, baz}; +#pragma omp begin declare target device_type(nohost) +fptr_t GPUdispatchTable[] = {foo, bar, baz}; +fptr_t *GPUdispatchTablePtr = GPUdispatchTable; +#pragma omp end declare target + +// Define "manual" OpenMP offload entries, where we emit Clang +// offloading entry structure definitions in the appropriate ELF +// section. This allows us to emulate the offloading entries that Clang would +// normally emit for us + +__attribute__((weak, section("llvm_offload_entries"), aligned(8))) +const __tgt_offload_entry __offloading_entry[] = {{ + 0ULL, // Reserved + 1, // Version + 1, // Kind + OMP_DECLARE_TARGET_INDIRECT_VTABLE, // Flags + &dispatchTable, // Address + "GPUdispatchTablePtr", // SymbolName + (size_t)(sizeof(dispatchTable)), // Size + 0ULL, // Data + NULL // AuxAddr +}}; + +// Mimic how Clang emits vtable pointers for C++ classes +typedef struct { + fptr_t *dispatchPtr; +} myClass; + +// --------------------------------------------------------------------------- +int main() { + myClass obj_foo = {dispatchTable + 0}; + myClass obj_bar = {dispatchTable + 1}; + myClass obj_baz = {dispatchTable + 2}; + int aaa = 0; + +#pragma omp target map(aaa) map(to : obj_foo, obj_bar, obj_baz) + { + // Lookup + fptr_t *foo_ptr = __llvm_omp_indirect_call_lookup(obj_foo.dispatchPtr); + fptr_t *bar_ptr = __llvm_omp_indirect_call_lookup(obj_bar.dispatchPtr); + fptr_t *baz_ptr = __llvm_omp_indirect_call_lookup(obj_baz.dispatchPtr); + foo_ptr[0](&aaa); + bar_ptr[0](&aaa); + baz_ptr[0](&aaa); + } + + assert(aaa == 111); + // CHECK: PASS + printf("PASS\n"); + return 0; +}