-
Notifications
You must be signed in to change notification settings - Fork 556
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[cuda] Implement basics for a CUDA HAL driver rewrite (#13942)
This commit starts a CUDA HAL driver rewrite under `experimental/`. We create a new `cuda2/` directory to host the new code to avoid interrupting the current CodeGen development. This commit just brings in the basics for boot up a new HAL driver, including dynamic symbols management, error status management, and IREE HAL driver implementation. Most of the code is directly copied from existing HAL driver, with noticeable changes: * Split CUDA and NCCL dynamic symbols into separate structures for better organziation and allowing optionality. * Fleshed out CUDA error to IREE status conversions. * Better organized code blocks and improved error messages and various comments. Building this commmit with `-DIREE_EXTERNAL_HAL_DRIVERS=cuda2`, we can have `tools/iree-run-module --dump_devices` showing `cuda2` devices, in parallel to the existing CUDA one. Progress towards #13245
- Loading branch information
1 parent
2544efe
commit 5c38bcc
Showing
14 changed files
with
1,556 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
# Set the root for package namespacing to the current directory. | ||
set(IREE_PACKAGE_ROOT_DIR "${CMAKE_CURRENT_LIST_DIR}") | ||
set(IREE_PACKAGE_ROOT_PREFIX "iree/experimental/cuda2") | ||
|
||
iree_add_all_subdirs() | ||
|
||
iree_cc_library( | ||
NAME | ||
cuda2 | ||
HDRS | ||
"api.h" | ||
SRCS | ||
"api.h" | ||
"cuda_driver.c" | ||
DEPS | ||
::dynamic_symbols | ||
iree::base | ||
iree::base::core_headers | ||
iree::base::tracing | ||
iree::hal | ||
iree::schemas::cuda_executable_def_c_fbs | ||
PUBLIC | ||
) | ||
|
||
iree_cc_library( | ||
NAME | ||
dynamic_symbols | ||
HDRS | ||
"dynamic_symbols.h" | ||
"status_util.h" | ||
TEXTUAL_HDRS | ||
"dynamic_symbol_tables.h" | ||
SRCS | ||
"cuda_headers.h" | ||
"dynamic_symbols.c" | ||
"status_util.c" | ||
DEPS | ||
iree::base | ||
iree::base::core_headers | ||
iree::base::internal::dynamic_library | ||
iree::base::tracing | ||
iree_cuda::headers | ||
PUBLIC | ||
) | ||
|
||
iree_cc_test( | ||
NAME | ||
dynamic_symbols_test | ||
SRCS | ||
"dynamic_symbols_test.cc" | ||
DEPS | ||
::dynamic_symbols | ||
iree::base | ||
iree::testing::gtest | ||
iree::testing::gtest_main | ||
LABELS | ||
"driver=cuda2" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// Copyright 2023 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
// See iree/base/api.h for documentation on the API conventions used. | ||
|
||
#ifndef IREE_EXPERIMENTAL_CUDA2_API_H_ | ||
#define IREE_EXPERIMENTAL_CUDA2_API_H_ | ||
|
||
#include "iree/base/api.h" | ||
#include "iree/hal/api.h" | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif // __cplusplus | ||
|
||
//===----------------------------------------------------------------------===// | ||
// iree_hal_cuda2_driver_t | ||
//===----------------------------------------------------------------------===// | ||
|
||
// CUDA HAL driver creation options. | ||
typedef struct iree_hal_cuda2_driver_options_t { | ||
// The index of the default CUDA device to use within the list of available | ||
// devices. | ||
int default_device_index; | ||
} iree_hal_cuda2_driver_options_t; | ||
|
||
// Initializes the given |out_options| with default driver creation options. | ||
IREE_API_EXPORT void iree_hal_cuda2_driver_options_initialize( | ||
iree_hal_cuda2_driver_options_t* out_options); | ||
|
||
// Creates a CUDA HAL driver with the given |options|, from which CUDA devices | ||
// can be enumerated and created with specific parameters. | ||
// | ||
// |out_driver| must be released by the caller (see iree_hal_driver_release). | ||
IREE_API_EXPORT iree_status_t iree_hal_cuda2_driver_create( | ||
iree_string_view_t identifier, | ||
const iree_hal_cuda2_driver_options_t* options, | ||
iree_allocator_t host_allocator, iree_hal_driver_t** out_driver); | ||
|
||
#ifdef __cplusplus | ||
} // extern "C" | ||
#endif // __cplusplus | ||
|
||
#endif // IREE_EXPERIMENTAL_CUDA2_API_H_ |
Oops, something went wrong.