Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,32 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
}

torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device_);
nppCtx_->hStream = at::cuda::getCurrentCUDAStream(deviceIndex).stream();

// Create a CUDA event and attach it to the AVFrame's CUDA stream. That's the
// NVDEC stream, i.e. the CUDA stream that the frame was decoded on.
// We will be waiting for this event to complete before calling the NPP
// functions, to ensure NVDEC has finished decoding the frame before running
// the NPP color-conversion.
// Note that our code is generic and assumes that the NVDEC's stream can be
// arbitrary, but unfortunately we know it's hardcoded to be the default
// stream by FFmpeg:
// https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
TORCH_CHECK(
hwFramesCtx->device_ctx != nullptr,
"The AVFrame's hw_frames_ctx does not have a device_ctx. ");
auto cudaDeviceCtx =
static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
at::cuda::CUDAEvent nvdecDoneEvent;
at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
nvdecDoneEvent.record(nvdecStream);

// Don't start NPP work before NVDEC is done decoding the frame!
at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream(deviceIndex);
nvdecDoneEvent.block(nppStream);

// Create the NPP context if we haven't yet.
nppCtx_->hStream = nppStream.stream();
cudaError_t err =
cudaStreamGetFlags(nppCtx_->hStream, &nppCtx_->nStreamFlags);
TORCH_CHECK(
Expand Down
Loading