Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cuda] Switch cuda2 on and cuda1 off by default #16107

Merged
merged 3 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,10 @@ option(IREE_HAL_DRIVER_DEFAULTS "Sets the default value for all runtime HAL driv
# not cross compiling. Note: a CUDA-compatible GPU with drivers is still
# required to actually run CUDA workloads.
set(IREE_HAL_DRIVER_CUDA_DEFAULT ${IREE_HAL_DRIVER_DEFAULTS})
set(IREE_HAL_DRIVER_CUDA2_DEFAULT OFF)
set(IREE_HAL_DRIVER_CUDA1_DEFAULT OFF)
if(NOT IREE_CUDA_AVAILABLE OR CMAKE_CROSSCOMPILING)
set(IREE_HAL_DRIVER_CUDA_DEFAULT OFF)
set(IREE_HAL_DRIVER_CUDA2_DEFAULT OFF)
set(IREE_HAL_DRIVER_CUDA1_DEFAULT OFF)
endif()

# Vulkan support is enabled by default if the platform might support Vulkan.
Expand All @@ -262,7 +262,7 @@ if(NOT APPLE OR NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
endif()

option(IREE_HAL_DRIVER_CUDA "Enables the 'cuda' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA_DEFAULT})
option(IREE_HAL_DRIVER_CUDA2 "Enables the 'cuda2' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA2_DEFAULT})
option(IREE_HAL_DRIVER_CUDA1 "Enables the 'cuda1' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA1_DEFAULT})
option(IREE_HAL_DRIVER_LOCAL_SYNC "Enables the 'local-sync' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
option(IREE_HAL_DRIVER_LOCAL_TASK "Enables the 'local-task' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
option(IREE_HAL_DRIVER_VULKAN "Enables the 'vulkan' runtime HAL driver" ${IREE_HAL_DRIVER_VULKAN_DEFAULT})
Expand Down Expand Up @@ -318,8 +318,8 @@ message(STATUS "IREE HAL drivers:")
if(IREE_HAL_DRIVER_CUDA)
message(STATUS " - cuda")
endif()
if(IREE_HAL_DRIVER_CUDA2)
message(STATUS " - cuda2")
if(IREE_HAL_DRIVER_CUDA1)
message(STATUS " - cuda1")
endif()
if(IREE_HAL_DRIVER_LOCAL_SYNC)
message(STATUS " - local-sync")
Expand Down
8 changes: 4 additions & 4 deletions runtime/src/iree/hal/drivers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ endif()

set(_INIT_INTERNAL_DEPS)
if(IREE_HAL_DRIVER_CUDA)
add_subdirectory(cuda)
list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::cuda::registration)
endif()
if(IREE_HAL_DRIVER_CUDA2)
add_subdirectory(cuda2)
list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::cuda2::registration)
endif()
if(IREE_HAL_DRIVER_CUDA1)
add_subdirectory(cuda)
list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::cuda::registration)
endif()
if(IREE_HAL_DRIVER_LOCAL_SYNC)
add_subdirectory(local_sync)
list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::local_sync::registration)
Expand Down
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/cuda/cts/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

iree_hal_cts_test_suite(
DRIVER_NAME
cuda
cuda1
DRIVER_REGISTRATION_HDR
"runtime/src/iree/hal/drivers/cuda/registration/driver_module.h"
DRIVER_REGISTRATION_FN
Expand All @@ -28,7 +28,7 @@ iree_hal_cts_test_suite(
# Variant test suite using graph command buffers (--cuda_use_streams=0)
iree_hal_cts_test_suite(
DRIVER_NAME
cuda
cuda1
VARIANT_SUFFIX
graph
DRIVER_REGISTRATION_HDR
Expand Down
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/cuda/registration/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ iree_runtime_cc_library(
"driver_module.h",
],
defines = [
"IREE_HAVE_HAL_CUDA_DRIVER_MODULE=1",
"IREE_HAVE_HAL_CUDA1_DRIVER_MODULE=1",
],
tags = ["driver=cuda"],
tags = ["driver=cuda1"],
deps = [
"//runtime/src/iree/base",
"//runtime/src/iree/base/internal:flags",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ iree_cc_library(
iree::hal
iree::hal::drivers::cuda
DEFINES
"IREE_HAVE_HAL_CUDA_DRIVER_MODULE=1"
"IREE_HAVE_HAL_CUDA1_DRIVER_MODULE=1"
PUBLIC
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ static iree_status_t iree_hal_cuda_driver_factory_enumerate(
const iree_hal_driver_info_t** out_driver_infos) {
// NOTE: we could query supported cuda versions or featuresets here.
static const iree_hal_driver_info_t driver_infos[1] = {{
.driver_name = iree_string_view_literal("cuda"),
.full_name = iree_string_view_literal("CUDA (dynamic)"),
.driver_name = iree_string_view_literal("cuda1"),
.full_name = iree_string_view_literal("deprecated CUDA (dynamic)"),
}};
*out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
*out_driver_infos = driver_infos;
Expand All @@ -79,7 +79,7 @@ static iree_status_t iree_hal_cuda_driver_factory_try_create(
iree_hal_driver_t** out_driver) {
IREE_ASSERT_ARGUMENT(out_driver);
*out_driver = NULL;
if (!iree_string_view_equal(driver_name, IREE_SV("cuda"))) {
if (!iree_string_view_equal(driver_name, IREE_SV("cuda1"))) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
"no driver '%.*s' is provided by this factory",
(int)driver_name.size, driver_name.data);
Expand Down
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/cuda2/cts/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

iree_hal_cts_test_suite(
DRIVER_NAME
cuda2
cuda
VARIANT_SUFFIX
graph
DRIVER_REGISTRATION_HDR
Expand All @@ -31,7 +31,7 @@ iree_hal_cts_test_suite(

iree_hal_cts_test_suite(
DRIVER_NAME
cuda2
cuda
VARIANT_SUFFIX
stream
DRIVER_REGISTRATION_HDR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ iree_runtime_cc_library(
"driver_module.h",
],
defines = [
"IREE_HAVE_HAL_CUDA2_DRIVER_MODULE=1",
"IREE_HAVE_HAL_CUDA_DRIVER_MODULE=1",
],
tags = ["driver=cuda"],
deps = [
"//runtime/src/iree/base",
"//runtime/src/iree/base/internal:flags",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ iree_cc_library(
iree::hal
iree::hal::drivers::cuda2
DEFINES
"IREE_HAVE_HAL_CUDA2_DRIVER_MODULE=1"
"IREE_HAVE_HAL_CUDA_DRIVER_MODULE=1"
PUBLIC
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "iree/hal/drivers/cuda2/api.h"

IREE_FLAG(
bool, cuda2_use_streams, false,
bool, cuda2_use_streams, true,
"Use CUDA streams (instead of graphs) for executing command buffers.");

IREE_FLAG(bool, cuda2_allow_inline_execution, false,
Expand Down Expand Up @@ -70,8 +70,8 @@ static iree_status_t iree_hal_cuda2_driver_factory_enumerate(
IREE_TRACE_ZONE_BEGIN(z0);

static const iree_hal_driver_info_t driver_infos[1] = {{
.driver_name = IREE_SVL("cuda2"),
.full_name = IREE_SVL("next-gen NVIDIA CUDA HAL driver (via dylib)"),
.driver_name = IREE_SVL("cuda"),
.full_name = IREE_SVL("NVIDIA CUDA HAL driver (via dylib)"),
}};
*out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
*out_driver_infos = driver_infos;
Expand All @@ -85,7 +85,7 @@ static iree_status_t iree_hal_cuda2_driver_factory_try_create(
iree_hal_driver_t** out_driver) {
IREE_ASSERT_ARGUMENT(out_driver);

if (!iree_string_view_equal(driver_name, IREE_SV("cuda2"))) {
if (!iree_string_view_equal(driver_name, IREE_SV("cuda"))) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
"no driver '%.*s' is provided by this factory",
(int)driver_name.size, driver_name.data);
Expand Down
16 changes: 8 additions & 8 deletions runtime/src/iree/hal/drivers/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
#include "iree/hal/drivers/init.h"

#if defined(IREE_HAVE_HAL_CUDA_DRIVER_MODULE)
#include "iree/hal/drivers/cuda/registration/driver_module.h"
#include "iree/hal/drivers/cuda2/registration/driver_module.h"
#endif // IREE_HAVE_HAL_CUDA_DRIVER_MODULE

#if defined(IREE_HAVE_HAL_CUDA2_DRIVER_MODULE)
#include "iree/hal/drivers/cuda2/registration/driver_module.h"
#endif // IREE_HAVE_HAL_CUDA2_DRIVER_MODULE
#if defined(IREE_HAVE_HAL_CUDA1_DRIVER_MODULE)
#include "iree/hal/drivers/cuda/registration/driver_module.h"
#endif // IREE_HAVE_HAL_CUDA1_DRIVER_MODULE

#if defined(IREE_HAVE_HAL_LOCAL_SYNC_DRIVER_MODULE)
#include "iree/hal/drivers/local_sync/registration/driver_module.h"
Expand Down Expand Up @@ -47,13 +47,13 @@ iree_hal_register_all_available_drivers(iree_hal_driver_registry_t* registry) {

#if defined(IREE_HAVE_HAL_CUDA_DRIVER_MODULE)
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda_driver_module_register(registry));
z0, iree_hal_cuda2_driver_module_register(registry));
#endif // IREE_HAVE_HAL_CUDA_DRIVER_MODULE

#if defined(IREE_HAVE_HAL_CUDA2_DRIVER_MODULE)
#if defined(IREE_HAVE_HAL_CUDA1_DRIVER_MODULE)
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_driver_module_register(registry));
#endif // IREE_HAVE_HAL_CUDA2_DRIVER_MODULE
z0, iree_hal_cuda_driver_module_register(registry));
#endif // IREE_HAVE_HAL_CUDA1_DRIVER_MODULE

#if defined(IREE_HAVE_HAL_LOCAL_SYNC_DRIVER_MODULE)
IREE_RETURN_AND_END_ZONE_IF_ERROR(
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/stablehlo_ops/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ iree_check_single_backend_test_suite(
# TODO(#13984): memset emulation required for graphs.
"--iree-stream-emulate-memset",
],
driver = "cuda",
driver = "cuda1",
input_type = "stablehlo",
runner_args = ["--cuda_use_streams=false"],
tags = [
Expand Down Expand Up @@ -499,7 +499,7 @@ iree_check_single_backend_test_suite(
include = ["*.mlir"],
exclude = [],
),
driver = "cuda",
driver = "cuda1",
input_type = "stablehlo",
runner_args = ["--cuda_use_streams=true"],
tags = [
Expand Down Expand Up @@ -589,7 +589,7 @@ iree_check_single_backend_test_suite(
"--iree-stream-emulate-memset",
"--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda2",
driver = "cuda",
input_type = "stablehlo",
runner_args = ["--cuda2_use_streams=false"],
tags = [
Expand All @@ -609,7 +609,7 @@ iree_check_single_backend_test_suite(
compiler_flags = [
"--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda2",
driver = "cuda",
input_type = "stablehlo",
runner_args = ["--cuda2_use_streams=true"],
tags = [
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/stablehlo_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda"
"cuda1"
COMPILER_FLAGS
"--iree-stream-emulate-memset"
INPUT_TYPE
Expand Down Expand Up @@ -453,7 +453,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda"
"cuda1"
INPUT_TYPE
"stablehlo"
RUNNER_ARGS
Expand Down Expand Up @@ -534,7 +534,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda2"
"cuda"
COMPILER_FLAGS
"--iree-stream-emulate-memset"
"--iree-hal-cuda-enable-legacy-sync=false"
Expand Down Expand Up @@ -618,7 +618,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda2"
"cuda"
COMPILER_FLAGS
"--iree-hal-cuda-enable-legacy-sync=false"
INPUT_TYPE
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/tosa_ops/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ iree_check_single_backend_test_suite(
"--iree-stream-emulate-memset",
"--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda2",
driver = "cuda",
input_type = "tosa",
runner_args = ["--cuda2_use_streams=false"],
tags = [
Expand All @@ -323,7 +323,7 @@ iree_check_single_backend_test_suite(
compiler_flags = [
"--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda2",
driver = "cuda",
input_type = "tosa",
runner_args = ["--cuda2_use_streams=true"],
tags = [
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/tosa_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda2"
"cuda"
COMPILER_FLAGS
"--iree-stream-emulate-memset"
"--iree-hal-cuda-enable-legacy-sync=false"
Expand Down Expand Up @@ -332,7 +332,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda2"
"cuda"
COMPILER_FLAGS
"--iree-hal-cuda-enable-legacy-sync=false"
INPUT_TYPE
Expand Down
Loading