diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index f0b4cbfcc..095c82e27 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -35,16 +35,20 @@ static bool g_cuda_beta = registerDeviceInterface( static int CUDAAPI pfnSequenceCallback(void* pUserData, CUVIDEOFORMAT* videoFormat) { - BetaCudaDeviceInterface* decoder = - static_cast(pUserData); + auto decoder = static_cast(pUserData); return decoder->streamPropertyChange(videoFormat); } static int CUDAAPI -pfnDecodePictureCallback(void* pUserData, CUVIDPICPARAMS* pPicParams) { - BetaCudaDeviceInterface* decoder = - static_cast(pUserData); - return decoder->frameReadyForDecoding(pPicParams); +pfnDecodePictureCallback(void* pUserData, CUVIDPICPARAMS* picParams) { + auto decoder = static_cast(pUserData); + return decoder->frameReadyForDecoding(picParams); +} + +static int CUDAAPI +pfnDisplayPictureCallback(void* pUserData, CUVIDPARSERDISPINFO* dispInfo) { + auto decoder = static_cast(pUserData); + return decoder->frameReadyInDisplayOrder(dispInfo); } static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) { @@ -142,7 +146,7 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device) BetaCudaDeviceInterface::~BetaCudaDeviceInterface() { // TODONVDEC P0: we probably need to free the frames that have been decoded by - // NVDEC but not yet "mapped" - i.e. those that are still in frameBuffer_? + // NVDEC but not yet "mapped" - i.e. those that are still in readyFrames_? if (decoder_) { NVDECCache::getCache(device_.index()) @@ -218,7 +222,7 @@ void BetaCudaDeviceInterface::initialize(const AVStream* avStream) { parserParams.pUserData = this; parserParams.pfnSequenceCallback = pfnSequenceCallback; parserParams.pfnDecodePicture = pfnDecodePictureCallback; - parserParams.pfnDisplayPicture = nullptr; + parserParams.pfnDisplayPicture = pfnDisplayPictureCallback; CUresult result = cuvidCreateVideoParser(&videoParser_, &parserParams); TORCH_CHECK( @@ -274,10 +278,6 @@ int BetaCudaDeviceInterface::sendPacket(ReferenceAVPacket& packet) { cuvidPacket.flags = CUVID_PKT_TIMESTAMP; cuvidPacket.timestamp = packet->pts; - // Like DALI: store packet PTS in queue to later assign to frames as they - // come out - packetsPtsQueue.push(packet->pts); - } else { // End of stream packet cuvidPacket.flags = CUVID_PKT_ENDOFSTREAM; @@ -329,70 +329,38 @@ void BetaCudaDeviceInterface::applyBSF(ReferenceAVPacket& packet) { // ready to be decoded, i.e. the parser received all the necessary packets for a // given frame. It means we can send that frame to be decoded by the hardware // NVDEC decoder by calling cuvidDecodePicture which is non-blocking. -int BetaCudaDeviceInterface::frameReadyForDecoding(CUVIDPICPARAMS* pPicParams) { +int BetaCudaDeviceInterface::frameReadyForDecoding(CUVIDPICPARAMS* picParams) { if (isFlushing_) { return 0; } - TORCH_CHECK(pPicParams != nullptr, "Invalid picture parameters"); + TORCH_CHECK(picParams != nullptr, "Invalid picture parameters"); TORCH_CHECK(decoder_, "Decoder not initialized before picture decode"); // Send frame to be decoded by NVDEC - non-blocking call. - CUresult result = cuvidDecodePicture(*decoder_.get(), pPicParams); - if (result != CUDA_SUCCESS) { - return 0; // Yes, you're reading that right, 0 mean error. - } + CUresult result = cuvidDecodePicture(*decoder_.get(), picParams); - // The frame was sent to be decoded on the NVDEC hardware. Now we store some - // relevant info into our frame buffer so that we can retrieve the decoded - // frame later when receiveFrame() is called. - // Importantly we need to 'guess' the PTS of that frame. The heuristic we use - // (like in DALI) is that the frames are ready to be decoded in the same order - // as the packets were sent to the parser. So we assign the PTS of the frame - // by popping the PTS of the oldest packet in our packetsPtsQueue (note: - // oldest doesn't necessarily mean lowest PTS!). + // Yes, you're reading that right, 0 means error, 1 means success + return (result == CUDA_SUCCESS); +} - TORCH_CHECK( - // TODONVDEC P0 the queue may be empty, handle that. - !packetsPtsQueue.empty(), - "PTS queue is empty when decoding a frame"); - int64_t guessedPts = packetsPtsQueue.front(); - packetsPtsQueue.pop(); - - // Field values taken from DALI - CUVIDPARSERDISPINFO dispInfo = {}; - dispInfo.picture_index = pPicParams->CurrPicIdx; - dispInfo.progressive_frame = !pPicParams->field_pic_flag; - dispInfo.top_field_first = pPicParams->bottom_field_flag ^ 1; - dispInfo.repeat_first_field = 0; - dispInfo.timestamp = guessedPts; - - FrameBuffer::Slot* slot = frameBuffer_.findEmptySlot(); - slot->dispInfo = dispInfo; - slot->guessedPts = guessedPts; - slot->occupied = true; - - return 1; +int BetaCudaDeviceInterface::frameReadyInDisplayOrder( + CUVIDPARSERDISPINFO* dispInfo) { + readyFrames_.push(*dispInfo); + return 1; // success } -// Moral equivalent of avcodec_receive_frame(). Here, we look for a decoded -// frame with the exact desired PTS in our frame buffer. This logic is only -// valid in exact seek_mode, for now. -int BetaCudaDeviceInterface::receiveFrame( - UniqueAVFrame& avFrame, - int64_t desiredPts) { - FrameBuffer::Slot* slot = frameBuffer_.findFrameWithExactPts(desiredPts); - if (slot == nullptr) { +// Moral equivalent of avcodec_receive_frame(). +int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) { + if (readyFrames_.empty()) { // No frame found, instruct caller to try again later after sending more // packets. return AVERROR(EAGAIN); } - - slot->occupied = false; - slot->guessedPts = -1; + CUVIDPARSERDISPINFO dispInfo = readyFrames_.front(); + readyFrames_.pop(); CUVIDPROCPARAMS procParams = {}; - CUVIDPARSERDISPINFO dispInfo = slot->dispInfo; procParams.progressive_frame = dispInfo.progressive_frame; procParams.top_field_first = dispInfo.top_field_first; procParams.unpaired_field = dispInfo.repeat_first_field < 0; @@ -452,7 +420,7 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame( avFrame->width = width; avFrame->height = height; avFrame->format = AV_PIX_FMT_CUDA; - avFrame->pts = dispInfo.timestamp; // == guessedPts + avFrame->pts = dispInfo.timestamp; // TODONVDEC P0: Zero division error!!! // TODONVDEC P0: Move AVRational arithmetic to FFMPEGCommon, and put the @@ -518,13 +486,8 @@ void BetaCudaDeviceInterface::flush() { isFlushing_ = false; - for (auto& slot : frameBuffer_) { - slot.occupied = false; - slot.guessedPts = -1; - } - - std::queue empty; - packetsPtsQueue.swap(empty); + std::queue emptyQueue; + std::swap(readyFrames_, emptyQueue); eofSent_ = false; } @@ -544,26 +507,4 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( avFrame, frameOutput, preAllocatedOutputTensor); } -BetaCudaDeviceInterface::FrameBuffer::Slot* -BetaCudaDeviceInterface::FrameBuffer::findEmptySlot() { - for (auto& slot : frameBuffer_) { - if (!slot.occupied) { - return &slot; - } - } - frameBuffer_.emplace_back(); - return &frameBuffer_.back(); -} - -BetaCudaDeviceInterface::FrameBuffer::Slot* -BetaCudaDeviceInterface::FrameBuffer::findFrameWithExactPts( - int64_t desiredPts) { - for (auto& slot : frameBuffer_) { - if (slot.occupied && slot.guessedPts == desiredPts) { - return &slot; - } - } - return nullptr; -} - } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index b19112f0d..d5f436b14 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -50,51 +50,18 @@ class BetaCudaDeviceInterface : public DeviceInterface { } int sendPacket(ReferenceAVPacket& packet) override; - int receiveFrame(UniqueAVFrame& avFrame, int64_t desiredPts) override; + int receiveFrame(UniqueAVFrame& avFrame) override; void flush() override; // NVDEC callback functions (must be public for C callbacks) int streamPropertyChange(CUVIDEOFORMAT* videoFormat); - int frameReadyForDecoding(CUVIDPICPARAMS* pPicParams); + int frameReadyForDecoding(CUVIDPICPARAMS* picParams); + int frameReadyInDisplayOrder(CUVIDPARSERDISPINFO* dispInfo); private: // Apply bitstream filter, modifies packet in-place void applyBSF(ReferenceAVPacket& packet); - class FrameBuffer { - public: - struct Slot { - CUVIDPARSERDISPINFO dispInfo; - int64_t guessedPts; - bool occupied = false; - - Slot() : guessedPts(-1), occupied(false) { - std::memset(&dispInfo, 0, sizeof(dispInfo)); - } - }; - - // TODONVDEC P1: init size should probably be min_num_decode_surfaces from - // video format - FrameBuffer() : frameBuffer_(4) {} - - ~FrameBuffer() = default; - - Slot* findEmptySlot(); - Slot* findFrameWithExactPts(int64_t desiredPts); - - // Iterator support for range-based for loops - auto begin() { - return frameBuffer_.begin(); - } - - auto end() { - return frameBuffer_.end(); - } - - private: - std::vector frameBuffer_; - }; - UniqueAVFrame convertCudaFrameToAVFrame( CUdeviceptr framePtr, unsigned int pitch, @@ -104,9 +71,7 @@ class BetaCudaDeviceInterface : public DeviceInterface { UniqueCUvideodecoder decoder_; CUVIDEOFORMAT videoFormat_ = {}; - FrameBuffer frameBuffer_; - - std::queue packetsPtsQueue; + std::queue readyFrames_; bool eofSent_ = false; @@ -125,3 +90,92 @@ class BetaCudaDeviceInterface : public DeviceInterface { }; } // namespace facebook::torchcodec + +/* clang-format off */ +// Note: [General design, sendPacket, receiveFrame, frame ordering and NVCUVID callbacks] +// +// At a high level, this decoding interface mimics the FFmpeg send/receive +// architecture: +// - sendPacket(AVPacket) sends an AVPacket from the FFmpeg demuxer to the +// NVCUVID parser. +// - receiveFrame(AVFrame) is a non-blocking call: +// - if a frame is ready **in display order**, it must return it. By display +// order, we mean that receiveFrame() must return frames with increasing pts +// values when called successively. +// - if no frame is ready, it must return AVERROR(EAGAIN) to indicate the +// caller should send more packets. +// +// The rest of this note assumes you have a reasonable level of familiarity with +// the sendPacket/receiveFrame calling pattern. If you don't, look up the core +// decoding loop in SingleVideoDecoder. +// +// The frame re-ordering problem: +// Depending on the codec and on the encoding parameters, a packet from a video +// stream may contain exactly one frame, more than one frame, or a fraction of a +// frame. And, there may be non-linear frame dependencies because of B-frames, +// which need both past *and* future frames to be decoded. Consider the +// following stream, with frames presented in display order: I0 B1 P2 B3 P4 ... +// - I0 is an I-frame (also key frame, can be decoded independently) +// - B1 is a B-frame (bi-directional) which needs both I0 and P2 to be decoded +// - P2 is a P-frame (predicted frame) which only needs I0 to be decodec. +// +// Because B1 needs both I0 and P2 to be properly decoded, the decode order +// (packet order), defined by the encoder, must be: I0 P2 B1 P4 B3 ... which is +// different from the display order. +// +// SendPacket(AVPacket)'s job is just to pass down the packet to the NVCUVID +// parser by calling cuvidParseVideoData(packet). When +// cuvidParseVideoData(packet) is called, it may trigger callbacks, +// particularly: +// - streamPropertyChange(videoFormat): triggered once at the start of the +// stream, and possibly later if the stream properties change (e.g. +// resolution). +// - frameReadyForDecoding(picParams)): triggered **in decode order** when the +// parser has accumulated enough data to decode a frame. We send that frame to +// the NVDEC hardware for **async** decoding. +// - frameReadyInDisplayOrder(dispInfo)): triggered **in display order** when a +// frame is ready to be "displayed" (returned). At that point, the parser also +// gives us the pts of that frame. We store (a reference to) that frame in a +// FIFO queue: readyFrames_. +// +// When receiveFrame(AVFrame) is called, if readyFrames_ is not empty, we pop +// the front of the queue, which is the next frame in display order, and map it +// to an AVFrame by calling cuvidMapVideoFrame(). If readyFrames_ is empty we +// return EAGAIN to indicate the caller should send more packets. +// +// There is potentially a small inefficiency due to the callback design: in +// order for us to know that a frame is ready in display order, we need the +// frameReadyInDisplayOrder callback to be triggered. This can only happen +// within cuvidParseVideoData(packet) in sendPacket(). This means there may be +// the following sequence of calls: +// +// sendPacket(relevantAVPacket) +// cuvidParseVideoData(relevantAVPacket) +// frameReadyForDecoding() +// cuvidDecodePicture() Send frame to NVDEC for async decoding +// +// receiveFrame() -> EAGAIN Frame is potentially already decoded +// and could be returned, but we don't +// know because frameReadyInDisplayOrder +// hasn't been triggered yet. We'll only +// know after sending another, +// potentially irrelevant packet. +// +// sendPacket(irrelevantAVPacket) +// cuvidParseVideoData(irrelevantAVPacket) +// frameReadyInDisplayOrder() Only now do we know that our target +// frame is ready. +// +// receiveFrame() return target frame +// +// How much this matters in practice is unclear, but probably negligible in +// general. Particularly when frames are decoded consecutively anyway, the +// "irrelevantPacket" is actually relevant for a future target frame. +// +// Note that the alternative is to *not* rely on the frameReadyInDisplayOrder +// callback. It's technically possible, but it would mean we now have to solve +// two hard, *codec-dependent* problems that the callback was solving for us: +// - we have to guess the frame's pts ourselves +// - we have to re-order the frames ourselves to preserve display order. +// +/* clang-format on */ diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 08d94fddc..b7d5ef07a 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -101,9 +101,7 @@ class DeviceInterface { // Moral equivalent of avcodec_receive_frame() // Returns AVSUCCESS on success, AVERROR(EAGAIN) if no frame ready, // AVERROR_EOF if end of stream, or other AVERROR on failure - virtual int receiveFrame( - [[maybe_unused]] UniqueAVFrame& avFrame, - [[maybe_unused]] int64_t desiredPts) { + virtual int receiveFrame([[maybe_unused]] UniqueAVFrame& avFrame) { TORCH_CHECK( false, "Send/receive packet decoding not implemented for this device interface"); diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 62be222f6..afa852e48 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -1189,7 +1189,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( // avcodec_send_packet. This would make the decoding loop even more generic. while (true) { if (deviceInterface_->canDecodePacketDirectly()) { - status = deviceInterface_->receiveFrame(avFrame, cursor_); + status = deviceInterface_->receiveFrame(avFrame); } else { status = avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get()); diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 2e7a3250d..5e058fd0b 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -155,12 +155,6 @@ def __init__( device_variant = device_split[2] device = ":".join(device_split[0:2]) - # TODONVDEC P0 Support approximate mode. Not ideal to validate that here - # either, but validating this at a lower level forces to add yet another - # (temprorary) validation API to the device inteface - if device_variant == "beta" and seek_mode != "exact": - raise ValueError("Seek mode must be exact for BETA CUDA interface.") - core.add_video_stream( self._decoder, stream_index=stream_index, diff --git a/test/test_decoders.py b/test/test_decoders.py index c4cd55dd7..7ffc67566 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -1417,9 +1417,14 @@ def test_get_frames_at_tensor_indices(self): @needs_cuda @pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE)) @pytest.mark.parametrize("contiguous_indices", (True, False)) - def test_beta_cuda_interface_get_frame_at(self, asset, contiguous_indices): - ref_decoder = VideoDecoder(asset.path, device="cuda") - beta_decoder = VideoDecoder(asset.path, device="cuda:0:beta") + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_beta_cuda_interface_get_frame_at( + self, asset, contiguous_indices, seek_mode + ): + ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) + beta_decoder = VideoDecoder( + asset.path, device="cuda:0:beta", seek_mode=seek_mode + ) assert ref_decoder.metadata == beta_decoder.metadata @@ -1442,9 +1447,14 @@ def test_beta_cuda_interface_get_frame_at(self, asset, contiguous_indices): @needs_cuda @pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE)) @pytest.mark.parametrize("contiguous_indices", (True, False)) - def test_beta_cuda_interface_get_frames_at(self, asset, contiguous_indices): - ref_decoder = VideoDecoder(asset.path, device="cuda") - beta_decoder = VideoDecoder(asset.path, device="cuda:0:beta") + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_beta_cuda_interface_get_frames_at( + self, asset, contiguous_indices, seek_mode + ): + ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) + beta_decoder = VideoDecoder( + asset.path, device="cuda:0:beta", seek_mode=seek_mode + ) assert ref_decoder.metadata == beta_decoder.metadata @@ -1465,16 +1475,78 @@ def test_beta_cuda_interface_get_frames_at(self, asset, contiguous_indices): beta_frames.duration_seconds, ref_frames.duration_seconds ) + @needs_cuda + @pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE)) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode): + ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) + beta_decoder = VideoDecoder( + asset.path, device="cuda:0:beta", seek_mode=seek_mode + ) + + assert ref_decoder.metadata == beta_decoder.metadata + + timestamps = torch.linspace( + 0, ref_decoder.metadata.duration_seconds - 1e-4, steps=10 + ) + for pts in timestamps: + ref_frame = ref_decoder.get_frame_played_at(pts) + beta_frame = beta_decoder.get_frame_played_at(pts) + torch.testing.assert_close(beta_frame.data, ref_frame.data, rtol=0, atol=0) + + assert beta_frame.pts_seconds == ref_frame.pts_seconds + assert beta_frame.duration_seconds == ref_frame.duration_seconds + + @needs_cuda + @pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE)) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode): + ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) + beta_decoder = VideoDecoder( + asset.path, device="cuda:0:beta", seek_mode=seek_mode + ) + + assert ref_decoder.metadata == beta_decoder.metadata + + timestamps = torch.linspace( + 0, ref_decoder.metadata.duration_seconds - 1e-4, steps=10 + ).tolist() + + ref_frames = ref_decoder.get_frames_played_at(timestamps) + beta_frames = beta_decoder.get_frames_played_at(timestamps) + torch.testing.assert_close(beta_frames.data, ref_frames.data, rtol=0, atol=0) + torch.testing.assert_close(beta_frames.pts_seconds, ref_frames.pts_seconds) + torch.testing.assert_close( + beta_frames.duration_seconds, ref_frames.duration_seconds + ) + + @needs_cuda + @pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE)) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_beta_cuda_interface_backwards(self, asset, seek_mode): + + ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) + beta_decoder = VideoDecoder( + asset.path, device="cuda:0:beta", seek_mode=seek_mode + ) + + assert ref_decoder.metadata == beta_decoder.metadata + + for frame_index in [0, 100, 10, 50, 20, 200, 150, 389]: + frame_index = min(frame_index, len(ref_decoder) - 1) + ref_frame = ref_decoder.get_frame_at(frame_index) + beta_frame = beta_decoder.get_frame_at(frame_index) + torch.testing.assert_close(beta_frame.data, ref_frame.data, rtol=0, atol=0) + + assert beta_frame.pts_seconds == ref_frame.pts_seconds + assert beta_frame.duration_seconds == ref_frame.duration_seconds + @needs_cuda def test_beta_cuda_interface_error(self): with pytest.raises(RuntimeError, match="Can only do H264 for now"): VideoDecoder(AV1_VIDEO.path, device="cuda:0:beta") with pytest.raises(RuntimeError, match="Can only do H264 for now"): VideoDecoder(H265_VIDEO.path, device="cuda:0:beta") - with pytest.raises( - ValueError, match="Seek mode must be exact for BETA CUDA interface." - ): - VideoDecoder(NASA_VIDEO.path, device="cuda:0:beta", seek_mode="approximate") with pytest.raises(RuntimeError, match="Unsupported device"): VideoDecoder(NASA_VIDEO.path, device="cuda:0:bad_variant")