diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 8ae6a3959..78fa8d635 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -203,10 +203,7 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device) TORCH_CHECK( device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); - // Initialize CUDA context with a dummy tensor - torch::Tensor dummyTensorForCudaInitialization = torch::empty( - {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); - + initializeCudaContextWithPytorch(device_); nppCtx_ = getNppStreamContext(device_); } diff --git a/src/torchcodec/_core/CUDACommon.cpp b/src/torchcodec/_core/CUDACommon.cpp index ee2f63c28..4f3664031 100644 --- a/src/torchcodec/_core/CUDACommon.cpp +++ b/src/torchcodec/_core/CUDACommon.cpp @@ -23,6 +23,14 @@ PerGpuCache g_cached_npp_ctxs( } // namespace +void initializeCudaContextWithPytorch(const torch::Device& device) { + // It is important for pytorch itself to create the cuda context. If ffmpeg + // creates the context it may not be compatible with pytorch. + // This is a dummy tensor to initialize the cuda context. + torch::Tensor dummyTensorForCudaInitialization = torch::zeros( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); +} + /* clang-format off */ // Note: [YUV -> RGB Color Conversion, color space and color range] // diff --git a/src/torchcodec/_core/CUDACommon.h b/src/torchcodec/_core/CUDACommon.h index b4c081885..b935cd4bf 100644 --- a/src/torchcodec/_core/CUDACommon.h +++ b/src/torchcodec/_core/CUDACommon.h @@ -22,6 +22,8 @@ extern "C" { namespace facebook::torchcodec { +void initializeCudaContextWithPytorch(const torch::Device& device); + // Unique pointer type for NPP stream context using UniqueNppContext = std::unique_ptr; diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index e8df0a608..2d09864a1 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -129,11 +129,10 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device) TORCH_CHECK( device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); - // It is important for pytorch itself to create the cuda context. If ffmpeg - // creates the context it may not be compatible with pytorch. - // This is a dummy tensor to initialize the cuda context. - torch::Tensor dummyTensorForCudaInitialization = torch::empty( - {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); + initializeCudaContextWithPytorch(device_); + + // TODO rename this, this is a hardware device context, not a CUDA context! + // See https://github.com/meta-pytorch/torchcodec/issues/924 ctx_ = getCudaContext(device_); nppCtx_ = getNppStreamContext(device_); }