Skip to content

Commit

Permalink
[cuda] Try to switch cuda2 on as the default
Browse files Browse the repository at this point in the history
  • Loading branch information
antiagainst committed Jan 12, 2024
1 parent c27ed41 commit 84396e3
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 36 deletions.
12 changes: 6 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ option(IREE_HAL_DRIVER_DEFAULTS "Sets the default value for all runtime HAL driv
# not cross compiling. Note: a CUDA-compatible GPU with drivers is still
# required to actually run CUDA workloads.
set(IREE_HAL_DRIVER_CUDA_DEFAULT ${IREE_HAL_DRIVER_DEFAULTS})
set(IREE_HAL_DRIVER_CUDA2_DEFAULT OFF)
set(IREE_HAL_DRIVER_CUDA2_DEFAULT ${IREE_HAL_DRIVER_DEFAULTS})
if(NOT IREE_CUDA_AVAILABLE OR CMAKE_CROSSCOMPILING)
set(IREE_HAL_DRIVER_CUDA_DEFAULT OFF)
set(IREE_HAL_DRIVER_CUDA2_DEFAULT OFF)
Expand All @@ -259,8 +259,8 @@ if(NOT APPLE OR NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(IREE_HAL_DRIVER_METAL_DEFAULT OFF)
endif()

option(IREE_HAL_DRIVER_CUDA "Enables the 'cuda' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA_DEFAULT})
option(IREE_HAL_DRIVER_CUDA2 "Enables the 'cuda2' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA2_DEFAULT})
option(IREE_HAL_DRIVER_CUDA1 "Enables the 'cuda' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA_DEFAULT})
option(IREE_HAL_DRIVER_CUDA "Enables the 'cuda2' runtime HAL driver" ${IREE_HAL_DRIVER_CUDA2_DEFAULT})
option(IREE_HAL_DRIVER_LOCAL_SYNC "Enables the 'local-sync' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
option(IREE_HAL_DRIVER_LOCAL_TASK "Enables the 'local-task' runtime HAL driver" ${IREE_HAL_DRIVER_DEFAULTS})
option(IREE_HAL_DRIVER_VULKAN "Enables the 'vulkan' runtime HAL driver" ${IREE_HAL_DRIVER_VULKAN_DEFAULT})
Expand Down Expand Up @@ -313,12 +313,12 @@ if(IREE_BUILD_COMPILER)
endif()

message(STATUS "IREE HAL drivers:")
if(IREE_HAL_DRIVER_CUDA1)
message(STATUS " - cuda1")
endif()
if(IREE_HAL_DRIVER_CUDA)
message(STATUS " - cuda")
endif()
if(IREE_HAL_DRIVER_CUDA2)
message(STATUS " - cuda2")
endif()
if(IREE_HAL_DRIVER_LOCAL_SYNC)
message(STATUS " - local-sync")
endif()
Expand Down
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/cuda/cts/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

iree_hal_cts_test_suite(
DRIVER_NAME
cuda
cuda1
DRIVER_REGISTRATION_HDR
"runtime/src/iree/hal/drivers/cuda/registration/driver_module.h"
DRIVER_REGISTRATION_FN
Expand All @@ -28,7 +28,7 @@ iree_hal_cts_test_suite(
# Variant test suite using graph command buffers (--cuda_use_streams=0)
iree_hal_cts_test_suite(
DRIVER_NAME
cuda
cuda1
VARIANT_SUFFIX
graph
DRIVER_REGISTRATION_HDR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ static iree_status_t iree_hal_cuda_driver_factory_enumerate(
const iree_hal_driver_info_t** out_driver_infos) {
// NOTE: we could query supported cuda versions or featuresets here.
static const iree_hal_driver_info_t driver_infos[1] = {{
.driver_name = iree_string_view_literal("cuda"),
.full_name = iree_string_view_literal("CUDA (dynamic)"),
.driver_name = iree_string_view_literal("cuda1"),
.full_name = iree_string_view_literal("deprecated CUDA HAL driver (dynamic)"),
}};
*out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
*out_driver_infos = driver_infos;
Expand All @@ -79,7 +79,7 @@ static iree_status_t iree_hal_cuda_driver_factory_try_create(
iree_hal_driver_t** out_driver) {
IREE_ASSERT_ARGUMENT(out_driver);
*out_driver = NULL;
if (!iree_string_view_equal(driver_name, IREE_SV("cuda"))) {
if (!iree_string_view_equal(driver_name, IREE_SV("cuda1"))) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
"no driver '%.*s' is provided by this factory",
(int)driver_name.size, driver_name.data);
Expand Down
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/cuda2/cts/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

iree_hal_cts_test_suite(
DRIVER_NAME
cuda2
cuda
VARIANT_SUFFIX
graph
DRIVER_REGISTRATION_HDR
Expand All @@ -31,7 +31,7 @@ iree_hal_cts_test_suite(

iree_hal_cts_test_suite(
DRIVER_NAME
cuda2
cuda
VARIANT_SUFFIX
stream
DRIVER_REGISTRATION_HDR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ static iree_status_t iree_hal_cuda2_driver_factory_enumerate(
IREE_TRACE_ZONE_BEGIN(z0);

static const iree_hal_driver_info_t driver_infos[1] = {{
.driver_name = IREE_SVL("cuda2"),
.full_name = IREE_SVL("next-gen NVIDIA CUDA HAL driver (via dylib)"),
.driver_name = IREE_SVL("cuda"),
.full_name = IREE_SVL("NVIDIA CUDA HAL driver (via dylib)"),
}};
*out_driver_info_count = IREE_ARRAYSIZE(driver_infos);
*out_driver_infos = driver_infos;
Expand All @@ -85,7 +85,7 @@ static iree_status_t iree_hal_cuda2_driver_factory_try_create(
iree_hal_driver_t** out_driver) {
IREE_ASSERT_ARGUMENT(out_driver);

if (!iree_string_view_equal(driver_name, IREE_SV("cuda2"))) {
if (!iree_string_view_equal(driver_name, IREE_SV("cuda"))) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
"no driver '%.*s' is provided by this factory",
(int)driver_name.size, driver_name.data);
Expand Down
16 changes: 8 additions & 8 deletions runtime/src/iree/hal/drivers/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
#include "iree/hal/drivers/init.h"

#if defined(IREE_HAVE_HAL_CUDA_DRIVER_MODULE)
#include "iree/hal/drivers/cuda/registration/driver_module.h"
#include "iree/hal/drivers/cuda2/registration/driver_module.h"
#endif // IREE_HAVE_HAL_CUDA_DRIVER_MODULE

#if defined(IREE_HAVE_HAL_CUDA2_DRIVER_MODULE)
#include "iree/hal/drivers/cuda2/registration/driver_module.h"
#endif // IREE_HAVE_HAL_CUDA2_DRIVER_MODULE
#if defined(IREE_HAVE_HAL_CUDA1_DRIVER_MODULE)
#include "iree/hal/drivers/cuda/registration/driver_module.h"
#endif // IREE_HAVE_HAL_CUDA1_DRIVER_MODULE

#if defined(IREE_HAVE_HAL_LOCAL_SYNC_DRIVER_MODULE)
#include "iree/hal/drivers/local_sync/registration/driver_module.h"
Expand Down Expand Up @@ -47,13 +47,13 @@ iree_hal_register_all_available_drivers(iree_hal_driver_registry_t* registry) {

#if defined(IREE_HAVE_HAL_CUDA_DRIVER_MODULE)
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda_driver_module_register(registry));
z0, iree_hal_cuda2_driver_module_register(registry));
#endif // IREE_HAVE_HAL_CUDA_DRIVER_MODULE

#if defined(IREE_HAVE_HAL_CUDA2_DRIVER_MODULE)
#if defined(IREE_HAVE_HAL_CUDA1_DRIVER_MODULE)
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_cuda2_driver_module_register(registry));
#endif // IREE_HAVE_HAL_CUDA2_DRIVER_MODULE
z0, iree_hal_cuda_driver_module_register(registry));
#endif // IREE_HAVE_HAL_CUDA_DRIVER_MODULE

#if defined(IREE_HAVE_HAL_LOCAL_SYNC_DRIVER_MODULE)
IREE_RETURN_AND_END_ZONE_IF_ERROR(
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/stablehlo_ops/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ iree_check_single_backend_test_suite(
# TODO(#13984): memset emulation required for graphs.
"--iree-stream-emulate-memset",
],
driver = "cuda",
driver = "cuda1",
input_type = "stablehlo",
runner_args = ["--cuda_use_streams=false"],
tags = [
Expand Down Expand Up @@ -493,7 +493,7 @@ iree_check_single_backend_test_suite(
include = ["*.mlir"],
exclude = [],
),
driver = "cuda",
driver = "cuda1",
input_type = "stablehlo",
runner_args = ["--cuda_use_streams=true"],
tags = [
Expand Down Expand Up @@ -582,7 +582,7 @@ iree_check_single_backend_test_suite(
"--iree-stream-emulate-memset",
"--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda2",
driver = "cuda",
input_type = "stablehlo",
runner_args = ["--cuda2_use_streams=false"],
tags = [
Expand All @@ -602,7 +602,7 @@ iree_check_single_backend_test_suite(
compiler_flags = [
"--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda2",
driver = "cuda",
input_type = "stablehlo",
runner_args = ["--cuda2_use_streams=true"],
tags = [
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/stablehlo_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda"
"cuda1"
COMPILER_FLAGS
"--iree-stream-emulate-memset"
INPUT_TYPE
Expand Down Expand Up @@ -447,7 +447,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda"
"cuda1"
INPUT_TYPE
"stablehlo"
RUNNER_ARGS
Expand Down Expand Up @@ -527,7 +527,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda2"
"cuda"
COMPILER_FLAGS
"--iree-stream-emulate-memset"
"--iree-hal-cuda-enable-legacy-sync=false"
Expand Down Expand Up @@ -610,7 +610,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda2"
"cuda"
COMPILER_FLAGS
"--iree-hal-cuda-enable-legacy-sync=false"
INPUT_TYPE
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/tosa_ops/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ iree_check_single_backend_test_suite(
"--iree-stream-emulate-memset",
"--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda2",
driver = "cuda",
input_type = "tosa",
runner_args = ["--cuda2_use_streams=false"],
tags = [
Expand All @@ -323,7 +323,7 @@ iree_check_single_backend_test_suite(
compiler_flags = [
"--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda2",
driver = "cuda",
input_type = "tosa",
runner_args = ["--cuda2_use_streams=true"],
tags = [
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/tosa_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda2"
"cuda"
COMPILER_FLAGS
"--iree-stream-emulate-memset"
"--iree-hal-cuda-enable-legacy-sync=false"
Expand Down Expand Up @@ -332,7 +332,7 @@ iree_check_single_backend_test_suite(
TARGET_BACKEND
"cuda"
DRIVER
"cuda2"
"cuda"
COMPILER_FLAGS
"--iree-hal-cuda-enable-legacy-sync=false"
INPUT_TYPE
Expand Down

0 comments on commit 84396e3

Please sign in to comment.