From 5c38bcc7068f486733584bd251c9596059c31b7f Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 5 Jun 2023 17:29:11 -0400 Subject: [PATCH] [cuda] Implement basics for a CUDA HAL driver rewrite (#13942) This commit starts a CUDA HAL driver rewrite under `experimental/`. We create a new `cuda2/` directory to host the new code to avoid interrupting the current CodeGen development. This commit just brings in the basics for boot up a new HAL driver, including dynamic symbols management, error status management, and IREE HAL driver implementation. Most of the code is directly copied from existing HAL driver, with noticeable changes: * Split CUDA and NCCL dynamic symbols into separate structures for better organziation and allowing optionality. * Fleshed out CUDA error to IREE status conversions. * Better organized code blocks and improved error messages and various comments. Building this commmit with `-DIREE_EXTERNAL_HAL_DRIVERS=cuda2`, we can have `tools/iree-run-module --dump_devices` showing `cuda2` devices, in parallel to the existing CUDA one. Progress towards https://github.com/openxla/iree/issues/13245 --- CMakeLists.txt | 9 + experimental/cuda2/CMakeLists.txt | 64 +++ experimental/cuda2/api.h | 47 ++ experimental/cuda2/cuda_driver.c | 450 ++++++++++++++++++ experimental/cuda2/cuda_headers.h | 13 + experimental/cuda2/dynamic_symbol_tables.h | 116 +++++ experimental/cuda2/dynamic_symbols.c | 229 +++++++++ experimental/cuda2/dynamic_symbols.h | 97 ++++ experimental/cuda2/dynamic_symbols_test.cc | 84 ++++ .../cuda2/registration/CMakeLists.txt | 23 + .../cuda2/registration/driver_module.c | 106 +++++ .../cuda2/registration/driver_module.h | 25 + experimental/cuda2/status_util.c | 185 +++++++ experimental/cuda2/status_util.h | 108 +++++ 14 files changed, 1556 insertions(+) create mode 100644 experimental/cuda2/CMakeLists.txt create mode 100644 experimental/cuda2/api.h create mode 100644 experimental/cuda2/cuda_driver.c create mode 100644 experimental/cuda2/cuda_headers.h create mode 100644 experimental/cuda2/dynamic_symbol_tables.h create mode 100644 experimental/cuda2/dynamic_symbols.c create mode 100644 experimental/cuda2/dynamic_symbols.h create mode 100644 experimental/cuda2/dynamic_symbols_test.cc create mode 100644 experimental/cuda2/registration/CMakeLists.txt create mode 100644 experimental/cuda2/registration/driver_module.c create mode 100644 experimental/cuda2/registration/driver_module.h create mode 100644 experimental/cuda2/status_util.c create mode 100644 experimental/cuda2/status_util.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 73ba986b8f22..034a6d274122 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -308,6 +308,15 @@ if(IREE_HAL_EXECUTABLE_PLUGIN_SYSTEM_LIBRARY) message(STATUS " - system-library") endif() +#------------------------------------------------------------------------------- +# Experimental next-generation CUDA HAL driver +#------------------------------------------------------------------------------- + +set(IREE_EXTERNAL_CUDA2_HAL_DRIVER_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/experimental/cuda2") +set(IREE_EXTERNAL_CUDA2_HAL_DRIVER_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/experimental/cuda2") +set(IREE_EXTERNAL_CUDA2_HAL_DRIVER_TARGET "iree::experimental::cuda2::registration") +set(IREE_EXTERNAL_CUDA2_HAL_DRIVER_REGISTER "iree_hal_cuda2_driver_module_register") + #------------------------------------------------------------------------------- # Experimental ROCM HAL driver #------------------------------------------------------------------------------- diff --git a/experimental/cuda2/CMakeLists.txt b/experimental/cuda2/CMakeLists.txt new file mode 100644 index 000000000000..dcb0e8a08b27 --- /dev/null +++ b/experimental/cuda2/CMakeLists.txt @@ -0,0 +1,64 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Set the root for package namespacing to the current directory. +set(IREE_PACKAGE_ROOT_DIR "${CMAKE_CURRENT_LIST_DIR}") +set(IREE_PACKAGE_ROOT_PREFIX "iree/experimental/cuda2") + +iree_add_all_subdirs() + +iree_cc_library( + NAME + cuda2 + HDRS + "api.h" + SRCS + "api.h" + "cuda_driver.c" + DEPS + ::dynamic_symbols + iree::base + iree::base::core_headers + iree::base::tracing + iree::hal + iree::schemas::cuda_executable_def_c_fbs + PUBLIC +) + +iree_cc_library( + NAME + dynamic_symbols + HDRS + "dynamic_symbols.h" + "status_util.h" + TEXTUAL_HDRS + "dynamic_symbol_tables.h" + SRCS + "cuda_headers.h" + "dynamic_symbols.c" + "status_util.c" + DEPS + iree::base + iree::base::core_headers + iree::base::internal::dynamic_library + iree::base::tracing + iree_cuda::headers + PUBLIC +) + +iree_cc_test( + NAME + dynamic_symbols_test + SRCS + "dynamic_symbols_test.cc" + DEPS + ::dynamic_symbols + iree::base + iree::testing::gtest + iree::testing::gtest_main + LABELS + "driver=cuda2" +) diff --git a/experimental/cuda2/api.h b/experimental/cuda2/api.h new file mode 100644 index 000000000000..565fd50ca972 --- /dev/null +++ b/experimental/cuda2/api.h @@ -0,0 +1,47 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// See iree/base/api.h for documentation on the API conventions used. + +#ifndef IREE_EXPERIMENTAL_CUDA2_API_H_ +#define IREE_EXPERIMENTAL_CUDA2_API_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree_hal_cuda2_driver_t +//===----------------------------------------------------------------------===// + +// CUDA HAL driver creation options. +typedef struct iree_hal_cuda2_driver_options_t { + // The index of the default CUDA device to use within the list of available + // devices. + int default_device_index; +} iree_hal_cuda2_driver_options_t; + +// Initializes the given |out_options| with default driver creation options. +IREE_API_EXPORT void iree_hal_cuda2_driver_options_initialize( + iree_hal_cuda2_driver_options_t* out_options); + +// Creates a CUDA HAL driver with the given |options|, from which CUDA devices +// can be enumerated and created with specific parameters. +// +// |out_driver| must be released by the caller (see iree_hal_driver_release). +IREE_API_EXPORT iree_status_t iree_hal_cuda2_driver_create( + iree_string_view_t identifier, + const iree_hal_cuda2_driver_options_t* options, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_CUDA2_API_H_ diff --git a/experimental/cuda2/cuda_driver.c b/experimental/cuda2/cuda_driver.c new file mode 100644 index 000000000000..b88cd9ec4f28 --- /dev/null +++ b/experimental/cuda2/cuda_driver.c @@ -0,0 +1,450 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "experimental/cuda2/api.h" +#include "experimental/cuda2/dynamic_symbols.h" +#include "experimental/cuda2/status_util.h" +#include "iree/base/api.h" +#include "iree/base/assert.h" +#include "iree/base/tracing.h" +#include "iree/hal/api.h" + +// Maximum device name length supported by the CUDA HAL driver. +#define IREE_HAL_CUDA_MAX_DEVICE_NAME_LENGTH 128 + +// Utility macros to convert between CUDevice and iree_hal_device_id_t. +#define IREE_CUDEVICE_TO_DEVICE_ID(device) (iree_hal_device_id_t)((device) + 1) +#define IREE_DEVICE_ID_TO_CUDEVICE(device_id) (CUdevice)((device_id)-1) + +typedef struct iree_hal_cuda2_driver_t { + // Abstract resource used for injecting reference counting and vtable; + // must be at offset 0. + iree_hal_resource_t resource; + + iree_allocator_t host_allocator; + + // Identifier used for registering the driver in the IREE driver registry. + iree_string_view_t identifier; + // CUDA driver API dynamic symbols to interact with the CUDA system. + iree_hal_cuda2_dynamic_symbols_t cuda_symbols; + // NCCL API dynamic symbols to interact with the CUDA system. + iree_hal_cuda2_nccl_dynamic_symbols_t nccl_symbols; + + // The index of the default CUDA device to use if multiple ones are available. + int default_device_index; +} iree_hal_cuda2_driver_t; + +static const iree_hal_driver_vtable_t iree_hal_cuda2_driver_vtable; + +static iree_hal_cuda2_driver_t* iree_hal_cuda2_driver_cast( + iree_hal_driver_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda2_driver_vtable); + return (iree_hal_cuda2_driver_t*)base_value; +} + +IREE_API_EXPORT void iree_hal_cuda2_driver_options_initialize( + iree_hal_cuda2_driver_options_t* out_options) { + IREE_ASSERT_ARGUMENT(out_options); + memset(out_options, 0, sizeof(*out_options)); + out_options->default_device_index = 0; +} + +static iree_status_t iree_hal_cuda2_driver_create_internal( + iree_string_view_t identifier, + const iree_hal_cuda2_driver_options_t* options, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + iree_hal_cuda2_driver_t* driver = NULL; + iree_host_size_t total_size = iree_sizeof_struct(*driver) + identifier.size; + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(host_allocator, total_size, (void**)&driver)); + + iree_hal_resource_initialize(&iree_hal_cuda2_driver_vtable, + &driver->resource); + driver->host_allocator = host_allocator; + iree_string_view_append_to_buffer( + identifier, &driver->identifier, + (char*)driver + iree_sizeof_struct(*driver)); + driver->default_device_index = options->default_device_index; + + iree_status_t status = iree_hal_cuda2_dynamic_symbols_initialize( + host_allocator, &driver->cuda_symbols); + + if (iree_status_is_ok(status)) { + // Try to dynamically load NCCL. This will fail if NCCL is unavailable or + // incompatible. We only fail on unavailability when the user tries to + // create a channel and otherwise defer reporting. + status = iree_hal_cuda2_nccl_dynamic_symbols_initialize( + host_allocator, &driver->cuda_symbols, &driver->nccl_symbols); + if (iree_status_is_unavailable(status)) status = iree_status_ignore(status); + } + + if (iree_status_is_ok(status)) { + *out_driver = (iree_hal_driver_t*)driver; + } else { + iree_hal_driver_release((iree_hal_driver_t*)driver); + } + return status; +} + +IREE_API_EXPORT iree_status_t iree_hal_cuda2_driver_create( + iree_string_view_t identifier, + const iree_hal_cuda2_driver_options_t* options, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(options); + IREE_ASSERT_ARGUMENT(out_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_hal_cuda2_driver_create_internal( + identifier, options, host_allocator, out_driver); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_cuda2_driver_destroy(iree_hal_driver_t* base_driver) { + IREE_ASSERT_ARGUMENT(base_driver); + + iree_hal_cuda2_driver_t* driver = iree_hal_cuda2_driver_cast(base_driver); + iree_allocator_t host_allocator = driver->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_cuda2_nccl_dynamic_symbols_deinitialize(&driver->nccl_symbols); + iree_hal_cuda2_dynamic_symbols_deinitialize(&driver->cuda_symbols); + iree_allocator_free(host_allocator, driver); + + IREE_TRACE_ZONE_END(z0); +} + +// Initializes the CUDA system. +static iree_status_t iree_hal_cuda2_init(iree_hal_cuda2_driver_t* driver) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + IREE_CURESULT_TO_STATUS(&driver->cuda_symbols, cuInit(0), "cuInit"); + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Populates device information from the given CUDA physical device handle. +// |out_device_info| must point to valid memory and additional data will be +// appended to |buffer_ptr| and the new pointer is returned. +static iree_status_t iree_hal_cuda2_populate_device_info( + CUdevice device, iree_hal_cuda2_dynamic_symbols_t* syms, + uint8_t* buffer_ptr, uint8_t** out_buffer_ptr, + iree_hal_device_info_t* out_device_info) { + *out_buffer_ptr = buffer_ptr; + + char device_name[IREE_HAL_CUDA_MAX_DEVICE_NAME_LENGTH]; + IREE_CUDA_RETURN_IF_ERROR( + syms, cuDeviceGetName(device_name, sizeof(device_name), device), + "cuDeviceGetName"); + memset(out_device_info, 0, sizeof(*out_device_info)); + out_device_info->device_id = IREE_CUDEVICE_TO_DEVICE_ID(device); + + // This matches the output of `nvidia-smi -L`. + CUuuid device_uuid; + IREE_CUDA_RETURN_IF_ERROR(syms, cuDeviceGetUuid(&device_uuid, device), + "cuDeviceGetUuid"); + char device_path_str[4 + 36 + 1] = {0}; + snprintf(device_path_str, sizeof(device_path_str), + "GPU-" + "%02x%02x%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x%02x%02x%02x%02x", + (uint8_t)device_uuid.bytes[0], (uint8_t)device_uuid.bytes[1], + (uint8_t)device_uuid.bytes[2], (uint8_t)device_uuid.bytes[3], + (uint8_t)device_uuid.bytes[4], (uint8_t)device_uuid.bytes[5], + (uint8_t)device_uuid.bytes[6], (uint8_t)device_uuid.bytes[7], + (uint8_t)device_uuid.bytes[8], (uint8_t)device_uuid.bytes[9], + (uint8_t)device_uuid.bytes[10], (uint8_t)device_uuid.bytes[11], + (uint8_t)device_uuid.bytes[12], (uint8_t)device_uuid.bytes[13], + (uint8_t)device_uuid.bytes[14], (uint8_t)device_uuid.bytes[15]); + buffer_ptr += iree_string_view_append_to_buffer( + iree_make_string_view(device_path_str, + IREE_ARRAYSIZE(device_path_str) - 1), + &out_device_info->path, (char*)buffer_ptr); + + iree_string_view_t device_name_str = + iree_make_string_view(device_name, strlen(device_name)); + buffer_ptr += iree_string_view_append_to_buffer( + device_name_str, &out_device_info->name, (char*)buffer_ptr); + + *out_buffer_ptr = buffer_ptr; + return iree_ok_status(); +} + +// Returns true if the device meets all the required capabilities. +static bool iree_hal_cuda2_is_valid_device(iree_hal_cuda2_driver_t* driver, + CUdevice device) { + return true; +} + +static iree_status_t iree_hal_cuda2_driver_query_available_devices( + iree_hal_driver_t* base_driver, iree_allocator_t host_allocator, + iree_host_size_t* out_device_info_count, + iree_hal_device_info_t** out_device_infos) { + IREE_ASSERT_ARGUMENT(base_driver); + IREE_ASSERT_ARGUMENT(out_device_info_count); + IREE_ASSERT_ARGUMENT(out_device_infos); + iree_hal_cuda2_driver_t* driver = iree_hal_cuda2_driver_cast(base_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + // Ensure CUDA is initialized before querying it. + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_cuda2_init(driver)); + + // Query the number of available CUDA devices. + int device_count = 0; + IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(z0, &driver->cuda_symbols, + cuDeviceGetCount(&device_count), + "cuDeviceGetCount"); + + // Allocate the return infos and populate with the devices. + iree_hal_device_info_t* device_infos = NULL; + iree_host_size_t total_size = + device_count * (sizeof(iree_hal_device_info_t) + + IREE_HAL_CUDA_MAX_DEVICE_NAME_LENGTH * sizeof(char)); + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&device_infos); + + int valid_device_count = 0; + if (iree_status_is_ok(status)) { + uint8_t* buffer_ptr = + (uint8_t*)device_infos + device_count * sizeof(iree_hal_device_info_t); + for (iree_host_size_t i = 0; i < device_count; ++i) { + CUdevice device = 0; + status = IREE_CURESULT_TO_STATUS(&driver->cuda_symbols, + cuDeviceGet(&device, i), "cuDeviceGet"); + if (!iree_status_is_ok(status)) break; + if (!iree_hal_cuda2_is_valid_device(driver, device)) continue; + status = iree_hal_cuda2_populate_device_info( + device, &driver->cuda_symbols, buffer_ptr, &buffer_ptr, + &device_infos[valid_device_count]); + if (!iree_status_is_ok(status)) break; + valid_device_count++; + } + } + if (iree_status_is_ok(status)) { + *out_device_info_count = valid_device_count; + *out_device_infos = device_infos; + } else { + iree_allocator_free(host_allocator, device_infos); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_cuda2_driver_dump_device_info( + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + iree_string_builder_t* builder) { + iree_hal_cuda2_driver_t* driver = iree_hal_cuda2_driver_cast(base_driver); + CUdevice device = (CUdevice)device_id; + if (!device) return iree_ok_status(); + // TODO: dump detailed device info. + (void)driver; + (void)device; + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda2_driver_select_default_device( + iree_hal_driver_t* base_driver, iree_hal_cuda2_dynamic_symbols_t* syms, + int default_device_index, iree_allocator_t host_allocator, + CUdevice* out_device) { + iree_hal_device_info_t* device_infos = NULL; + iree_host_size_t device_count = 0; + IREE_RETURN_IF_ERROR(iree_hal_cuda2_driver_query_available_devices( + base_driver, host_allocator, &device_count, &device_infos)); + + iree_status_t status = iree_ok_status(); + if (device_count == 0) { + status = iree_make_status(IREE_STATUS_UNAVAILABLE, + "no compatible CUDA devices were found"); + } else if (default_device_index >= device_count) { + status = iree_make_status(IREE_STATUS_NOT_FOUND, + "default device %d not found (of %ld enumerated)", + default_device_index, device_count); + } else { + *out_device = IREE_DEVICE_ID_TO_CUDEVICE( + device_infos[default_device_index].device_id); + } + iree_allocator_free(host_allocator, device_infos); + + return status; +} + +static iree_status_t iree_hal_cuda2_driver_create_device_by_id( + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + iree_host_size_t param_count, const iree_string_pair_t* params, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(base_driver); + IREE_ASSERT_ARGUMENT(params); + IREE_ASSERT_ARGUMENT(out_device); + + iree_hal_cuda2_driver_t* driver = iree_hal_cuda2_driver_cast(base_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + // Ensure CUDA is initialized before querying it. + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_cuda2_init(driver)); + + // Use either the specified device (enumerated earlier) or whatever default + // one was specified when the driver was created. + CUdevice device = 0; + if (device_id == IREE_HAL_DEVICE_ID_DEFAULT) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_cuda2_driver_select_default_device( + base_driver, &driver->cuda_symbols, + driver->default_device_index, host_allocator, &device)); + } else { + device = IREE_DEVICE_ID_TO_CUDEVICE(device_id); + } + (void)device; + + IREE_TRACE_ZONE_END(z0); + return iree_status_from_code(IREE_STATUS_UNIMPLEMENTED); +} + +static iree_status_t iree_hal_cuda2_driver_create_device_by_uuid( + iree_hal_driver_t* base_driver, iree_string_view_t driver_name, + const CUuuid* device_uuid, iree_host_size_t param_count, + const iree_string_pair_t* params, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + iree_hal_cuda2_driver_t* driver = iree_hal_cuda2_driver_cast(base_driver); + + // Ensure CUDA is initialized before querying it. + IREE_RETURN_IF_ERROR(iree_hal_cuda2_init(driver)); + + // CUDA doesn't have an API to do this so we need to scan all devices to + // find the one with the matching UUID. + int device_count = 0; + IREE_CUDA_RETURN_IF_ERROR(&driver->cuda_symbols, + cuDeviceGetCount(&device_count), + "cuDeviceGetCount"); + CUdevice device = 0; + bool found_device = false; + for (int i = 0; i < device_count; i++) { + IREE_CUDA_RETURN_IF_ERROR(&driver->cuda_symbols, cuDeviceGet(&device, i), + "cuDeviceGet"); + CUuuid query_uuid; + IREE_CUDA_RETURN_IF_ERROR(&driver->cuda_symbols, + cuDeviceGetUuid(&query_uuid, device), + "cuDeviceGetUuid"); + if (memcmp(&device_uuid->bytes[0], &query_uuid.bytes[0], + sizeof(device_uuid)) == 0) { + found_device = true; + break; + } + } + if (!found_device) { + return iree_make_status( + IREE_STATUS_NOT_FOUND, + "CUDA device with UUID GPU-" + "%02x%02x%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x%02x%02x%02x%02x" + " not found", + (uint8_t)device_uuid->bytes[0], (uint8_t)device_uuid->bytes[1], + (uint8_t)device_uuid->bytes[2], (uint8_t)device_uuid->bytes[3], + (uint8_t)device_uuid->bytes[4], (uint8_t)device_uuid->bytes[5], + (uint8_t)device_uuid->bytes[6], (uint8_t)device_uuid->bytes[7], + (uint8_t)device_uuid->bytes[8], (uint8_t)device_uuid->bytes[9], + (uint8_t)device_uuid->bytes[10], (uint8_t)device_uuid->bytes[11], + (uint8_t)device_uuid->bytes[12], (uint8_t)device_uuid->bytes[13], + (uint8_t)device_uuid->bytes[14], (uint8_t)device_uuid->bytes[15]); + } + + iree_status_t status = iree_hal_cuda2_driver_create_device_by_id( + base_driver, IREE_CUDEVICE_TO_DEVICE_ID(device), param_count, params, + host_allocator, out_device); + + return status; +} + +static iree_status_t iree_hal_cuda2_driver_create_device_by_index( + iree_hal_driver_t* base_driver, iree_string_view_t driver_name, + int device_index, iree_host_size_t param_count, + const iree_string_pair_t* params, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + iree_hal_cuda2_driver_t* driver = iree_hal_cuda2_driver_cast(base_driver); + + // Ensure CUDA is initialized before querying it. + IREE_RETURN_IF_ERROR(iree_hal_cuda2_init(driver)); + + // Query the number of available CUDA devices. + int device_count = 0; + IREE_CUDA_RETURN_IF_ERROR(&driver->cuda_symbols, + cuDeviceGetCount(&device_count), + "cuDeviceGetCount"); + if (device_index >= device_count) { + return iree_make_status(IREE_STATUS_NOT_FOUND, + "device %d not found (of %d enumerated)", + device_index, device_count); + } + + CUdevice device = 0; + IREE_CUDA_RETURN_IF_ERROR(&driver->cuda_symbols, + cuDeviceGet(&device, device_index), "cuDeviceGet"); + + iree_status_t status = iree_hal_cuda2_driver_create_device_by_id( + base_driver, IREE_CUDEVICE_TO_DEVICE_ID(device), param_count, params, + host_allocator, out_device); + + return status; +} + +static iree_status_t iree_hal_cuda2_driver_create_device_by_path( + iree_hal_driver_t* base_driver, iree_string_view_t driver_name, + iree_string_view_t device_path, iree_host_size_t param_count, + const iree_string_pair_t* params, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(base_driver); + IREE_ASSERT_ARGUMENT(params); + IREE_ASSERT_ARGUMENT(out_device); + + if (iree_string_view_is_empty(device_path)) { + return iree_hal_cuda2_driver_create_device_by_id( + base_driver, IREE_HAL_DEVICE_ID_DEFAULT, param_count, params, + host_allocator, out_device); + } + + if (iree_string_view_consume_prefix(&device_path, IREE_SV("GPU-"))) { + // UUID as returned by cuDeviceGetUuid. + CUuuid device_uuid; + if (!iree_string_view_parse_hex_bytes(device_path, + IREE_ARRAYSIZE(device_uuid.bytes), + (uint8_t*)device_uuid.bytes)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid GPU UUID: '%.*s'", (int)device_path.size, + device_path.data); + } + return iree_hal_cuda2_driver_create_device_by_uuid( + base_driver, driver_name, &device_uuid, param_count, params, + host_allocator, out_device); + } + + // Try to parse as a device index. + int device_index = 0; + if (iree_string_view_atoi_int32(device_path, &device_index)) { + return iree_hal_cuda2_driver_create_device_by_index( + base_driver, driver_name, device_index, param_count, params, + host_allocator, out_device); + } + + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unsupported device path"); +} + +static const iree_hal_driver_vtable_t iree_hal_cuda2_driver_vtable = { + .destroy = iree_hal_cuda2_driver_destroy, + .query_available_devices = iree_hal_cuda2_driver_query_available_devices, + .dump_device_info = iree_hal_cuda2_driver_dump_device_info, + .create_device_by_id = iree_hal_cuda2_driver_create_device_by_id, + .create_device_by_path = iree_hal_cuda2_driver_create_device_by_path, +}; diff --git a/experimental/cuda2/cuda_headers.h b/experimental/cuda2/cuda_headers.h new file mode 100644 index 000000000000..6e3a6219d66a --- /dev/null +++ b/experimental/cuda2/cuda_headers.h @@ -0,0 +1,13 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_CUDA2_CUDA_HEADERS_H_ +#define IREE_EXPERIMENTAL_CUDA2_CUDA_HEADERS_H_ + +#include "cuda.h" // IWYU pragma: export +#include "third_party/nccl/nccl.h" // IWYU pragma: export + +#endif // IREE_EXPERIMENTAL_CUDA2_CUDA_HEADERS_H_ diff --git a/experimental/cuda2/dynamic_symbol_tables.h b/experimental/cuda2/dynamic_symbol_tables.h new file mode 100644 index 000000000000..f16031b432d1 --- /dev/null +++ b/experimental/cuda2/dynamic_symbol_tables.h @@ -0,0 +1,116 @@ +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// CUDA symbols +//===----------------------------------------------------------------------===// +IREE_CU_PFN_DECL(cuCtxCreate, CUcontext*, unsigned int, CUdevice) +IREE_CU_PFN_DECL(cuCtxDestroy, CUcontext) +IREE_CU_PFN_DECL(cuDevicePrimaryCtxRetain, CUcontext*, CUdevice) +IREE_CU_PFN_DECL(cuDevicePrimaryCtxRelease, CUdevice) +IREE_CU_PFN_DECL(cuCtxSetCurrent, CUcontext) +IREE_CU_PFN_DECL(cuCtxPushCurrent, CUcontext) +IREE_CU_PFN_DECL(cuCtxPopCurrent, CUcontext*) +IREE_CU_PFN_DECL(cuDeviceGet, CUdevice*, int) +IREE_CU_PFN_DECL(cuDeviceGetCount, int*) +IREE_CU_PFN_DECL(cuDeviceGetName, char*, int, CUdevice) +IREE_CU_PFN_DECL(cuDeviceGetAttribute, int*, CUdevice_attribute, CUdevice) +IREE_CU_PFN_DECL(cuDeviceGetUuid, CUuuid*, CUdevice) +IREE_CU_PFN_DECL(cuEventCreate, CUevent*, unsigned int) +IREE_CU_PFN_DECL(cuEventDestroy, CUevent) +IREE_CU_PFN_DECL(cuEventElapsedTime, float*, CUevent, CUevent) +IREE_CU_PFN_DECL(cuEventQuery, CUevent) +IREE_CU_PFN_DECL(cuEventRecord, CUevent, CUstream) +IREE_CU_PFN_DECL(cuEventSynchronize, CUevent) +IREE_CU_PFN_DECL(cuGetErrorName, CUresult, const char**) +IREE_CU_PFN_DECL(cuGetErrorString, CUresult, const char**) +IREE_CU_PFN_DECL(cuGraphAddMemcpyNode, CUgraphNode*, CUgraph, + const CUgraphNode*, size_t, const CUDA_MEMCPY3D*, CUcontext) +IREE_CU_PFN_DECL(cuGraphAddMemsetNode, CUgraphNode*, CUgraph, + const CUgraphNode*, size_t, const CUDA_MEMSET_NODE_PARAMS*, + CUcontext) +IREE_CU_PFN_DECL(cuGraphAddKernelNode, CUgraphNode*, CUgraph, + const CUgraphNode*, size_t, const CUDA_KERNEL_NODE_PARAMS*) +IREE_CU_PFN_DECL(cuGraphCreate, CUgraph*, unsigned int) +IREE_CU_PFN_DECL(cuGraphDestroy, CUgraph) +IREE_CU_PFN_DECL(cuGraphExecDestroy, CUgraphExec) +IREE_CU_PFN_DECL(cuGraphGetNodes, CUgraph, CUgraphNode*, size_t*) +IREE_CU_PFN_DECL(cuGraphInstantiate, CUgraphExec*, CUgraph, CUgraphNode*, char*, + size_t) +IREE_CU_PFN_DECL(cuGraphLaunch, CUgraphExec, CUstream) +IREE_CU_PFN_DECL(cuInit, unsigned int) +IREE_CU_PFN_DECL(cuMemAllocManaged, CUdeviceptr*, size_t, unsigned int) +IREE_CU_PFN_DECL(cuMemPrefetchAsync, CUdeviceptr, size_t, CUdevice, CUstream) +IREE_CU_PFN_DECL(cuMemAlloc, CUdeviceptr*, size_t) +IREE_CU_PFN_DECL(cuMemFree, CUdeviceptr) +IREE_CU_PFN_DECL(cuMemFreeHost, void*) +IREE_CU_PFN_DECL(cuMemHostAlloc, void**, size_t, unsigned int) +IREE_CU_PFN_DECL(cuMemHostRegister, void*, size_t, unsigned int) +IREE_CU_PFN_DECL(cuMemHostUnregister, void*) +IREE_CU_PFN_DECL(cuMemHostGetDevicePointer, CUdeviceptr*, void*, unsigned int) +IREE_CU_PFN_DECL(cuModuleGetFunction, CUfunction*, CUmodule, const char*) +IREE_CU_PFN_DECL(cuModuleLoadDataEx, CUmodule*, const void*, unsigned int, + CUjit_option*, void**) +IREE_CU_PFN_DECL(cuModuleUnload, CUmodule) +IREE_CU_PFN_DECL(cuStreamCreate, CUstream*, unsigned int) +IREE_CU_PFN_DECL(cuStreamDestroy, CUstream) +IREE_CU_PFN_DECL(cuStreamSynchronize, CUstream) +IREE_CU_PFN_DECL(cuStreamWaitEvent, CUstream, CUevent, unsigned int) +IREE_CU_PFN_DECL(cuMemsetD32Async, unsigned long long, unsigned int, size_t, + CUstream) +IREE_CU_PFN_DECL(cuMemsetD16Async, unsigned long long, unsigned short, size_t, + CUstream) +IREE_CU_PFN_DECL(cuMemsetD8Async, unsigned long long, unsigned char, size_t, + CUstream) +IREE_CU_PFN_DECL(cuMemcpyAsync, CUdeviceptr, CUdeviceptr, size_t, CUstream) +IREE_CU_PFN_DECL(cuMemcpyHtoDAsync_v2, CUdeviceptr, const void*, size_t, + CUstream) +IREE_CU_PFN_DECL(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int) +IREE_CU_PFN_DECL(cuLaunchKernel, CUfunction, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, CUstream, void**, void**) + +//===----------------------------------------------------------------------===// +// NCCL symbols +//===----------------------------------------------------------------------===// +IREE_NCCL_PFN_DECL(ncclGetVersion, int*) +IREE_NCCL_PFN_DECL(ncclGetUniqueId, ncclUniqueId*) +IREE_NCCL_PFN_DECL(ncclCommInitRankConfig, ncclComm_t*, int, ncclUniqueId, int, + ncclConfig_t*) +IREE_NCCL_PFN_DECL(ncclCommInitRank, ncclComm_t*, int, ncclUniqueId, int) +IREE_NCCL_PFN_DECL(ncclCommInitAll, ncclComm_t*, int, const int*) +IREE_NCCL_PFN_DECL(ncclCommSplit, ncclComm_t, int, int, ncclComm_t*, + ncclConfig_t*) +IREE_NCCL_PFN_DECL(ncclCommFinalize, ncclComm_t) +IREE_NCCL_PFN_DECL(ncclCommDestroy, ncclComm_t) +IREE_NCCL_PFN_DECL(ncclCommAbort, ncclComm_t) +IREE_NCCL_PFN_DECL_STR_RETURN(ncclGetErrorString, ncclResult_t) +IREE_NCCL_PFN_DECL_STR_RETURN(ncclGetLastError, ncclComm_t) +IREE_NCCL_PFN_DECL(ncclCommGetAsyncError, ncclComm_t, ncclResult_t*) +IREE_NCCL_PFN_DECL(ncclCommCount, const ncclComm_t, int*) +IREE_NCCL_PFN_DECL(ncclCommCuDevice, const ncclComm_t, int*) +IREE_NCCL_PFN_DECL(ncclCommUserRank, const ncclComm_t, int*) +IREE_NCCL_PFN_DECL(ncclRedOpCreatePreMulSum, ncclRedOp_t*, void*, + ncclDataType_t, ncclScalarResidence_t, ncclComm_t) +IREE_NCCL_PFN_DECL(ncclRedOpDestroy, ncclRedOp_t, ncclComm_t) +IREE_NCCL_PFN_DECL(ncclReduce, const void*, void*, size_t, ncclDataType_t, + ncclRedOp_t, int, ncclComm_t, cudaStream_t) +IREE_NCCL_PFN_DECL(ncclBcast, void*, size_t, ncclDataType_t, int, ncclComm_t, + cudaStream_t) +IREE_NCCL_PFN_DECL(ncclBroadcast, const void*, void*, size_t, ncclDataType_t, + int, ncclComm_t, cudaStream_t) +IREE_NCCL_PFN_DECL(ncclAllReduce, const void*, void*, size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t) +IREE_NCCL_PFN_DECL(ncclReduceScatter, const void*, void*, size_t, + ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t) +IREE_NCCL_PFN_DECL(ncclAllGather, const void*, void*, size_t, ncclDataType_t, + ncclComm_t, cudaStream_t) +IREE_NCCL_PFN_DECL(ncclSend, const void*, size_t, ncclDataType_t, int, + ncclComm_t, cudaStream_t) +IREE_NCCL_PFN_DECL(ncclRecv, void*, size_t, ncclDataType_t, int, ncclComm_t, + cudaStream_t) +IREE_NCCL_PFN_DECL(ncclGroupStart) +IREE_NCCL_PFN_DECL(ncclGroupEnd) diff --git a/experimental/cuda2/dynamic_symbols.c b/experimental/cuda2/dynamic_symbols.c new file mode 100644 index 000000000000..6321bafa413b --- /dev/null +++ b/experimental/cuda2/dynamic_symbols.c @@ -0,0 +1,229 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/cuda2/dynamic_symbols.h" + +#include + +#include "experimental/cuda2/status_util.h" +#include "iree/base/assert.h" +#include "iree/base/internal/dynamic_library.h" +#include "iree/base/target_platform.h" +#include "iree/base/tracing.h" + +//===----------------------------------------------------------------------===// +// CUDA dynamic symbols +//===----------------------------------------------------------------------===// + +static const char* iree_hal_cuda_dylib_names[] = { +#if defined(IREE_PLATFORM_WINDOWS) + "nvcuda.dll", +#else + "libcuda.so", +#endif // IREE_PLATFORM_WINDOWS +}; + +#define IREE_CONCAT(A, B) A B + +// Resolves all CUDA dynamic symbols in `dynamic_symbol_tables.h`, prefer _v2 +// version if it exists. +static iree_status_t iree_hal_cuda2_dynamic_symbols_resolve_all( + iree_hal_cuda2_dynamic_symbols_t* syms) { +#define IREE_CU_PFN_DECL(cuda_symbol_name, ...) \ + { \ + static const char* name = #cuda_symbol_name; \ + IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \ + syms->dylib, name, (void**)&syms->cuda_symbol_name)); \ + static const char* name_v2 = IREE_CONCAT(#cuda_symbol_name, "_v2"); \ + void* fptr_v2; \ + iree_dynamic_library_lookup_symbol(syms->dylib, name_v2, &fptr_v2); \ + if (fptr_v2) syms->cuda_symbol_name = fptr_v2; \ + } +// Ignore NCCL symbols +#define IREE_NCCL_PFN_DECL(nccl_symbol_name, ...) +#define IREE_NCCL_PFN_DECL_STR_RETURN(nccl_symbol_name, ...) +#include "experimental/cuda2/dynamic_symbol_tables.h" // IWYU pragma: keep +#undef IREE_CU_PFN_DECL +#undef IREE_NCCL_PFN_DECL +#undef IREE_NCCL_PFN_DECL_STR_RETURN + return iree_ok_status(); +} + +#undef IREE_CONCAT + +iree_status_t iree_hal_cuda2_dynamic_symbols_initialize( + iree_allocator_t host_allocator, + iree_hal_cuda2_dynamic_symbols_t* out_syms) { + IREE_ASSERT_ARGUMENT(out_syms); + IREE_TRACE_ZONE_BEGIN(z0); + + memset(out_syms, 0, sizeof(*out_syms)); + iree_status_t status = iree_dynamic_library_load_from_files( + IREE_ARRAYSIZE(iree_hal_cuda_dylib_names), iree_hal_cuda_dylib_names, + IREE_DYNAMIC_LIBRARY_FLAG_NONE, host_allocator, &out_syms->dylib); + if (iree_status_is_not_found(status)) { + iree_status_ignore(status); + status = iree_make_status( + IREE_STATUS_UNAVAILABLE, + "CUDA driver library 'libcuda.so'/'nvcuda.dll' not available; please " + "ensure installed and in dynamic library search path"); + } + if (iree_status_is_ok(status)) { + status = iree_hal_cuda2_dynamic_symbols_resolve_all(out_syms); + } + if (!iree_status_is_ok(status)) { + iree_hal_cuda2_dynamic_symbols_deinitialize(out_syms); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +void iree_hal_cuda2_dynamic_symbols_deinitialize( + iree_hal_cuda2_dynamic_symbols_t* syms) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_dynamic_library_release(syms->dylib); + memset(syms, 0, sizeof(*syms)); + + IREE_TRACE_ZONE_END(z0); +} + +//===----------------------------------------------------------------------===// +// NCCL dynamic symbols +//===----------------------------------------------------------------------===// + +static const char* iree_hal_cuda_nccl_dylib_names[] = { +#if defined(IREE_PLATFORM_WINDOWS) + "nccl.dll", +#else + "libnccl.so", +#endif // IREE_PLATFORM_WINDOWS +}; + +// Resolves all NCCL dynamic symbols in `dynamic_symbol_tables.h`, prefer _v2 +// version if it exists. +static iree_status_t iree_hal_cuda2_nccl_dynamic_symbols_resolve_all( + iree_hal_cuda2_nccl_dynamic_symbols_t* syms) { +#define IREE_NCCL_PFN_DECL(nccl_symbol_name, ...) \ + { \ + static const char* name = #nccl_symbol_name; \ + IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \ + syms->dylib, name, (void**)&syms->nccl_symbol_name)); \ + } +#define IREE_NCCL_PFN_DECL_STR_RETURN(nccl_symbol_name, ...) \ + { \ + static const char* name = #nccl_symbol_name; \ + IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \ + syms->dylib, name, (void**)&syms->nccl_symbol_name)); \ + } +// Ignore CUDA symbols +#define IREE_CU_PFN_DECL(cuda_symbol_name, ...) +#include "experimental/cuda2/dynamic_symbol_tables.h" // IWYU pragma: keep +#undef IREE_NCCL_PFN_DECL +#undef IREE_NCCL_PFN_DECL_STR_RETURN +#undef IREE_CU_PFN_DECL + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda2_nccl_check_version( + iree_dynamic_library_t* nccl_library) { + ncclResult_t (*ncclGetVersion)(int*) = NULL; + + iree_status_t status = iree_dynamic_library_lookup_symbol( + nccl_library, "ncclGetVersion", (void**)&ncclGetVersion); + if (!iree_status_is_ok(status)) { + iree_status_ignore(status); + return iree_make_status( + IREE_STATUS_UNAVAILABLE, + "ncclGetVersion symbol not found in dynamic library"); + } + + // Check the NCCL version compatibility. + int nccl_version = 0; + ncclResult_t result = ncclGetVersion(&nccl_version); + if (result != ncclSuccess) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "ncclGetVersion() failed with error %d", result); + } + + int major = 0; + int minor = 0; + int patch = 0; + if (nccl_version < 20000) { + major = nccl_version / 1000; + minor = (nccl_version % 1000) / 100; + } else { + major = nccl_version / 10000; + minor = (nccl_version % 10000) / 100; + } + patch = nccl_version % 100; + if (major != NCCL_MAJOR || minor != NCCL_MINOR || patch != NCCL_PATCH) { + return iree_make_status( + IREE_STATUS_UNAVAILABLE, + "NCCL version is %d.%d.%d, but %d.%d.%d is required", major, minor, + patch, NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH); + } + + return iree_ok_status(); +} + +iree_status_t iree_hal_cuda2_nccl_dynamic_symbols_initialize( + iree_allocator_t host_allocator, + const iree_hal_cuda2_dynamic_symbols_t* cuda_library, + iree_hal_cuda2_nccl_dynamic_symbols_t* out_syms) { + IREE_ASSERT_ARGUMENT(out_syms); + if (!cuda_library->dylib) { + return iree_make_status( + IREE_STATUS_FAILED_PRECONDITION, + "CUDA dynamic symbols must be resolved prior to loading NCCL symbols"); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + memset(out_syms, 0, sizeof(*out_syms)); + iree_status_t status = iree_dynamic_library_load_from_files( + IREE_ARRAYSIZE(iree_hal_cuda_nccl_dylib_names), + iree_hal_cuda_nccl_dylib_names, IREE_DYNAMIC_LIBRARY_FLAG_NONE, + host_allocator, &out_syms->dylib); + if (iree_status_is_not_found(status)) { + iree_status_ignore(status); + status = iree_make_status( + IREE_STATUS_UNAVAILABLE, + "NCCL runtime library 'libnccl.so'/'nccl.dll' (version %d.%d.%d) not " + "available; please ensure installed and in dynamic library search path", + NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH); + } + + if (iree_status_is_ok(status)) { + // Check the version first before resolving all symbols. This makes sure + // that we have the right version and all symbols are available at the + // time of resolving. + status = iree_hal_cuda2_nccl_check_version(out_syms->dylib); + } + + // Resolve all symbols; this will fail if any required symbols are missing. + if (iree_status_is_ok(status)) { + status = iree_hal_cuda2_nccl_dynamic_symbols_resolve_all(out_syms); + } + + if (!iree_status_is_ok(status)) { + iree_dynamic_library_release(out_syms->dylib); + out_syms->dylib = NULL; + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +void iree_hal_cuda2_nccl_dynamic_symbols_deinitialize( + iree_hal_cuda2_nccl_dynamic_symbols_t* syms) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_dynamic_library_release(syms->dylib); + memset(syms, 0, sizeof(*syms)); + + IREE_TRACE_ZONE_END(z0); +} diff --git a/experimental/cuda2/dynamic_symbols.h b/experimental/cuda2/dynamic_symbols.h new file mode 100644 index 000000000000..a9e739c8def4 --- /dev/null +++ b/experimental/cuda2/dynamic_symbols.h @@ -0,0 +1,97 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_CUDA2_DYNAMIC_SYMBOLS_H_ +#define IREE_EXPERIMENTAL_CUDA2_DYNAMIC_SYMBOLS_H_ + +#include "experimental/cuda2/cuda_headers.h" +#include "iree/base/api.h" +#include "iree/base/internal/dynamic_library.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// iree_dynamic_library_t allows dynamically loading a subset of CUDA driver API +// and the NCCL API. We load all the symbols in `dynamic_symbol_tables.h` and +// fail if any of the symbol is not available. The functions signatures are +// matching the declarations in `cuda.h` and `nccl.h"`. + +//===----------------------------------------------------------------------===// +// CUDA dynamic symbols +//===----------------------------------------------------------------------===// + +// CUDA driver API dynamic symbols. +typedef struct iree_hal_cuda2_dynamic_symbols_t { + // The dynamic library handle. + iree_dynamic_library_t* dylib; + + // Concrete CUDA symbols defined by including the `dynamic_symbol_tables.h`. +#define IREE_CU_PFN_DECL(cudaSymbolName, ...) \ + CUresult (*cudaSymbolName)(__VA_ARGS__); +// Ignore NCCL symbols +#define IREE_NCCL_PFN_DECL(ncclSymbolName, ...) +#define IREE_NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...) +#include "experimental/cuda2/dynamic_symbol_tables.h" // IWYU pragma: export +#undef IREE_CU_PFN_DECL +#undef IREE_NCCL_PFN_DECL +#undef IREE_NCCL_PFN_DECL_STR_RETURN +} iree_hal_cuda2_dynamic_symbols_t; + +// Initializes |out_syms| in-place with dynamically loaded CUDA symbols. +// iree_hal_cuda2_dynamic_symbols_deinitialize must be used to release the +// library resources. +iree_status_t iree_hal_cuda2_dynamic_symbols_initialize( + iree_allocator_t host_allocator, + iree_hal_cuda2_dynamic_symbols_t* out_syms); + +// Deinitializes |syms| by unloading the backing library. All function pointers +// will be invalidated. They _may_ still work if there are other reasons the +// library remains loaded so be careful. +void iree_hal_cuda2_dynamic_symbols_deinitialize( + iree_hal_cuda2_dynamic_symbols_t* syms); + +//===----------------------------------------------------------------------===// +// NCCL dynamic symbols +//===----------------------------------------------------------------------===// + +// NCCL API dynamic symbols. +typedef struct iree_hal_cuda2_nccl_dynamic_symbols_t { + // The dynamic library handle. + iree_dynamic_library_t* dylib; + + // Concrete NCCL symbols defined by including the `dynamic_symbol_tables.h`. +#define IREE_NCCL_PFN_DECL(ncclSymbolName, ...) \ + ncclResult_t (*ncclSymbolName)(__VA_ARGS__); +#define IREE_NCCL_PFN_DECL_STR_RETURN(ncclSymbolName, ...) \ + const char* (*ncclSymbolName)(__VA_ARGS__); +// Ignore CUDA symbols +#define IREE_CU_PFN_DECL(cudaSymbolName, ...) +#include "experimental/cuda2/dynamic_symbol_tables.h" // IWYU pragma: export +#undef IREE_NCCL_PFN_DECL +#undef IREE_NCCL_PFN_DECL_STR_RETURN +#undef IREE_CU_PFN_DECL +} iree_hal_cuda2_nccl_dynamic_symbols_t; + +// Initializes |out_syms| in-place with dynamically loaded NCCL symbols. +// iree_hal_cuda2_dynamic_symbols_deinitialize must be used to release the +// library resources. +iree_status_t iree_hal_cuda2_nccl_dynamic_symbols_initialize( + iree_allocator_t host_allocator, + const iree_hal_cuda2_dynamic_symbols_t* cuda_library, + iree_hal_cuda2_nccl_dynamic_symbols_t* out_syms); + +// Deinitializes |syms| by unloading the backing library. All function pointers +// will be invalidated. They _may_ still work if there are other reasons the +// library remains loaded so be careful. +void iree_hal_cuda2_nccl_dynamic_symbols_deinitialize( + iree_hal_cuda2_nccl_dynamic_symbols_t* syms); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_CUDA2_DYNAMIC_SYMBOLS_H_ diff --git a/experimental/cuda2/dynamic_symbols_test.cc b/experimental/cuda2/dynamic_symbols_test.cc new file mode 100644 index 000000000000..39e19b480f51 --- /dev/null +++ b/experimental/cuda2/dynamic_symbols_test.cc @@ -0,0 +1,84 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/cuda2/dynamic_symbols.h" + +#include + +#include "iree/base/api.h" +#include "iree/testing/gtest.h" + +namespace iree { +namespace hal { +namespace cuda { +namespace { + +#define CUDA_CHECK_ERRORS(expr) \ + { \ + CUresult status = expr; \ + ASSERT_EQ(CUDA_SUCCESS, status); \ + } + +TEST(DynamicSymbolsTest, CreateFromSystemLoader) { + iree_hal_cuda2_dynamic_symbols_t symbols; + iree_status_t status = iree_hal_cuda2_dynamic_symbols_initialize( + iree_allocator_system(), &symbols); + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + iree_status_ignore(status); + std::cerr << "Symbols cannot be loaded, skipping test."; + GTEST_SKIP(); + } + + int device_count = 0; + CUDA_CHECK_ERRORS(symbols.cuInit(0)); + CUDA_CHECK_ERRORS(symbols.cuDeviceGetCount(&device_count)); + if (device_count > 0) { + CUdevice device; + CUDA_CHECK_ERRORS(symbols.cuDeviceGet(&device, /*ordinal=*/0)); + } + + iree_hal_cuda2_dynamic_symbols_deinitialize(&symbols); +} + +#define NCCL_CHECK_ERRORS(expr) \ + { \ + ncclResult_t status = expr; \ + ASSERT_EQ(ncclSuccess, status); \ + } + +TEST(NCCLDynamicSymbolsTest, CreateFromSystemLoader) { + iree_hal_cuda2_dynamic_symbols_t cuda_symbols; + iree_status_t status = iree_hal_cuda2_dynamic_symbols_initialize( + iree_allocator_system(), &cuda_symbols); + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + iree_status_ignore(status); + std::cerr << "CUDA symbols cannot be loaded, skipping test."; + GTEST_SKIP(); + } + + iree_hal_cuda2_nccl_dynamic_symbols_t nccl_symbols; + status = iree_hal_cuda2_nccl_dynamic_symbols_initialize( + iree_allocator_system(), &cuda_symbols, &nccl_symbols); + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + iree_status_ignore(status); + std::cerr << "CUDA NCCL symbols cannot be loaded, skipping test."; + GTEST_SKIP(); + } + + int nccl_version = 0; + NCCL_CHECK_ERRORS(nccl_symbols.ncclGetVersion(&nccl_version)); + ASSERT_EQ(NCCL_VERSION_CODE, nccl_version); + iree_hal_cuda2_nccl_dynamic_symbols_deinitialize(&nccl_symbols); + iree_hal_cuda2_dynamic_symbols_deinitialize(&cuda_symbols); +} + +} // namespace +} // namespace cuda +} // namespace hal +} // namespace iree diff --git a/experimental/cuda2/registration/CMakeLists.txt b/experimental/cuda2/registration/CMakeLists.txt new file mode 100644 index 000000000000..54372e960593 --- /dev/null +++ b/experimental/cuda2/registration/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_cc_library( + NAME + registration + HDRS + "driver_module.h" + SRCS + "driver_module.c" + DEPS + iree::base + iree::base::core_headers + iree::base::tracing + iree::experimental::cuda2 + iree::hal + DEFINES + "IREE_HAVE_HAL_CUDA2_DRIVER_MODULE=1" + PUBLIC +) diff --git a/experimental/cuda2/registration/driver_module.c b/experimental/cuda2/registration/driver_module.c new file mode 100644 index 000000000000..4089ed0d1165 --- /dev/null +++ b/experimental/cuda2/registration/driver_module.c @@ -0,0 +1,106 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/cuda2/registration/driver_module.h" + +#include +#include + +#include "experimental/cuda2/api.h" +#include "iree/base/api.h" +#include "iree/base/internal/flags.h" +#include "iree/base/status.h" +#include "iree/base/tracing.h" + +IREE_FLAG(int32_t, cuda2_default_index, 0, + "Specifies the index of the default CUDA device to use"); + +IREE_FLAG(bool, cuda2_default_index_from_mpi, true, + "Infers the default CUDA device index from the PMI_RANK or\n" + "OMPI_COMM_WORLD_LOCAL_RANK environment variables when set"); + +static bool iree_try_parse_env_i32(const char* var_name, int32_t* out_value) { + const char* var_value = getenv(var_name); + if (!var_value || strlen(var_value) == 0) return false; + return iree_string_view_atoi_int32(iree_make_cstring_view(var_value), + out_value); +} + +// Tries to infer the device index using the local MPI rank from environment +// variables; otherwise returns |default_index|. +// +// This makes it easy to use N devices on a single system when running via +// `mpiexec`. +static int32_t iree_hal_cuda2_infer_device_index_from_env( + int32_t default_index) { + // TODO: try more env vars from other implementations. This covers Intel/MS + // and OpenMPI today. + int32_t result = 0; + if (iree_try_parse_env_i32("PMI_RANK", &result) || + iree_try_parse_env_i32("OMPI_COMM_WORLD_LOCAL_RANK", &result)) { + return result; + } + return default_index; +} + +static iree_status_t iree_hal_cuda2_driver_factory_enumerate( + void* self, iree_host_size_t* out_driver_info_count, + const iree_hal_driver_info_t** out_driver_infos) { + IREE_ASSERT_ARGUMENT(out_driver_info_count); + IREE_ASSERT_ARGUMENT(out_driver_infos); + IREE_TRACE_ZONE_BEGIN(z0); + + static const iree_hal_driver_info_t driver_infos[1] = {{ + .driver_name = IREE_SVL("cuda2"), + .full_name = IREE_SVL("next-gen NVIDIA CUDA HAL driver (via dylib)"), + }}; + *out_driver_info_count = IREE_ARRAYSIZE(driver_infos); + *out_driver_infos = driver_infos; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda2_driver_factory_try_create( + void* self, iree_string_view_t driver_name, iree_allocator_t host_allocator, + iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(out_driver); + + if (!iree_string_view_equal(driver_name, IREE_SV("cuda2"))) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "no driver '%.*s' is provided by this factory", + (int)driver_name.size, driver_name.data); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_cuda2_driver_options_t driver_options; + iree_hal_cuda2_driver_options_initialize(&driver_options); + + driver_options.default_device_index = FLAG_cuda2_default_index; + if (FLAG_cuda2_default_index_from_mpi) { + driver_options.default_device_index = + iree_hal_cuda2_infer_device_index_from_env( + driver_options.default_device_index); + } + + iree_status_t status = iree_hal_cuda2_driver_create( + driver_name, &driver_options, host_allocator, out_driver); + + IREE_TRACE_ZONE_END(z0); + + return status; +} + +IREE_API_EXPORT iree_status_t +iree_hal_cuda2_driver_module_register(iree_hal_driver_registry_t* registry) { + static const iree_hal_driver_factory_t factory = { + .self = NULL, + .enumerate = iree_hal_cuda2_driver_factory_enumerate, + .try_create = iree_hal_cuda2_driver_factory_try_create, + }; + return iree_hal_driver_registry_register_factory(registry, &factory); +} diff --git a/experimental/cuda2/registration/driver_module.h b/experimental/cuda2/registration/driver_module.h new file mode 100644 index 000000000000..c92643da78cd --- /dev/null +++ b/experimental/cuda2/registration/driver_module.h @@ -0,0 +1,25 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_CUDA2_REGISTRATION_DRIVER_MODULE_H_ +#define IREE_EXPERIMENTAL_CUDA2_REGISTRATION_DRIVER_MODULE_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Registers the CUDA HAL driver to the given |registry|. +IREE_API_EXPORT iree_status_t +iree_hal_cuda2_driver_module_register(iree_hal_driver_registry_t* registry); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_CUDA2_REGISTRATION_DRIVER_MODULE_H_ diff --git a/experimental/cuda2/status_util.c b/experimental/cuda2/status_util.c new file mode 100644 index 000000000000..c23be647579e --- /dev/null +++ b/experimental/cuda2/status_util.c @@ -0,0 +1,185 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "experimental/cuda2/status_util.h" + +#include + +#include "experimental/cuda2/dynamic_symbols.h" +#include "iree/base/status.h" + +//===----------------------------------------------------------------------===// +// CUDA result utilities +//===----------------------------------------------------------------------===// + +// The list of CUDA error strings with their corresponding IREE error state +// classification. +// +// Note that the list of errors is taken from `cudaError_enum` in cuda.h. +// This is not an exhaustive list; we are just listing common ones here. +#define IREE_CUDA_ERROR_LIST(IREE_CUDA_MAP_ERROR) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_INVALID_VALUE", \ + IREE_STATUS_INVALID_ARGUMENT) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_OUT_OF_MEMORY", \ + IREE_STATUS_RESOURCE_EXHAUSTED) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_NOT_INITIALIZED", IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_DEINITIALIZED", IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_STUB_LIBRARY", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_NO_DEVICE", IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_INVALID_DEVICE", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_DEVICE_NOT_LICENSED", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_INVALID_IMAGE", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_INVALID_CONTEXT", IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_MAP_FAILED", IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_UNMAP_FAILED", IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_ALREADY_MAPPED", IREE_STATUS_ALREADY_EXISTS) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_NO_BINARY_FOR_GPU", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_ALREADY_ACQUIRED", \ + IREE_STATUS_ALREADY_EXISTS) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_NOT_MAPPED", IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_NOT_MAPPED_AS_ARRAY", IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_NOT_MAPPED_AS_POINTER", \ + IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_UNSUPPORTED_LIMIT", \ + IREE_STATUS_OUT_OF_RANGE) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_CONTEXT_ALREADY_IN_USE", \ + IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_INVALID_PTX", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_NVLINK_UNCORRECTABLE", \ + IREE_STATUS_DATA_LOSS) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_JIT_COMPILER_NOT_FOUND", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_UNSUPPORTED_PTX_VERSION", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_JIT_COMPILATION_DISABLED", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_INVALID_SOURCE", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_FILE_NOT_FOUND", IREE_STATUS_NOT_FOUND) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND", \ + IREE_STATUS_NOT_FOUND) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_SHARED_OBJECT_INIT_FAILED", \ + IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_OPERATING_SYSTEM", IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_INVALID_HANDLE", \ + IREE_STATUS_INVALID_ARGUMENT) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_ILLEGAL_STATE", IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_NOT_FOUND", IREE_STATUS_NOT_FOUND) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_NOT_READY", IREE_STATUS_UNAVAILABLE) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_ILLEGAL_ADDRESS", IREE_STATUS_ABORTED) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES", \ + IREE_STATUS_RESOURCE_EXHAUSTED) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_LAUNCH_TIMEOUT", \ + IREE_STATUS_DEADLINE_EXCEEDED) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE", \ + IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_CONTEXT_IS_DESTROYED", IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_ASSERT", IREE_STATUS_ABORTED) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED", \ + IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED", \ + IREE_STATUS_INTERNAL) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_HARDWARE_STACK_ERROR", IREE_STATUS_ABORTED) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_ILLEGAL_INSTRUCTION", IREE_STATUS_ABORTED) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_MISALIGNED_ADDRESS", IREE_STATUS_ABORTED) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_INVALID_ADDRESS_SPACE", IREE_STATUS_ABORTED) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_INVALID_PC", IREE_STATUS_ABORTED) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_LAUNCH_FAILED", IREE_STATUS_ABORTED) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE", \ + IREE_STATUS_OUT_OF_RANGE) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_NOT_PERMITTED", \ + IREE_STATUS_PERMISSION_DENIED) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_NOT_SUPPORTED", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_SYSTEM_NOT_READY", IREE_STATUS_UNAVAILABLE) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_SYSTEM_DRIVER_MISMATCH", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_TIMEOUT", IREE_STATUS_DEADLINE_EXCEEDED) \ + IREE_CUDA_MAP_ERROR("CUDA_ERROR_UNKNOWN", IREE_STATUS_UNKNOWN) + +// Converts CUDA |error_name| to the corresponding IREE status code. +static iree_status_code_t iree_hal_cuda2_error_name_to_status_code( + const char* error_name) { +#define IREE_CUDA_ERROR_TO_IREE_STATUS(cuda_error, iree_status) \ + if (strncmp(error_name, cuda_error, strlen(cuda_error)) == 0) { \ + return iree_status; \ + } + IREE_CUDA_ERROR_LIST(IREE_CUDA_ERROR_TO_IREE_STATUS) +#undef IREE_CUDA_ERROR_TO_IREE_STATUS + return IREE_STATUS_UNKNOWN; +} + +#undef IREE_CUDA_ERROR_LIST + +iree_status_t iree_hal_cuda2_result_to_status( + const iree_hal_cuda2_dynamic_symbols_t* syms, CUresult result, + const char* file, uint32_t line) { + if (IREE_LIKELY(result == CUDA_SUCCESS)) return iree_ok_status(); + + const char* error_name = NULL; + if (syms->cuGetErrorName(result, &error_name) != CUDA_SUCCESS) { + error_name = "CUDA_ERROR_UNKNOWN"; + } + + const char* error_string = NULL; + if (syms->cuGetErrorString(result, &error_string) != CUDA_SUCCESS) { + error_string = "unknown error"; + } + + return iree_make_status_with_location( + file, line, iree_hal_cuda2_error_name_to_status_code(error_name), + "CUDA error '%s' (%d): %s", error_name, result, error_string); +} + +//===----------------------------------------------------------------------===// +// NCCL result utilities +//===----------------------------------------------------------------------===// + +iree_status_t iree_hal_nccl2_result_to_status( + const iree_hal_cuda2_nccl_dynamic_symbols_t* syms, ncclResult_t result, + const char* file, uint32_t line) { + iree_status_code_t code; + + switch (result) { + case ncclSuccess: + return iree_ok_status(); + case ncclUnhandledCudaError: + code = IREE_STATUS_FAILED_PRECONDITION; + break; + case ncclSystemError: + code = IREE_STATUS_INTERNAL; + break; + case ncclInternalError: + code = IREE_STATUS_INTERNAL; + break; + case ncclInvalidArgument: + code = IREE_STATUS_INVALID_ARGUMENT; + break; + case ncclInvalidUsage: + code = IREE_STATUS_FAILED_PRECONDITION; + break; + case ncclRemoteError: + code = IREE_STATUS_UNAVAILABLE; + break; + case ncclInProgress: + code = IREE_STATUS_DEFERRED; + break; + default: + code = IREE_STATUS_INTERNAL; + break; + } + return iree_make_status_with_location(file, line, code, "NCCL error %d: %s", + result, + syms->ncclGetErrorString(result)); +} diff --git a/experimental/cuda2/status_util.h b/experimental/cuda2/status_util.h new file mode 100644 index 000000000000..316f0104ef55 --- /dev/null +++ b/experimental/cuda2/status_util.h @@ -0,0 +1,108 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_CUDA2_STATUS_UTIL_H_ +#define IREE_EXPERIMENTAL_CUDA2_STATUS_UTIL_H_ + +#include + +#include "experimental/cuda2/dynamic_symbols.h" +#include "iree/base/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// CUDA result macros +//===----------------------------------------------------------------------===// + +// Converts a CUresult to an iree_status_t. +// +// Usage: +// iree_status_t status = IREE_CURESULT_TO_STATUS(cuda_symbols, +// cuDoThing(...)); +#define IREE_CURESULT_TO_STATUS(syms, expr, ...) \ + iree_hal_cuda2_result_to_status((syms), ((syms)->expr), __FILE__, __LINE__) + +// IREE_RETURN_IF_ERROR but implicitly converts the CUresult return value to +// an iree_status_t. +// +// Usage: +// IREE_CUDA_RETURN_IF_ERROR(cuda_symbols, cuDoThing(...), "message"); +#define IREE_CUDA_RETURN_IF_ERROR(syms, expr, ...) \ + IREE_RETURN_IF_ERROR(iree_hal_cuda2_result_to_status((syms), ((syms)->expr), \ + __FILE__, __LINE__), \ + __VA_ARGS__) + +// IREE_RETURN_IF_ERROR but ends the current zone and implicitly converts the +// CUresult return value to an iree_status_t. +// +// Usage: +// IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(zone_id, cuda_symbols, +// cuDoThing(...), "message"); +#define IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(zone_id, syms, expr, ...) \ + IREE_RETURN_AND_END_ZONE_IF_ERROR( \ + zone_id, \ + iree_hal_cuda2_result_to_status((syms), ((syms)->expr), __FILE__, \ + __LINE__), \ + __VA_ARGS__) + +// IREE_IGNORE_ERROR but implicitly converts the CUresult return value to an +// iree_status_t. +// +// Usage: +// IREE_CUDA_IGNORE_ERROR(cuda_symbols, cuDoThing(...)); +#define IREE_CUDA_IGNORE_ERROR(syms, expr) \ + IREE_IGNORE_ERROR(iree_hal_cuda2_result_to_status((syms), ((syms)->expr), \ + __FILE__, __LINE__)) + +// Converts a CUresult to an iree_status_t object. +iree_status_t iree_hal_cuda2_result_to_status( + const iree_hal_cuda2_dynamic_symbols_t* syms, CUresult result, + const char* file, uint32_t line); + +//===----------------------------------------------------------------------===// +// NCCL result macros +//===----------------------------------------------------------------------===// + +// Converts a ncclResult_t to an iree_status_t. +// +// Usage: +// iree_status_t status = IREE_NCCL_RESULT_TO_STATUS(nccl_symbols, +// ncclDoThing(...)); +#define IREE_NCCL_RESULT_TO_STATUS(syms, expr, ...) \ + iree_hal_nccl2_result_to_status((syms), ((syms)->expr), __FILE__, __LINE__) + +// IREE_RETURN_IF_ERROR but implicitly converts the ncclResult_t return value to +// an iree_status_t. +// +// Usage: +// IREE_NCCL_RETURN_IF_ERROR(nccl_symbols, ncclDoThing(...), "message"); +#define IREE_NCCL_RETURN_IF_ERROR(syms, expr, ...) \ + IREE_RETURN_IF_ERROR(iree_hal_nccl2_result_to_status((syms), ((syms)->expr), \ + __FILE__, __LINE__), \ + __VA_ARGS__) + +// IREE_IGNORE_ERROR but implicitly converts the ncclResult_t return value to +// an iree_status_t. +// +// Usage: +// IREE_NCCL_IGNORE_ERROR(nccl_symbols, ncclDoThing(...)); +#define IREE_NCCL_IGNORE_ERROR(syms, expr) \ + IREE_IGNORE_ERROR(iree_hal_nccl2_result_to_status((syms), ((syms)->expr), \ + __FILE__, __LINE__)) + +// Converts a ncclResult_t to an iree_status_t object. +iree_status_t iree_hal_nccl2_result_to_status( + const iree_hal_cuda2_nccl_dynamic_symbols_t* syms, ncclResult_t result, + const char* file, uint32_t line); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_CUDA2_STATUS_UTIL_H_