From 756060cc534f1d6e234a6ba32166e941e3aaa490 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 14 Oct 2025 14:49:23 +0100 Subject: [PATCH 1/6] Address some TODOs --- src/torchcodec/_core/BetaCudaDeviceInterface.cpp | 1 - src/torchcodec/_core/NVDECCache.h | 7 ++----- src/torchcodec/decoders/_video_decoder.py | 3 --- test/utils.py | 6 +++--- 4 files changed, 5 insertions(+), 12 deletions(-) diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 78fa8d635..779aaa5ce 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -221,7 +221,6 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() { } if (videoParser_) { - // TODONVDEC P2: consider caching this? Does DALI do that? cuvidDestroyVideoParser(videoParser_); videoParser_ = nullptr; } diff --git a/src/torchcodec/_core/NVDECCache.h b/src/torchcodec/_core/NVDECCache.h index 17fc99902..25dd4adf0 100644 --- a/src/torchcodec/_core/NVDECCache.h +++ b/src/torchcodec/_core/NVDECCache.h @@ -68,11 +68,8 @@ class NVDECCache { CacheKey(const CacheKey&) = default; CacheKey& operator=(const CacheKey&) = default; - // TODONVDEC P2: we only implement operator< which is enough for std::map, - // but: - // - we should consider using std::unordered_map - // - we should consider a more sophisticated and potentially less strict - // cache key comparison logic + // TODONVDEC P2: consider a more sophisticated and potentially + // less strict cache key comparison logic bool operator<(const CacheKey& other) const { return std::tie( codecType, diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 729fd4727..31473ddf0 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -145,9 +145,6 @@ def __init__( # If device looks like "cuda:0:beta", make it "cuda:0" and set # device_variant to "beta" - # TODONVDEC P2 Consider alternative ways of exposing custom device - # variants, and if we want this new decoder backend to be a "device - # variant" at all. device_variant = "default" if device is not None: device_split = device.split(":") diff --git a/test/utils.py b/test/utils.py index 7c91f307c..a4eb0a35c 100644 --- a/test/utils.py +++ b/test/utils.py @@ -750,7 +750,7 @@ def sample_format(self) -> str: def supports_approximate_mode(asset: TestVideo) -> bool: - # TODONVDEC P2: open an issue about his. That's actually not related to - # NVDEC at all, those don't support approximate mode because they don't set - # a duration. CPU decoder fails too! + # Those are missing the `duration` field so they fail in approximate mode (on all devices). + # TODO: we should address this, see + # https://github.com/meta-pytorch/torchcodec/issues/945 return asset not in (AV1_VIDEO, TEST_SRC_2_720P_VP9, TEST_SRC_2_720P_VP8) From aa382fc85d6b5ac3ff600225fcde939a37fde1a0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 14 Oct 2025 17:52:20 +0100 Subject: [PATCH 2/6] remove more --- test/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/utils.py b/test/utils.py index a4eb0a35c..7fb00ab59 100644 --- a/test/utils.py +++ b/test/utils.py @@ -41,9 +41,6 @@ def unsplit_device_str(device_str: str) -> str: # It is used: # - before calling `.to(device)` where device can't be "cuda:0:beta" # - before calling add_video_stream(device=device, device_variant=device_variant) - # - # TODONVDEC P2: Find a less clunky way to test the BETA CUDA interface. It - # will ultimately depend on how we want to publicly expose it. if device_str == "cuda:0:beta": return "cuda", "beta" else: From 4723ecd370c684d6d252bd21b919a7ad5370625d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 14 Oct 2025 18:04:56 +0100 Subject: [PATCH 3/6] more --- src/torchcodec/_core/NVDECCache.cpp | 2 +- src/torchcodec/_core/NVDECCache.h | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/torchcodec/_core/NVDECCache.cpp b/src/torchcodec/_core/NVDECCache.cpp index 87ab5b0dc..6bc708e9c 100644 --- a/src/torchcodec/_core/NVDECCache.cpp +++ b/src/torchcodec/_core/NVDECCache.cpp @@ -27,7 +27,7 @@ NVDECCache& NVDECCache::getCache(int deviceIndex) { deviceIndex); static NVDECCache cacheInstances[MAX_CUDA_GPUS]; if (deviceIndex == -1) { - // TODO NVDEC P3: Unify with existing getNonNegativeDeviceIndex() + // TODONVDEC P3: Unify with existing getNonNegativeDeviceIndex() TORCH_CHECK( cudaGetDevice(&deviceIndex) == cudaSuccess, "Failed to get current CUDA device."); diff --git a/src/torchcodec/_core/NVDECCache.h b/src/torchcodec/_core/NVDECCache.h index 25dd4adf0..5ba58fb35 100644 --- a/src/torchcodec/_core/NVDECCache.h +++ b/src/torchcodec/_core/NVDECCache.h @@ -68,8 +68,6 @@ class NVDECCache { CacheKey(const CacheKey&) = default; CacheKey& operator=(const CacheKey&) = default; - // TODONVDEC P2: consider a more sophisticated and potentially - // less strict cache key comparison logic bool operator<(const CacheKey& other) const { return std::tie( codecType, From 633c4b3f592d3ed090846080c626b0d59551655e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 15 Oct 2025 13:04:10 +0100 Subject: [PATCH 4/6] WIP --- .../_core/BetaCudaDeviceInterface.cpp | 9 ++++--- src/torchcodec/_core/CUDACommon.cpp | 26 ++++++++++++++----- src/torchcodec/_core/CUDACommon.h | 7 ++++- src/torchcodec/_core/Cache.h | 24 +++-------------- src/torchcodec/_core/CudaDeviceInterface.cpp | 9 +++---- src/torchcodec/_core/NVDECCache.cpp | 16 +++--------- src/torchcodec/_core/NVDECCache.h | 3 ++- 7 files changed, 43 insertions(+), 51 deletions(-) diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 779aaa5ce..adf6add28 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -216,8 +216,8 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() { // unclear. flush(); unmapPreviousFrame(); - NVDECCache::getCache(device_.index()) - .returnDecoder(&videoFormat_, std::move(decoder_)); + NVDECCache::getCache(device_).returnDecoder( + &videoFormat_, std::move(decoder_)); } if (videoParser_) { @@ -361,11 +361,12 @@ int BetaCudaDeviceInterface::streamPropertyChange(CUVIDEOFORMAT* videoFormat) { } if (!decoder_) { - decoder_ = NVDECCache::getCache(device_.index()).getDecoder(videoFormat); + decoder_ = NVDECCache::getCache(device_).getDecoder(videoFormat); if (!decoder_) { // TODONVDEC P2: consider re-configuring an existing decoder instead of - // re-creating one. See docs, see DALI. + // re-creating one. See docs, see DALI. Re-configuration doesn't seem to + // be enabled in DALI by default. decoder_ = createDecoder(videoFormat); } diff --git a/src/torchcodec/_core/CUDACommon.cpp b/src/torchcodec/_core/CUDACommon.cpp index 4f3664031..2698d1291 100644 --- a/src/torchcodec/_core/CUDACommon.cpp +++ b/src/torchcodec/_core/CUDACommon.cpp @@ -10,9 +10,6 @@ namespace facebook::torchcodec { namespace { -// Pytorch can only handle up to 128 GPUs. -// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44 -const int MAX_CUDA_GPUS = 128; // Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching. // Set to a positive number to have a cache of that size. const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; @@ -249,7 +246,7 @@ torch::Tensor convertNV12FrameToRGB( } UniqueNppContext getNppStreamContext(const torch::Device& device) { - torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device); + int deviceIndex = getDeviceIndex(device); UniqueNppContext nppCtx = g_cached_npp_ctxs.get(device); if (nppCtx) { @@ -266,13 +263,13 @@ UniqueNppContext getNppStreamContext(const torch::Device& device) { nppCtx = std::make_unique(); cudaDeviceProp prop{}; - cudaError_t err = cudaGetDeviceProperties(&prop, nonNegativeDeviceIndex); + cudaError_t err = cudaGetDeviceProperties(&prop, deviceIndex); TORCH_CHECK( err == cudaSuccess, "cudaGetDeviceProperties failed: ", cudaGetErrorString(err)); - nppCtx->nCudaDeviceId = nonNegativeDeviceIndex; + nppCtx->nCudaDeviceId = deviceIndex; nppCtx->nMultiProcessorCount = prop.multiProcessorCount; nppCtx->nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor; nppCtx->nMaxThreadsPerBlock = prop.maxThreadsPerBlock; @@ -312,4 +309,21 @@ void validatePreAllocatedTensorShape( } } +int getDeviceIndex(const torch::Device& device) { + // PyTorch uses int8_t as its torch::DeviceIndex, but FFmpeg and CUDA + // libraries use int. So we use int, too. + int deviceIndex = static_cast(device.index()); + TORCH_CHECK( + deviceIndex >= -1 && deviceIndex < MAX_CUDA_GPUS, + "Invalid device index = ", + deviceIndex); + + if (deviceIndex == -1) { + TORCH_CHECK( + cudaGetDevice(&deviceIndex) == cudaSuccess, + "Failed to get current CUDA device."); + } + return deviceIndex; +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CUDACommon.h b/src/torchcodec/_core/CUDACommon.h index b935cd4bf..588f60e49 100644 --- a/src/torchcodec/_core/CUDACommon.h +++ b/src/torchcodec/_core/CUDACommon.h @@ -11,7 +11,6 @@ #include #include -#include "src/torchcodec/_core/Cache.h" #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/Frame.h" @@ -22,6 +21,10 @@ extern "C" { namespace facebook::torchcodec { +// Pytorch can only handle up to 128 GPUs. +// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44 +constexpr int MAX_CUDA_GPUS = 128; + void initializeCudaContextWithPytorch(const torch::Device& device); // Unique pointer type for NPP stream context @@ -43,4 +46,6 @@ void validatePreAllocatedTensorShape( const std::optional& preAllocatedOutputTensor, const UniqueAVFrame& avFrame); +int getDeviceIndex(const torch::Device& device); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Cache.h b/src/torchcodec/_core/Cache.h index 7b088a145..475455a20 100644 --- a/src/torchcodec/_core/Cache.h +++ b/src/torchcodec/_core/Cache.h @@ -9,6 +9,7 @@ #include #include #include +#include "src/torchcodec/_core/CUDACommon.h" namespace facebook::torchcodec { @@ -95,30 +96,11 @@ class PerGpuCache { std::vector>> cache_; }; -// Note: this function is inline for convenience, not performance. Because the -// rest of this file is template functions, they must all be defined in this -// header. This function is not a template function, and should, in principle, -// be defined in a .cpp file to preserve the One Definition Rule. That's -// annoying for such a small amount of code, so we just inline it. If this file -// grows, and there are more such functions, we should break them out into a -// .cpp file. -inline torch::DeviceIndex getNonNegativeDeviceIndex( - const torch::Device& device) { - torch::DeviceIndex deviceIndex = device.index(); - // For single GPU machines libtorch returns -1 for the device index. So for - // that case we set the device index to 0. That's used in per-gpu cache - // implementation and during initialization of CUDA and FFmpeg contexts - // which require non negative indices. - deviceIndex = std::max(deviceIndex, 0); - TORCH_CHECK(deviceIndex >= 0, "Device index out of range"); - return deviceIndex; -} - template bool PerGpuCache::addIfCacheHasCapacity( const torch::Device& device, element_type&& obj) { - torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device); + int deviceIndex = getDeviceIndex(device); TORCH_CHECK( static_cast(deviceIndex) < cache_.size(), "Device index out of range"); @@ -128,7 +110,7 @@ bool PerGpuCache::addIfCacheHasCapacity( template typename PerGpuCache::element_type PerGpuCache::get( const torch::Device& device) { - torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device); + int deviceIndex = getDeviceIndex(device); TORCH_CHECK( static_cast(deviceIndex) < cache_.size(), "Device index out of range"); diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index aea2b2d9a..aee1ecd07 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -32,9 +32,6 @@ static bool g_cuda = registerDeviceInterface( // from // the cache. If the cache is empty we create a new cuda context. -// Pytorch can only handle up to 128 GPUs. -// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44 -const int MAX_CUDA_GPUS = 128; // Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching. // Set to a positive number to have a cache of that size. const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; @@ -54,7 +51,7 @@ int getFlagsAVHardwareDeviceContextCreate() { UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) { enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); - torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device); + int deviceIndex = getDeviceIndex(device); UniqueAVBufferRef hardwareDeviceCtx = g_cached_hw_device_ctxs.get(device); if (hardwareDeviceCtx) { @@ -68,9 +65,9 @@ UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) { // So we ensure the deviceIndex is not negative. // We set the device because we may be called from a different thread than // the one that initialized the cuda context. - cudaSetDevice(nonNegativeDeviceIndex); + cudaSetDevice(deviceIndex); AVBufferRef* hardwareDeviceCtxRaw = nullptr; - std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex); + std::string deviceOrdinal = std::to_string(deviceIndex); int err = av_hwdevice_ctx_create( &hardwareDeviceCtxRaw, diff --git a/src/torchcodec/_core/NVDECCache.cpp b/src/torchcodec/_core/NVDECCache.cpp index 6bc708e9c..477302842 100644 --- a/src/torchcodec/_core/NVDECCache.cpp +++ b/src/torchcodec/_core/NVDECCache.cpp @@ -7,6 +7,7 @@ #include #include +#include "src/torchcodec/_core/CUDACommon.h" #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/NVDECCache.h" @@ -19,19 +20,10 @@ extern "C" { namespace facebook::torchcodec { -NVDECCache& NVDECCache::getCache(int deviceIndex) { - const int MAX_CUDA_GPUS = 128; - TORCH_CHECK( - deviceIndex >= -1 && deviceIndex < MAX_CUDA_GPUS, - "Invalid device index = ", - deviceIndex); +NVDECCache& NVDECCache::getCache(const torch::Device& device) { static NVDECCache cacheInstances[MAX_CUDA_GPUS]; - if (deviceIndex == -1) { - // TODONVDEC P3: Unify with existing getNonNegativeDeviceIndex() - TORCH_CHECK( - cudaGetDevice(&deviceIndex) == cudaSuccess, - "Failed to get current CUDA device."); - } + + int deviceIndex = getDeviceIndex(device); return cacheInstances[deviceIndex]; } diff --git a/src/torchcodec/_core/NVDECCache.h b/src/torchcodec/_core/NVDECCache.h index 5ba58fb35..b248ebc68 100644 --- a/src/torchcodec/_core/NVDECCache.h +++ b/src/torchcodec/_core/NVDECCache.h @@ -11,6 +11,7 @@ #include #include +#include #include "src/torchcodec/_core/nvcuvid_include/cuviddec.h" #include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h" @@ -36,7 +37,7 @@ using UniqueCUvideodecoder = // per GPU device, and it is accessed through the static getCache() method. class NVDECCache { public: - static NVDECCache& getCache(int deviceIndex); + static NVDECCache& getCache(const torch::Device& device); // Get decoder from cache - returns nullptr if none available UniqueCUvideodecoder getDecoder(CUVIDEOFORMAT* videoFormat); From 85d58fbb4c00f0244ea737a563a9963d1cd6e782 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 15 Oct 2025 13:07:25 +0100 Subject: [PATCH 5/6] Create common g_cached_npp_ctxs --- src/torchcodec/_core/CUDACommon.cpp | 1 + src/torchcodec/_core/NVDECCache.cpp | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/CUDACommon.cpp b/src/torchcodec/_core/CUDACommon.cpp index 2698d1291..4532e3c76 100644 --- a/src/torchcodec/_core/CUDACommon.cpp +++ b/src/torchcodec/_core/CUDACommon.cpp @@ -5,6 +5,7 @@ // LICENSE file in the root directory of this source tree. #include "src/torchcodec/_core/CUDACommon.h" +#include "src/torchcodec/_core/Cache.h" // for PerGpuCache namespace facebook::torchcodec { diff --git a/src/torchcodec/_core/NVDECCache.cpp b/src/torchcodec/_core/NVDECCache.cpp index 477302842..302433cd4 100644 --- a/src/torchcodec/_core/NVDECCache.cpp +++ b/src/torchcodec/_core/NVDECCache.cpp @@ -22,9 +22,7 @@ namespace facebook::torchcodec { NVDECCache& NVDECCache::getCache(const torch::Device& device) { static NVDECCache cacheInstances[MAX_CUDA_GPUS]; - - int deviceIndex = getDeviceIndex(device); - return cacheInstances[deviceIndex]; + return cacheInstances[getDeviceIndex(device)]; } UniqueCUvideodecoder NVDECCache::getDecoder(CUVIDEOFORMAT* videoFormat) { From a24501cf802916adf2428fee8ae0de0b4128fb23 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 15 Oct 2025 13:14:00 +0100 Subject: [PATCH 6/6] Avoid circular dep --- src/torchcodec/_core/Cache.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/_core/Cache.h b/src/torchcodec/_core/Cache.h index 475455a20..b2c93e8ea 100644 --- a/src/torchcodec/_core/Cache.h +++ b/src/torchcodec/_core/Cache.h @@ -9,7 +9,6 @@ #include #include #include -#include "src/torchcodec/_core/CUDACommon.h" namespace facebook::torchcodec { @@ -96,6 +95,11 @@ class PerGpuCache { std::vector>> cache_; }; +// Forward declaration of getDeviceIndex which exists in CUDACommon.h +// This avoids circular dependency between Cache.h and CUDACommon.cpp which also +// needs to include Cache.h +int getDeviceIndex(const torch::Device& device); + template bool PerGpuCache::addIfCacheHasCapacity( const torch::Device& device,