diff --git a/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp b/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp index eb4a46c35a9b7e..dc563ee40f7bfc 100644 --- a/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp +++ b/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp @@ -8,10 +8,10 @@ using core::atl_is_atmi_initialized; -atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place, - const char *symbol, - void **var_addr, - unsigned int *var_size) { +atmi_status_t atmi_interop_hsa_get_symbol_info( + const std::map &SymbolInfoTable, + atmi_mem_place_t place, const char *symbol, void **var_addr, + unsigned int *var_size) { /* // Typical usage: void *var_addr; @@ -32,9 +32,9 @@ atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place, // get the symbol info std::string symbolStr = std::string(symbol); - if (SymbolInfoTable[place.dev_id].find(symbolStr) != - SymbolInfoTable[place.dev_id].end()) { - atl_symbol_info_t info = SymbolInfoTable[place.dev_id][symbolStr]; + auto It = SymbolInfoTable.find(symbolStr); + if (It != SymbolInfoTable.end()) { + atl_symbol_info_t info = It->second; *var_addr = reinterpret_cast(info.addr); *var_size = info.size; return ATMI_STATUS_SUCCESS; @@ -46,6 +46,7 @@ atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place, } atmi_status_t atmi_interop_hsa_get_kernel_info( + const std::map &KernelInfoTable, atmi_mem_place_t place, const char *kernel_name, hsa_executable_symbol_info_t kernel_info, uint32_t *value) { /* @@ -68,9 +69,9 @@ atmi_status_t atmi_interop_hsa_get_kernel_info( atmi_status_t status = ATMI_STATUS_SUCCESS; // get the kernel info std::string kernelStr = std::string(kernel_name); - if (KernelInfoTable[place.dev_id].find(kernelStr) != - KernelInfoTable[place.dev_id].end()) { - atl_kernel_info_t info = KernelInfoTable[place.dev_id][kernelStr]; + auto It = KernelInfoTable.find(kernelStr); + if (It != KernelInfoTable.end()) { + atl_kernel_info_t info = It->second; switch (kernel_info) { case HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE: *value = info.group_segment_size; diff --git a/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h b/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h index c0f588215e8a27..20da1173a8dbae 100644 --- a/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h +++ b/openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h @@ -9,6 +9,10 @@ #include "atmi_runtime.h" #include "hsa.h" #include "hsa_ext_amd.h" +#include "internal.h" + +#include +#include #ifdef __cplusplus extern "C" { @@ -44,11 +48,10 @@ extern "C" { * * @retval ::ATMI_STATUS_UNKNOWN The function encountered errors. */ -atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place, - const char *symbol, - void **var_addr, - unsigned int *var_size); - +atmi_status_t atmi_interop_hsa_get_symbol_info( + const std::map &SymbolInfoTable, + atmi_mem_place_t place, const char *symbol, void **var_addr, + unsigned int *var_size); /** * @brief Get the HSA-specific kernel info from a kernel name * @@ -75,8 +78,10 @@ atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place, * @retval ::ATMI_STATUS_UNKNOWN The function encountered errors. */ atmi_status_t atmi_interop_hsa_get_kernel_info( + const std::map &KernelInfoTable, atmi_mem_place_t place, const char *kernel_name, hsa_executable_symbol_info_t info, uint32_t *value); + /** @} */ #ifdef __cplusplus diff --git a/openmp/libomptarget/plugins/amdgpu/impl/internal.h b/openmp/libomptarget/plugins/amdgpu/impl/internal.h index ef068398734987..98d9ee487fe96f 100644 --- a/openmp/libomptarget/plugins/amdgpu/impl/internal.h +++ b/openmp/libomptarget/plugins/amdgpu/impl/internal.h @@ -106,9 +106,6 @@ typedef struct atl_symbol_info_s { uint32_t size; } atl_symbol_info_t; -extern std::vector> KernelInfoTable; -extern std::vector> SymbolInfoTable; - // ---------------------- Kernel End ------------- namespace core { diff --git a/openmp/libomptarget/plugins/amdgpu/impl/system.cpp b/openmp/libomptarget/plugins/amdgpu/impl/system.cpp index f3a7d20be0ddd9..ac171022bad417 100644 --- a/openmp/libomptarget/plugins/amdgpu/impl/system.cpp +++ b/openmp/libomptarget/plugins/amdgpu/impl/system.cpp @@ -146,9 +146,6 @@ ATLMachine g_atl_machine; std::vector atl_gpu_kernarg_pools; -std::vector> KernelInfoTable; -std::vector> SymbolInfoTable; - bool g_atmi_initialized = false; /* @@ -208,15 +205,6 @@ atmi_status_t Runtime::Initialize() { atmi_status_t Runtime::Finalize() { atmi_status_t rc = ATMI_STATUS_SUCCESS; - for (uint32_t i = 0; i < SymbolInfoTable.size(); i++) { - SymbolInfoTable[i].clear(); - } - SymbolInfoTable.clear(); - for (uint32_t i = 0; i < KernelInfoTable.size(); i++) { - KernelInfoTable[i].clear(); - } - KernelInfoTable.clear(); - atl_reset_atmi_initialized(); hsa_status_t err = hsa_shut_down(); if (err != HSA_STATUS_SUCCESS) { @@ -556,13 +544,6 @@ hsa_status_t init_hsa() { return err; } - int gpu_count = g_atl_machine.processorCount(); - KernelInfoTable.resize(gpu_count); - SymbolInfoTable.resize(gpu_count); - for (uint32_t i = 0; i < SymbolInfoTable.size(); i++) - SymbolInfoTable[i].clear(); - for (uint32_t i = 0; i < KernelInfoTable.size(); i++) - KernelInfoTable[i].clear(); atlc.g_hsa_initialized = true; DEBUG_PRINT("done\n"); } @@ -835,8 +816,9 @@ int populate_kernelArgMD(msgpack::byte_range args_element, } } // namespace -static hsa_status_t get_code_object_custom_metadata(void *binary, - size_t binSize, int gpu) { +static hsa_status_t get_code_object_custom_metadata( + void *binary, size_t binSize, int gpu, + std::map &KernelInfoTable) { // parse code object with different keys from v2 // also, the kernel name is not the same as the symbol name -- so a // symbol->name map is needed @@ -1003,14 +985,16 @@ static hsa_status_t get_code_object_custom_metadata(void *binary, kernel_segment_size, info.kernel_segment_size); // kernel received, now add it to the kernel info table - KernelInfoTable[gpu][kernelName] = info; + KernelInfoTable[kernelName] = info; } return HSA_STATUS_SUCCESS; } -static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol, - int gpu) { +static hsa_status_t +populate_InfoTables(hsa_executable_symbol_t symbol, int gpu, + std::map &KernelInfoTable, + std::map &SymbolInfoTable) { hsa_symbol_kind_t type; uint32_t name_length; @@ -1047,11 +1031,16 @@ static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol, // by now, the kernel info table should already have an entry // because the non-ROCr custom code object parsing is called before // iterating over the code object symbols using ROCr - if (KernelInfoTable[gpu].find(kernelName) == KernelInfoTable[gpu].end()) { - return HSA_STATUS_ERROR; + if (KernelInfoTable.find(kernelName) == KernelInfoTable.end()) { + if (HSA_STATUS_ERROR_INVALID_CODE_OBJECT != HSA_STATUS_SUCCESS) { + printf("[%s:%d] %s failed: %s\n", __FILE__, __LINE__, + "Finding the entry kernel info table", + get_error_string(HSA_STATUS_ERROR_INVALID_CODE_OBJECT)); + exit(1); + } } // found, so assign and update - info = KernelInfoTable[gpu][kernelName]; + info = KernelInfoTable[kernelName]; /* Extract dispatch information from the symbol */ err = hsa_executable_symbol_get_info( @@ -1089,7 +1078,7 @@ static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol, info.private_segment_size, info.kernel_segment_size); // assign it back to the kernel info table - KernelInfoTable[gpu][kernelName] = info; + KernelInfoTable[kernelName] = info; free(name); } else if (type == HSA_SYMBOL_KIND_VARIABLE) { err = hsa_executable_symbol_get_info( @@ -1135,7 +1124,7 @@ static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol, if (err != HSA_STATUS_SUCCESS) { return err; } - SymbolInfoTable[gpu][std::string(name)] = info; + SymbolInfoTable[std::string(name)] = info; free(name); } else { DEBUG_PRINT("Symbol is an indirect function\n"); @@ -1143,7 +1132,9 @@ static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol, return HSA_STATUS_SUCCESS; } -atmi_status_t Runtime::RegisterModuleFromMemory( +atmi_status_t RegisterModuleFromMemory( + std::map &KernelInfoTable, + std::map &SymbolInfoTable, void *module_bytes, size_t module_size, atmi_place_t place, atmi_status_t (*on_deserialized_data)(void *data, size_t size, void *cb_state), @@ -1183,7 +1174,8 @@ atmi_status_t Runtime::RegisterModuleFromMemory( // Some metadata info is not available through ROCr API, so use custom // code object metadata parsing to collect such metadata info - err = get_code_object_custom_metadata(module_bytes, module_size, gpu); + err = get_code_object_custom_metadata(module_bytes, module_size, gpu, + KernelInfoTable); if (err != HSA_STATUS_SUCCESS) { DEBUG_PRINT("[%s:%d] %s failed: %s\n", __FILE__, __LINE__, "Getting custom code object metadata", @@ -1240,9 +1232,9 @@ atmi_status_t Runtime::RegisterModuleFromMemory( err = hsa::executable_iterate_symbols( executable, [&](hsa_executable_t, hsa_executable_symbol_t symbol) -> hsa_status_t { - return populate_InfoTables(symbol, gpu); + return populate_InfoTables(symbol, gpu, KernelInfoTable, + SymbolInfoTable); }); - if (err != HSA_STATUS_SUCCESS) { printf("[%s:%d] %s failed: %s\n", __FILE__, __LINE__, "Iterating over symbols for execuatable", get_error_string(err)); diff --git a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp index b67f3cf45023b7..4883288e0725c8 100644 --- a/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp +++ b/openmp/libomptarget/plugins/amdgpu/src/rtl.cpp @@ -86,6 +86,16 @@ int print_kernel_trace; #include "elf_common.h" +namespace core { +atmi_status_t RegisterModuleFromMemory( + std::map &KernelInfo, + std::map &SymbolInfoTable, void *, size_t, + atmi_place_t, + atmi_status_t (*on_deserialized_data)(void *data, size_t size, + void *cb_state), + void *cb_state, std::vector &HSAExecutables); +} + /// Keep entries table per device struct FuncOrGblEntryTy { __tgt_target_table Table; @@ -339,6 +349,9 @@ class RTLDeviceInfoTy { std::vector HSAExecutables; + std::vector> KernelInfoTable; + std::vector> SymbolInfoTable; + struct atmiFreePtrDeletor { void operator()(void *p) { atmi_free(p); // ignore failure to free @@ -482,6 +495,8 @@ class RTLDeviceInfoTy { NumTeams.resize(NumberOfDevices); NumThreads.resize(NumberOfDevices); deviceStateStore.resize(NumberOfDevices); + KernelInfoTable.resize(NumberOfDevices); + SymbolInfoTable.resize(NumberOfDevices); for (int i = 0; i < NumberOfDevices; i++) { HSAQueues[i] = nullptr; @@ -993,15 +1008,17 @@ atmi_status_t interop_get_symbol_info(char *base, size_t img_size, template atmi_status_t module_register_from_memory_to_place( + std::map &KernelInfoTable, + std::map &SymbolInfoTable, void *module_bytes, size_t module_size, atmi_place_t place, C cb, std::vector &HSAExecutables) { auto L = [](void *data, size_t size, void *cb_state) -> atmi_status_t { C *unwrapped = static_cast(cb_state); return (*unwrapped)(data, size); }; - return core::Runtime::RegisterModuleFromMemory( - module_bytes, module_size, place, L, static_cast(&cb), - HSAExecutables); + return core::RegisterModuleFromMemory( + KernelInfoTable, SymbolInfoTable, module_bytes, module_size, place, L, + static_cast(&cb), HSAExecutables); } } // namespace @@ -1116,11 +1133,12 @@ struct device_environment { DP("Setting global device environment after load (%u bytes)\n", si.size); int device_id = host_device_env.device_num; - + auto &SymbolInfo = DeviceInfo.SymbolInfoTable[device_id]; void *state_ptr; uint32_t state_ptr_size; atmi_status_t err = atmi_interop_hsa_get_symbol_info( - get_gpu_mem_place(device_id), sym(), &state_ptr, &state_ptr_size); + SymbolInfo, get_gpu_mem_place(device_id), sym(), &state_ptr, + &state_ptr_size); if (err != ATMI_STATUS_SUCCESS) { DP("failed to find %s in loaded image\n", sym()); return err; @@ -1205,8 +1223,11 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id, auto env = device_environment(device_id, DeviceInfo.NumberOfDevices, image, img_size); + auto &KernelInfo = DeviceInfo.KernelInfoTable[device_id]; + auto &SymbolInfo = DeviceInfo.SymbolInfoTable[device_id]; atmi_status_t err = module_register_from_memory_to_place( - (void *)image->ImageStart, img_size, get_gpu_place(device_id), + KernelInfo, SymbolInfo, (void *)image->ImageStart, img_size, + get_gpu_place(device_id), [&](void *data, size_t size) { if (image_contains_symbol(data, size, "needs_hostcall_buffer")) { __atomic_store_n(&DeviceInfo.hostcall_required, true, @@ -1241,9 +1262,10 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id, void *state_ptr; uint32_t state_ptr_size; + auto &SymbolInfoMap = DeviceInfo.SymbolInfoTable[device_id]; atmi_status_t err = atmi_interop_hsa_get_symbol_info( - get_gpu_mem_place(device_id), "omptarget_nvptx_device_State", - &state_ptr, &state_ptr_size); + SymbolInfoMap, get_gpu_mem_place(device_id), + "omptarget_nvptx_device_State", &state_ptr, &state_ptr_size); if (err != ATMI_STATUS_SUCCESS) { DP("No device_state symbol found, skipping initialization\n"); @@ -1325,8 +1347,10 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id, void *varptr; uint32_t varsize; + auto &SymbolInfoMap = DeviceInfo.SymbolInfoTable[device_id]; atmi_status_t err = atmi_interop_hsa_get_symbol_info( - get_gpu_mem_place(device_id), e->name, &varptr, &varsize); + SymbolInfoMap, get_gpu_mem_place(device_id), e->name, &varptr, + &varsize); if (err != ATMI_STATUS_SUCCESS) { // Inform the user what symbol prevented offloading @@ -1367,8 +1391,10 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id, atmi_mem_place_t place = get_gpu_mem_place(device_id); uint32_t kernarg_segment_size; + auto &KernelInfoMap = DeviceInfo.KernelInfoTable[device_id]; atmi_status_t err = atmi_interop_hsa_get_kernel_info( - place, e->name, HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, + KernelInfoMap, place, e->name, + HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, &kernarg_segment_size); // each arg is a void * in this openmp implementation @@ -1794,6 +1820,7 @@ int32_t __tgt_rtl_run_target_team_region_locked( KernelTy *KernelInfo = (KernelTy *)tgt_entry_ptr; std::string kernel_name = std::string(KernelInfo->Name); + auto &KernelInfoTable = DeviceInfo.KernelInfoTable; if (KernelInfoTable[device_id].find(kernel_name) == KernelInfoTable[device_id].end()) { DP("Kernel %s not found\n", kernel_name.c_str());