From 8e49959df2bd5cc8991bc92c732d2378c1927e4f Mon Sep 17 00:00:00 2001 From: Banit Agrawal Date: Wed, 4 Oct 2023 19:52:15 -0700 Subject: [PATCH] [CUDA Host Allocator] Add support of CudaHostRegister (#108488) Summary: This diff adds another option to create cuda pinned memory using cudaHostRegister to avoid large lock wait time with cudaHostAlloc. Differential Revision: D45843715 --- aten/src/ATen/cuda/CachingHostAllocator.cpp | 98 ++++++++++++++++++++- c10/cuda/CUDAAllocatorConfig.cpp | 50 ++++++++++- c10/cuda/CUDAAllocatorConfig.h | 26 +++++- docs/source/notes/cuda.rst | 12 +++ test/test_cuda.py | 7 ++ torch/utils/hipify/cuda_to_hip_mappings.py | 1 + 6 files changed, 188 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index a6646779cfbec..36531b6412771 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include #include @@ -174,11 +176,18 @@ class CUDAHostAllocator { } // Round up the allocation to the nearest power of two to improve reuse. + size_t roundSize = c10::llvm::PowerOf2Ceil(size); void* ptr = nullptr; - C10_CUDA_CHECK(cudaHostAlloc( - &ptr, c10::llvm::PowerOf2Ceil(size), cudaHostAllocDefault)); + if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: + pinned_use_cuda_host_register()) { + allocWithCudaHostRegister(&ptr, roundSize); + } else { + // Use cudaHostAlloc for allocating pinned memory (global lock in driver) + C10_CUDA_CHECK(cudaHostAlloc(&ptr, roundSize, cudaHostAllocDefault)); + } + auto block = new Block(); - block->size_ = c10::llvm::PowerOf2Ceil(size); + block->size_ = roundSize; block->ptr_ = ptr; block->allocated_ = true; @@ -279,7 +288,14 @@ class CUDAHostAllocator { for (auto* block : blocks_to_remove) { blocks_.erase(block); ptr_to_block_.erase(block->ptr_); - AT_CUDA_CHECK(cudaFreeHost(block->ptr_)); + if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: + pinned_use_cuda_host_register()) { + void* ptr = block->ptr_; + AT_CUDA_CHECK(cudaHostUnregister(ptr)); + free(ptr); + } else { + AT_CUDA_CHECK(cudaFreeHost(block->ptr_)); + } delete block; } } @@ -343,6 +359,80 @@ class CUDAHostAllocator { } } + TaskThreadPool* getThreadPool() { + static TaskThreadPool* pool = new TaskThreadPool( + c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: + pinned_max_register_threads()); + return pool; + } + + void mapPagesForRegister( + const void* ptr, + size_t size, + size_t i, + size_t numThreads, + size_t pageSize) { + uintptr_t start = (uintptr_t)ptr + (size * i / numThreads); + uintptr_t end = (uintptr_t)start + (size / numThreads); + if (i == (numThreads - 1)) { + end = (uintptr_t)ptr + size; + } + + // pre-fault/map the pages by setting the first byte of the page + uintptr_t alignedStart = + (((uintptr_t)start + pageSize - 1) & ~(pageSize - 1)); + for (uintptr_t p = alignedStart; p < ((uintptr_t)end); p += pageSize) { + memset((void*)p, 0, 1); + } + } + + void registerPages(const void* ptr, size_t size) { + AT_CUDA_CHECK( + cudaHostRegister((void*)ptr, (size_t)size, cudaHostRegisterDefault)); + + // If host and device pointer don't match, give a warning and exit + void* devptr; + AT_CUDA_CHECK(cudaHostGetDevicePointer(&devptr, (void*)ptr, 0)); + TORCH_CHECK( + (void*)devptr == (void*)ptr, + "Host and device pointer dont match with cudaHostRegister. " + "Please dont use this feature by setting " + "PYTORCH_PINNED_ALLOC_CONF=use_cuda_host_register:False (default)", + ""); + } + + inline void allocWithCudaHostRegister(void** ptr, size_t roundSize) { + // Here we do regular allocation, pre-fault/map the pages, and then do + // cudaHostRegister with GPU mapping flags to lock the pages, so we + // can minimize the cost for the cuda global lock. + *ptr = malloc(roundSize); + + // Parallelize the mapping/registering of pages to reduce wall time + size_t pageSize = (1 << 12); // 4kB pages + size_t numMapThreads = c10::cuda::CUDACachingAllocator:: + CUDAAllocatorConfig::pinned_num_register_threads(); + if ((numMapThreads > 1) && (roundSize >= (pageSize * numMapThreads))) { + auto* pool = getThreadPool(); + for (size_t i = 0; i < numMapThreads; i++) { + pool->run(std::bind( + &CUDAHostAllocator::mapPagesForRegister, + this, + *ptr, + roundSize, + i, // thread task-id + numMapThreads, + pageSize)); + } + pool->waitWorkComplete(); + } else { + // Map pages in the same thread + mapPagesForRegister(*ptr, roundSize, 0, 1, pageSize); + } + + // Register the mapped pages using cudaHostRegister + registerPages(*ptr, roundSize); + } + EventPool event_pool_; alignas(64) std::mutex blocks_mutex_; diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index 2d4bee6db69c6..9b463fcb2a86b 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -13,8 +13,10 @@ constexpr size_t kRoundUpPowerOfTwoIntervals = 16; CUDAAllocatorConfig::CUDAAllocatorConfig() : m_max_split_size(std::numeric_limits::max()), m_garbage_collection_threshold(0), + m_pinned_num_register_threads(1), m_expandable_segments(false), - m_release_lock_on_cudamalloc(false) { + m_release_lock_on_cudamalloc(false), + m_pinned_use_cuda_host_register(false) { m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); } @@ -270,6 +272,12 @@ void CUDAAllocatorConfig::parseArgs(const char* env) { i < config.size() && (config[i] == "True" || config[i] == "False"), "Expected a single True/False argument for release_lock_on_cudamalloc"); m_release_lock_on_cudamalloc = (config[i] == "True"); + } else if (config[i].compare("pinned_use_cuda_host_register") == 0) { + i = parsePinnedUseCudaHostRegister(config, i); + used_native_specific_option = true; + } else if (config[i].compare("pinned_num_register_threads") == 0) { + i = parsePinnedNumRegisterThreads(config, i); + used_native_specific_option = true; } else { TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", config[i]); } @@ -286,6 +294,46 @@ void CUDAAllocatorConfig::parseArgs(const char* env) { } } +size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + TORCH_CHECK( + (config[i] == "True" || config[i] == "False"), + "Expected a single True/False argument for pinned_use_cuda_host_register"); + m_pinned_use_cuda_host_register = (config[i] == "True"); + } else { + TORCH_CHECK( + false, "Error, expecting pinned_use_cuda_host_register value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + size_t val2 = stoi(config[i]); + TORCH_CHECK( + llvm::isPowerOf2_64(val2), + "Number of register threads has to be power of 2 ", + ""); + auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads(); + TORCH_CHECK( + val2 <= maxThreads, + "Number of register threads should be less than or equal to " + + std::to_string(maxThreads), + ""); + m_pinned_num_register_threads = val2; + } else { + TORCH_CHECK( + false, "Error, expecting pinned_num_register_threads value", ""); + } + return i; +} + // General caching allocator utilities void setAllocatorSettings(const std::string& env) { CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str()); diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index 58f056f5d9ebd..fd2b973ccad24 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -15,7 +15,7 @@ namespace cuda { namespace CUDACachingAllocator { // Environment config parser -class CUDAAllocatorConfig { +class C10_CUDA_API CUDAAllocatorConfig { public: static size_t max_split_size() { return instance().m_max_split_size; @@ -39,6 +39,22 @@ class CUDAAllocatorConfig { return instance().m_release_lock_on_cudamalloc; } + /** Pinned memory allocator settings */ + static bool pinned_use_cuda_host_register() { + return instance().m_pinned_use_cuda_host_register; + } + + static size_t pinned_num_register_threads() { + return instance().m_pinned_num_register_threads; + } + + static size_t pinned_max_register_threads() { + // Based on the benchmark results, we see better allocation performance + // with 8 threads. However on future systems, we may need more threads + // and limiting this to 128 threads. + return 128; + } + // This is used to round-up allocation size to nearest power of 2 divisions. // More description below in function roundup_power2_next_division // As ane example, if we want 4 divisions between 2's power, this can be done @@ -76,12 +92,20 @@ class CUDAAllocatorConfig { const std::vector& config, size_t i, bool& used_cudaMallocAsync); + size_t parsePinnedUseCudaHostRegister( + const std::vector& config, + size_t i); + size_t parsePinnedNumRegisterThreads( + const std::vector& config, + size_t i); std::atomic m_max_split_size; std::vector m_roundup_power2_divisions; std::atomic m_garbage_collection_threshold; + std::atomic m_pinned_num_register_threads; std::atomic m_expandable_segments; std::atomic m_release_lock_on_cudamalloc; + std::atomic m_pinned_use_cuda_host_register; }; // General caching allocator utilities diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index f6ac16ae5c1e1..0b49f58153651 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -470,6 +470,18 @@ Available options: appended to the end of the segment. This process does not create as many slivers of unusable memory, so it is more likely to succeed at finding this memory. + `pinned_use_cuda_host_register` option is a boolean flag that determines whether to + use the CUDA API's cudaHostRegister function for allocating pinned memory instead + of the default cudaHostAlloc. When set to True, the memory is allocated using regular + malloc and then pages are mapped to the memory before calling cudaHostRegister. + This pre-mapping of pages helps reduce the lock time during the execution + of cudaHostRegister. + + `pinned_num_register_threads` option is only valid when pinned_use_cuda_host_register + is set to True. By default, one thread is used to map the pages. This option allows + using more threads to parallelize the page mapping operations to reduce the overall + allocation time of pinned memory. A good value for this option is 8 based on + benchmarking results. .. note:: diff --git a/test/test_cuda.py b/test/test_cuda.py index 70f94e585ee90..b1429d1bbed0f 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -76,6 +76,13 @@ def tearDown(self): del self.autocast_lists super().tearDown() + def test_pinned_memory_with_cudaregister(self): + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "pinned_use_cuda_host_register:True,pinned_num_register_threads:8" + t = torch.ones(20) + self.assertFalse(t.is_pinned()) + pinned_t = torch.ones(1 << 16).pin_memory() + self.assertTrue(pinned_t.is_pinned()) + def test_cudart_register(self): t = torch.ones(20) self.assertFalse(t.is_pinned()) diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index c812ce4425b3b..382cf08280e9f 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -8610,6 +8610,7 @@ ("CUDACachingAllocator", ("HIPCachingAllocator", API_C10)), ("cuda::CUDAAllocatorConfig", ("hip::HIPAllocatorConfig", API_C10)), ("CUDAAllocatorConfig", ("HIPAllocatorConfig", API_C10)), + ("pinned_use_cuda_host_register", ("pinned_use_hip_host_register", API_C10)), ("c10::cuda::CUDAAllocator", ("c10::hip::HIPAllocator", API_C10)), ("cuda::CUDAAllocator", ("hip::HIPAllocator", API_C10)), ("CUDAAllocator", ("HIPAllocator", API_C10)),