From d211fffebe8d2343d51bf58bdccd6ec4f9ea1608 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 8 Oct 2025 15:59:16 +0100 Subject: [PATCH] BETA CUDA interface: Fix CUDA context initialization --- src/torchcodec/_core/BetaCudaDeviceInterface.cpp | 5 +---- src/torchcodec/_core/CUDACommon.cpp | 8 ++++++++ src/torchcodec/_core/CUDACommon.h | 2 ++ src/torchcodec/_core/CudaDeviceInterface.cpp | 9 ++++----- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 8ae6a3959..78fa8d635 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -203,10 +203,7 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device) TORCH_CHECK( device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); - // Initialize CUDA context with a dummy tensor - torch::Tensor dummyTensorForCudaInitialization = torch::empty( - {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); - + initializeCudaContextWithPytorch(device_); nppCtx_ = getNppStreamContext(device_); } diff --git a/src/torchcodec/_core/CUDACommon.cpp b/src/torchcodec/_core/CUDACommon.cpp index ee2f63c28..4f3664031 100644 --- a/src/torchcodec/_core/CUDACommon.cpp +++ b/src/torchcodec/_core/CUDACommon.cpp @@ -23,6 +23,14 @@ PerGpuCache g_cached_npp_ctxs( } // namespace +void initializeCudaContextWithPytorch(const torch::Device& device) { + // It is important for pytorch itself to create the cuda context. If ffmpeg + // creates the context it may not be compatible with pytorch. + // This is a dummy tensor to initialize the cuda context. + torch::Tensor dummyTensorForCudaInitialization = torch::zeros( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); +} + /* clang-format off */ // Note: [YUV -> RGB Color Conversion, color space and color range] // diff --git a/src/torchcodec/_core/CUDACommon.h b/src/torchcodec/_core/CUDACommon.h index b4c081885..b935cd4bf 100644 --- a/src/torchcodec/_core/CUDACommon.h +++ b/src/torchcodec/_core/CUDACommon.h @@ -22,6 +22,8 @@ extern "C" { namespace facebook::torchcodec { +void initializeCudaContextWithPytorch(const torch::Device& device); + // Unique pointer type for NPP stream context using UniqueNppContext = std::unique_ptr; diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index e8df0a608..2d09864a1 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -129,11 +129,10 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device) TORCH_CHECK( device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); - // It is important for pytorch itself to create the cuda context. If ffmpeg - // creates the context it may not be compatible with pytorch. - // This is a dummy tensor to initialize the cuda context. - torch::Tensor dummyTensorForCudaInitialization = torch::empty( - {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); + initializeCudaContextWithPytorch(device_); + + // TODO rename this, this is a hardware device context, not a CUDA context! + // See https://github.com/meta-pytorch/torchcodec/issues/924 ctx_ = getCudaContext(device_); nppCtx_ = getNppStreamContext(device_); }