diff --git a/src/torchcodec/decoders/__init__.py b/src/torchcodec/decoders/__init__.py index 307f18f43..c11ce2c82 100644 --- a/src/torchcodec/decoders/__init__.py +++ b/src/torchcodec/decoders/__init__.py @@ -7,4 +7,6 @@ from ._core import VideoStreamMetadata from ._video_decoder import VideoDecoder # noqa +# from ._audio_decoder import AudioDecoder # Will be public when more stable + SimpleVideoDecoder = VideoDecoder diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py new file mode 100644 index 000000000..42229d976 --- /dev/null +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +from typing import Literal, Optional, Tuple, Union + +from torch import Tensor + +from torchcodec.decoders import _core as core + +_ERROR_REPORTING_INSTRUCTIONS = """ +This should never happen. Please report an issue following the steps in +https://github.com/pytorch/torchcodec/issues/new?assignees=&labels=&projects=&template=bug-report.yml. +""" + + +class AudioDecoder: + """A single-stream audio decoder. + + TODO docs + """ + + def __init__( + self, + source: Union[str, Path, bytes, Tensor], + *, + sample_rate: Optional[int] = None, + stream_index: Optional[int] = None, + seek_mode: Literal["exact", "approximate"] = "exact", + ): + if sample_rate is not None: + raise ValueError("TODO implement this") + + # TODO unify validation with VideoDecoder? + allowed_seek_modes = ("exact", "approximate") + if seek_mode not in allowed_seek_modes: + raise ValueError( + f"Invalid seek mode ({seek_mode}). " + f"Supported values are {', '.join(allowed_seek_modes)}." + ) + + if isinstance(source, str): + self._decoder = core.create_from_file(source, seek_mode) + elif isinstance(source, Path): + self._decoder = core.create_from_file(str(source), seek_mode) + elif isinstance(source, bytes): + self._decoder = core.create_from_bytes(source, seek_mode) + elif isinstance(source, Tensor): + self._decoder = core.create_from_tensor(source, seek_mode) + else: + raise TypeError( + f"Unknown source type: {type(source)}. " + "Supported types are str, Path, bytes and Tensor." + ) + + core.add_audio_stream(self._decoder, stream_index=stream_index) + + self.metadata, self.stream_index = _get_and_validate_stream_metadata( + self._decoder, stream_index + ) + + # if self.metadata.num_frames is None: + # raise ValueError( + # "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS + # ) + # self._num_frames = self.metadata.num_frames + + # if self.metadata.begin_stream_seconds is None: + # raise ValueError( + # "The minimum pts value in seconds is unknown. " + # + _ERROR_REPORTING_INSTRUCTIONS + # ) + # self._begin_stream_seconds = self.metadata.begin_stream_seconds + + # if self.metadata.end_stream_seconds is None: + # raise ValueError( + # "The maximum pts value in seconds is unknown. " + # + _ERROR_REPORTING_INSTRUCTIONS + # ) + # self._end_stream_seconds = self.metadata.end_stream_seconds + + # TODO we need to have a default for stop_seconds. + def get_samples_played_in_range( + self, start_seconds: float, stop_seconds: float + ) -> Tensor: + """ + TODO DOCS + """ + # if not start_seconds <= stop_seconds: + # raise ValueError( + # f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})." + # ) + # if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds: + # raise ValueError( + # f"Invalid start seconds: {start_seconds}. " + # f"It must be greater than or equal to {self._begin_stream_seconds} " + # f"and less than or equal to {self._end_stream_seconds}." + # ) + # if not stop_seconds <= self._end_stream_seconds: + # raise ValueError( + # f"Invalid stop seconds: {stop_seconds}. " + # f"It must be less than or equal to {self._end_stream_seconds}." + # ) + + frames, *_ = core.get_frames_by_pts_in_range( + self._decoder, + start_seconds=start_seconds, + stop_seconds=stop_seconds, + ) + # TODO need to return view on this to account for samples instead of + # frames + return frames + + +def _get_and_validate_stream_metadata( + decoder: Tensor, + stream_index: Optional[int] = None, +) -> Tuple[core.AudioStreamMetadata, int]: + + # TODO should this still be called `get_video_metadata`? + container_metadata = core.get_video_metadata(decoder) + + if stream_index is None: + best_stream_index = container_metadata.best_audio_stream_index + if best_stream_index is None: + raise ValueError( + "The best audio stream is unknown and there is no specified stream. " + + _ERROR_REPORTING_INSTRUCTIONS + ) + stream_index = best_stream_index + + # This should be logically true because of the above conditions, but type checker + # is not clever enough. + assert stream_index is not None + + stream_metadata = container_metadata.streams[stream_index] + return (stream_metadata, stream_index) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 688a249d5..2acd219bd 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -4,7 +4,9 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Torch REQUIRED) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}") +# TODO Put back normal flags +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall ${TORCH_CXX_FLAGS}") find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) function(make_torchcodec_library library_name ffmpeg_target) diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp index 5de4a4624..3ad6fdded 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp @@ -60,6 +60,26 @@ int64_t getDuration(const AVFrame* frame) { #endif } +int getNumChannels(const AVFrame* avFrame) { +#if LIBAVFILTER_VERSION_MAJOR > 8 || \ + (IBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44) + return avFrame->ch_layout.nb_channels; +#else + return av_get_channel_layout_nb_channels(avFrame->channel_layout); +#endif +} + +int getNumChannels(const UniqueAVCodecContext& avCodecContext) { +// Not sure about the exactness of the version bounds, but as long as this +// compile we're fine. +#if LIBAVFILTER_VERSION_MAJOR > 8 || \ + (IBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44) + return avCodecContext->ch_layout.nb_channels; +#else + return avCodecContext->channels; +#endif +} + AVIOBytesContext::AVIOBytesContext( const void* data, size_t data_size, diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.h b/src/torchcodec/decoders/_core/FFMPEGCommon.h index deabae52d..957c7f840 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.h +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.h @@ -139,6 +139,9 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode); int64_t getDuration(const UniqueAVFrame& frame); int64_t getDuration(const AVFrame* frame); +int getNumChannels(const AVFrame* avFrame); +int getNumChannels(const UniqueAVCodecContext& avCodecContext); + // Returns true if sws_scale can handle unaligned data. bool canSwsScaleHandleUnalignedData(); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 116379eb6..4c612b901 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -156,6 +156,9 @@ void VideoDecoder::initializeDecoder() { "Our stream index, " + std::to_string(i) + ", does not match AVStream's index, " + std::to_string(avStream->index) + "."); + + // TODO figure out audio metadata + streamMetadata.streamIndex = i; streamMetadata.mediaType = avStream->codecpar->codec_type; streamMetadata.codecName = avcodec_get_name(avStream->codecpar->codec_id); @@ -171,14 +174,24 @@ void VideoDecoder::initializeDecoder() { av_q2d(avStream->time_base) * avStream->duration; } - double fps = av_q2d(avStream->r_frame_rate); - if (fps > 0) { - streamMetadata.averageFps = fps; - } - if (avStream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) { + double fps = av_q2d(avStream->r_frame_rate); + if (fps > 0) { + streamMetadata.averageFps = fps; + } containerMetadata_.numVideoStreams++; } else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) { + int numSamplesPerFrame = avStream->codecpar->frame_size; + int sampleRate = avStream->codecpar->sample_rate; + if (numSamplesPerFrame > 0 && sampleRate > 0) { + // This should allow the approximate mode to do its magic. + // fps is numFrames / duration where + // - duration = numSamplesTotal / sampleRate and + // - numSamplesTotal = numSamplesPerFrame * numFrames + // so fps = numFrames * sampleRate / (numSamplesPerFrame * numFrames) + streamMetadata.averageFps = + static_cast(sampleRate) / numSamplesPerFrame; + } containerMetadata_.numAudioStreams++; } @@ -418,69 +431,99 @@ VideoDecoder::VideoStreamOptions::VideoStreamOptions( } } -void VideoDecoder::addVideoStream( +void VideoDecoder::addStream( int streamIndex, + AVMediaType mediaType, const VideoStreamOptions& videoStreamOptions) { TORCH_CHECK( activeStreamIndex_ == NO_ACTIVE_STREAM, "Can only add one single stream."); + TORCH_CHECK( + mediaType == AVMEDIA_TYPE_VIDEO || mediaType == AVMEDIA_TYPE_AUDIO, + "Can only add video or audio streams."); TORCH_CHECK(formatContext_.get() != nullptr); AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr; activeStreamIndex_ = av_find_best_stream( - formatContext_.get(), AVMEDIA_TYPE_VIDEO, streamIndex, -1, &avCodec, 0); + formatContext_.get(), mediaType, streamIndex, -1, &avCodec, 0); + if (activeStreamIndex_ < 0) { - throw std::invalid_argument("No valid stream found in input file."); + throw std::invalid_argument( + "No valid stream found in input file. Is " + + std::to_string(streamIndex) + " of the desired media type?"); } + TORCH_CHECK(avCodec != nullptr); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; streamInfo.streamIndex = activeStreamIndex_; streamInfo.timeBase = formatContext_->streams[activeStreamIndex_]->time_base; streamInfo.stream = formatContext_->streams[activeStreamIndex_]; + streamInfo.avMediaType = mediaType; - if (streamInfo.stream->codecpar->codec_type != AVMEDIA_TYPE_VIDEO) { - throw std::invalid_argument( - "Stream with index " + std::to_string(activeStreamIndex_) + - " is not a video stream."); - } - - if (videoStreamOptions.device.type() == torch::kCUDA) { + // This should never happen, checking just to be safe. + TORCH_CHECK( + streamInfo.stream->codecpar->codec_type == mediaType, + "FFmpeg found stream with index ", + activeStreamIndex_, + " which is of the wrong media type."); + + // TODO_CODE_QUALITY this is meh to have that in the middle + if (mediaType == AVMEDIA_TYPE_VIDEO && + videoStreamOptions.device.type() == torch::kCUDA) { avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( findCudaCodec( videoStreamOptions.device, streamInfo.stream->codecpar->codec_id) .value_or(avCodec)); } + // TODO_FRAME_SIZE_APPROXIMATE_MODE + // For audio, we raise if seek_mode="approximate" and if the number of + // samples per frame is unknown (frame_size field of codec params). But that's + // quite limitting. Ultimately, the most common type of call will be to decode + // an entire file from start to end (possibly with some offsets for start and + // end). And for that, we shouldn't [need to] force the user to scan, because + // all this entails is a single call to seek(start) (if at all) and then just + // a bunch of consecutive calls to getNextFrame(). Maybe there should be a + // third seek mode for audio, e.g. seek_mode="contiguous" where we don't scan, + // and only allow calls to getFramesPlayedAt(). StreamMetadata& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; if (seekMode_ == SeekMode::approximate && !streamMetadata.averageFps.has_value()) { - throw std::runtime_error( - "Seek mode is approximate, but stream " + - std::to_string(activeStreamIndex_) + - " does not have an average fps in its metadata."); + std::string errMsg = "Seek mode is approximate, but stream " + + std::to_string(activeStreamIndex_) + "does not have "; + if (mediaType == AVMEDIA_TYPE_VIDEO) { + errMsg += "an average fps in its metadata."; + } else { + errMsg += "a constant number of samples per frame."; + } + throw std::runtime_error(errMsg); } AVCodecContext* codecContext = avcodec_alloc_context3(avCodec); TORCH_CHECK(codecContext != nullptr); - codecContext->thread_count = videoStreamOptions.ffmpegThreadCount.value_or(0); + codecContext->thread_count = + videoStreamOptions.ffmpegThreadCount.value_or(0); // TODO VIDEO ONLY? streamInfo.codecContext.reset(codecContext); int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); TORCH_CHECK_EQ(retVal, AVSUCCESS); - if (videoStreamOptions.device.type() == torch::kCPU) { - // No more initialization needed for CPU. - } else if (videoStreamOptions.device.type() == torch::kCUDA) { - initializeContextOnCuda(videoStreamOptions.device, codecContext); - } else { - TORCH_CHECK( - false, "Invalid device type: " + videoStreamOptions.device.str()); + // TODO_CODE_QUALITY meh again + if (mediaType == AVMEDIA_TYPE_VIDEO) { + if (videoStreamOptions.device.type() == torch::kCPU) { + // No more initialization needed for CPU. + } else if (videoStreamOptions.device.type() == torch::kCUDA) { + initializeContextOnCuda(videoStreamOptions.device, codecContext); + } else { + TORCH_CHECK( + false, "Invalid device type: " + videoStreamOptions.device.str()); + } + streamInfo.videoStreamOptions = videoStreamOptions; } - streamInfo.videoStreamOptions = videoStreamOptions; retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); if (retVal < AVSUCCESS) { @@ -488,14 +531,8 @@ void VideoDecoder::addVideoStream( } codecContext->time_base = streamInfo.stream->time_base; - - containerMetadata_.allStreamMetadata[activeStreamIndex_].width = - codecContext->width; - containerMetadata_.allStreamMetadata[activeStreamIndex_].height = - codecContext->height; - auto codedId = codecContext->codec_id; containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName = - std::string(avcodec_get_name(codedId)); + std::string(avcodec_get_name(codecContext->codec_id)); // We will only need packets from the active stream, so we tell FFmpeg to // discard packets from the other streams. Note that av_read_frame() may still @@ -506,6 +543,18 @@ void VideoDecoder::addVideoStream( formatContext_->streams[i]->discard = AVDISCARD_ALL; } } +} + +void VideoDecoder::addVideoStream( + int streamIndex, + const VideoStreamOptions& videoStreamOptions) { + addStream(streamIndex, AVMEDIA_TYPE_VIDEO, videoStreamOptions); + + auto& streamInfo = streamInfos_[activeStreamIndex_]; + containerMetadata_.allStreamMetadata[activeStreamIndex_].width = + streamInfo.codecContext->width; + containerMetadata_.allStreamMetadata[activeStreamIndex_].height = + streamInfo.codecContext->height; // By default, we want to use swscale for color conversion because it is // faster. However, it has width requirements, so we may need to fall back @@ -514,7 +563,7 @@ void VideoDecoder::addVideoStream( // swscale's width requirements to be violated. We don't expose the ability to // choose color conversion library publicly; we only use this ability // internally. - int width = videoStreamOptions.width.value_or(codecContext->width); + int width = videoStreamOptions.width.value_or(streamInfo.codecContext->width); // swscale requires widths to be multiples of 32: // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements @@ -527,13 +576,25 @@ void VideoDecoder::addVideoStream( videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); } +void VideoDecoder::addAudioStream(int streamIndex) { + addStream(streamIndex, AVMEDIA_TYPE_AUDIO); + + // See TODO_FRAME_SIZE_BATCH_TENSOR_ALLOCATION + auto& streamInfo = streamInfos_[activeStreamIndex_]; + TORCH_CHECK( + streamInfo.codecContext->frame_size > 0, + "No support for variable framerate yet."); + containerMetadata_.allStreamMetadata[activeStreamIndex_].sampleRate = + streamInfo.codecContext->sample_rate; +} + // -------------------------------------------------------------------------- // HIGH-LEVEL DECODING ENTRY-POINTS // -------------------------------------------------------------------------- VideoDecoder::FrameOutput VideoDecoder::getNextFrame() { auto output = getNextFrameInternal(); - output.data = maybePermuteHWC2CHW(output.data); + output.data = maybePermuteOutputTensor(output.data); return output; } @@ -547,16 +608,15 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal( } VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) { + validateActiveStream(AVMEDIA_TYPE_VIDEO); auto frameOutput = getFrameAtIndexInternal(frameIndex); - frameOutput.data = maybePermuteHWC2CHW(frameOutput.data); + frameOutput.data = maybePermuteOutputTensor(frameOutput.data); return frameOutput; } VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal( int64_t frameIndex, std::optional preAllocatedOutputTensor) { - validateActiveStream(); - const auto& streamInfo = streamInfos_[activeStreamIndex_]; const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; @@ -569,7 +629,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal( VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( const std::vector& frameIndices) { - validateActiveStream(); + validateActiveStream(AVMEDIA_TYPE_VIDEO); auto indicesAreSorted = std::is_sorted(frameIndices.begin(), frameIndices.end()); @@ -592,10 +652,8 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; - const auto& streamInfo = streamInfos_[activeStreamIndex_]; - const auto& videoStreamOptions = streamInfo.videoStreamOptions; - FrameBatchOutput frameBatchOutput( - frameIndices.size(), videoStreamOptions, streamMetadata); + + FrameBatchOutput frameBatchOutput = makeFrameBatchOutput(frameIndices.size()); auto previousIndexInVideo = -1; for (size_t f = 0; f < frameIndices.size(); ++f) { @@ -622,17 +680,16 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( } previousIndexInVideo = indexInVideo; } - frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); + frameBatchOutput.data = maybePermuteOutputTensor(frameBatchOutput.data); return frameBatchOutput; } VideoDecoder::FrameBatchOutput VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) { - validateActiveStream(); + validateActiveStream(AVMEDIA_TYPE_VIDEO); const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; - const auto& streamInfo = streamInfos_[activeStreamIndex_]; int64_t numFrames = getNumFrames(streamMetadata); TORCH_CHECK( start >= 0, "Range start, " + std::to_string(start) + " is less than 0."); @@ -644,9 +701,8 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) { step > 0, "Step must be greater than 0; is " + std::to_string(step)); int64_t numOutputFrames = std::ceil((stop - start) / double(step)); - const auto& videoStreamOptions = streamInfo.videoStreamOptions; - FrameBatchOutput frameBatchOutput( - numOutputFrames, videoStreamOptions, streamMetadata); + + FrameBatchOutput frameBatchOutput = makeFrameBatchOutput(numOutputFrames); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { FrameOutput frameOutput = @@ -654,11 +710,12 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) { frameBatchOutput.ptsSeconds[f] = frameOutput.ptsSeconds; frameBatchOutput.durationSeconds[f] = frameOutput.durationSeconds; } - frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); + frameBatchOutput.data = maybePermuteOutputTensor(frameBatchOutput.data); return frameBatchOutput; } VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) { + validateActiveStream(AVMEDIA_TYPE_VIDEO); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; double frameStartTime = ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase); @@ -693,13 +750,13 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) { // Convert the frame to tensor. FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrameStream); - frameOutput.data = maybePermuteHWC2CHW(frameOutput.data); + frameOutput.data = maybePermuteOutputTensor(frameOutput.data); return frameOutput; } VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt( const std::vector& timestamps) { - validateActiveStream(); + validateActiveStream(AVMEDIA_TYPE_VIDEO); const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; @@ -732,17 +789,28 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( double stopSeconds) { validateActiveStream(); - const auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; + // Because we currently never seek with audio streams, we prevent users from + // calling this method twice. We could allow multiple calls in the future. + // Assuming 2 consecutive calls: + // ``` + // getFramesPlayedInRange(startSeconds1, stopSeconds1); + // getFramesPlayedInRange(startSeconds2, stopSeconds2); + // ``` + // We would need to seek back to 0 iff startSeconds2 <= stopSeconds1. This + // logic is not implemented for now, so we just error. + + TORCH_CHECK( + streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO || + !alreadyCalledGetFramesPlayedInRange_, + "Can only decode once with audio stream. Re-create a decoder object if needed.") + alreadyCalledGetFramesPlayedInRange_ = true; + TORCH_CHECK( startSeconds <= stopSeconds, "Start seconds (" + std::to_string(startSeconds) + ") must be less than or equal to stop seconds (" + std::to_string(stopSeconds) + "."); - const auto& streamInfo = streamInfos_[activeStreamIndex_]; - const auto& videoStreamOptions = streamInfo.videoStreamOptions; - // Special case needed to implement a half-open range. At first glance, this // may seem unnecessary, as our search for stopFrame can return the end, and // we don't include stopFramIndex in our output. However, consider the @@ -761,11 +829,13 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( // values of the intervals will map to the same frame indices below. Hence, we // need this special case below. if (startSeconds == stopSeconds) { - FrameBatchOutput frameBatchOutput(0, videoStreamOptions, streamMetadata); - frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); + FrameBatchOutput frameBatchOutput = makeFrameBatchOutput(0); + frameBatchOutput.data = maybePermuteOutputTensor(frameBatchOutput.data); return frameBatchOutput; } + const auto& streamMetadata = + containerMetadata_.allStreamMetadata[activeStreamIndex_]; double minSeconds = getMinSeconds(streamMetadata); double maxSeconds = getMaxSeconds(streamMetadata); TORCH_CHECK( @@ -796,15 +866,15 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( int64_t stopFrameIndex = secondsToIndexUpperBound(stopSeconds); int64_t numFrames = stopFrameIndex - startFrameIndex; - FrameBatchOutput frameBatchOutput( - numFrames, videoStreamOptions, streamMetadata); + FrameBatchOutput frameBatchOutput = makeFrameBatchOutput(numFrames); + for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { FrameOutput frameOutput = getFrameAtIndexInternal(i, frameBatchOutput.data[f]); frameBatchOutput.ptsSeconds[f] = frameOutput.ptsSeconds; frameBatchOutput.durationSeconds[f] = frameOutput.durationSeconds; } - frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); + frameBatchOutput.data = maybePermuteOutputTensor(frameBatchOutput.data); return frameBatchOutput; } @@ -843,8 +913,12 @@ I P P P I P P P I P P I P P I P (2) is more efficient than (1) if there is an I frame between x and y. */ bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const { - int64_t lastDecodedAvFramePts = - streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts; + const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); + if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { + return true; + } + + int64_t lastDecodedAvFramePts = streamInfo.lastDecodedAvFramePts; if (targetPts < lastDecodedAvFramePts) { // We can never skip a seek if we are seeking backwards. return false; @@ -877,6 +951,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { streamInfo.discardFramesBeforePts = desiredPts; decodeStats_.numSeeksAttempted++; + if (canWeAvoidSeeking(desiredPts)) { decodeStats_.numSeeksSkipped++; return; @@ -972,6 +1047,7 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( if (ffmpegStatus == AVERROR_EOF) { // End of file reached. We must drain the codec by sending a nullptr // packet. + ffmpegStatus = avcodec_send_packet( streamInfo.codecContext.get(), /*avpkt=*/nullptr); @@ -1047,13 +1123,14 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( AVFrame* avFrame = avFrameStream.avFrame.get(); frameOutput.streamIndex = streamIndex; auto& streamInfo = streamInfos_[streamIndex]; - TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO); frameOutput.ptsSeconds = ptsToSeconds( avFrame->pts, formatContext_->streams[streamIndex]->time_base); frameOutput.durationSeconds = ptsToSeconds( getDuration(avFrame), formatContext_->streams[streamIndex]->time_base); - // TODO: we should fold preAllocatedOutputTensor into AVFrameStream. - if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { + if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { + convertAudioAVFrameToFrameOutputOnCPU( + avFrameStream, frameOutput, preAllocatedOutputTensor); + } else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { convertAVFrameToFrameOutputOnCPU( avFrameStream, frameOutput, preAllocatedOutputTensor); } else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) { @@ -1229,6 +1306,48 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); } +void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU( + VideoDecoder::AVFrameStream& avFrameStream, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor) { + const AVFrame* avFrame = avFrameStream.avFrame.get(); + + auto numSamples = avFrame->nb_samples; // per channel + auto numChannels = getNumChannels(avFrame); + + // TODO: dtype should be format-dependent + // TODO_CODE_QUALITY rename data to something else + torch::Tensor data; + if (preAllocatedOutputTensor.has_value()) { + data = preAllocatedOutputTensor.value(); + } else { + data = torch::empty({numChannels, numSamples}, torch::kFloat32); + } + + AVSampleFormat format = static_cast(avFrame->format); + // TODO Implement all formats + switch (format) { + case AV_SAMPLE_FMT_FLTP: { + uint8_t* outputChannelData = static_cast(data.data_ptr()); + auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format); + for (auto channel = 0; channel < numChannels; + ++channel, outputChannelData += numBytesPerChannel) { + memcpy( + outputChannelData, + avFrame->extended_data[channel], + numBytesPerChannel); + } + break; + } + default: + TORCH_CHECK( + false, + "Unsupported audio format (yet!): ", + av_get_sample_fmt_name(format)); + } + frameOutput.data = data; +} + // -------------------------------------------------------------------------- // OUTPUT ALLOCATION AND SHAPE CONVERSION // -------------------------------------------------------------------------- @@ -1247,6 +1366,41 @@ VideoDecoder::FrameBatchOutput::FrameBatchOutput( height, width, videoStreamOptions.device, numFrames); } +VideoDecoder::FrameBatchOutput::FrameBatchOutput( + int64_t numFrames, + int64_t numChannels, + int64_t numSamples) + : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), + durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) { + // TODO handle dtypes other than float + auto tensorOptions = torch::TensorOptions() + .dtype(torch::kFloat32) + .layout(torch::kStrided) + .device(torch::kCPU); + data = torch::empty({numFrames, numChannels, numSamples}, tensorOptions); +} + +VideoDecoder::FrameBatchOutput VideoDecoder::makeFrameBatchOutput( + int64_t numFrames) { + const auto& streamInfo = streamInfos_[activeStreamIndex_]; + if (streamInfo.avMediaType == AVMEDIA_TYPE_VIDEO) { + const auto& videoStreamOptions = streamInfo.videoStreamOptions; + const auto& streamMetadata = + containerMetadata_.allStreamMetadata[activeStreamIndex_]; + return FrameBatchOutput(numFrames, videoStreamOptions, streamMetadata); + } else { + // TODO_FRAME_SIZE_BATCH_TENSOR_ALLOCATION + // We asserted that frame_size is non-zero when we added the stream, but it + // may not always be the case. + // When it's 0, we can't pre-allocate the output tensor as we don't know the + // number of samples per channel, and it may be non-constant. We'll have to + // find a way to make the batch-APIs work without pre-allocation. + int64_t numSamples = streamInfo.codecContext->frame_size; + int64_t numChannels = getNumChannels(streamInfo.codecContext); + return FrameBatchOutput(numFrames, numChannels, numSamples); + } +} + torch::Tensor allocateEmptyHWCTensor( int height, int width, @@ -1268,6 +1422,17 @@ torch::Tensor allocateEmptyHWCTensor( } } +torch::Tensor VideoDecoder::maybePermuteOutputTensor( + torch::Tensor& outputTensor) { + if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) { + return maybePermuteHWC2CHW(outputTensor); + } else { + // No need to do anything for audio. We always return (numChannels, + // numSamples) or (numFrames, numChannels, numSamples) + return outputTensor; + } +} + // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so. // The [N] leading batch-dimension is optional i.e. the input tensor can be 3D // or 4D. @@ -1493,8 +1658,8 @@ int VideoDecoder::getKeyFrameIndexForPtsUsingScannedIndex( return upperBound - 1 - keyFrames.begin(); } -int64_t VideoDecoder::secondsToIndexLowerBound(double seconds) { - auto& streamInfo = streamInfos_[activeStreamIndex_]; +int64_t VideoDecoder::secondsToIndexLowerBound(double seconds) const { + auto& streamInfo = streamInfos_.at(activeStreamIndex_); switch (seekMode_) { case SeekMode::exact: { auto frame = std::lower_bound( @@ -1509,7 +1674,7 @@ int64_t VideoDecoder::secondsToIndexLowerBound(double seconds) { } case SeekMode::approximate: { auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; + containerMetadata_.allStreamMetadata.at(activeStreamIndex_); return std::floor(seconds * streamMetadata.averageFps.value()); } default: @@ -1598,7 +1763,8 @@ double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) { // VALIDATION UTILS // -------------------------------------------------------------------------- -void VideoDecoder::validateActiveStream() { +void VideoDecoder::validateActiveStream( + std::optional avMediaType) { auto errorMsg = "Provided stream index=" + std::to_string(activeStreamIndex_) + " was not previously added."; @@ -1612,6 +1778,12 @@ void VideoDecoder::validateActiveStream() { "Invalid stream index=" + std::to_string(activeStreamIndex_) + "; valid indices are in the range [0, " + std::to_string(allStreamMetadataSize) + ")."); + + if (avMediaType.has_value()) { + TORCH_CHECK( + streamInfos_[activeStreamIndex_].avMediaType == avMediaType.value(), + "The method you called doesn't support the media type (audio or video)"); + } } void VideoDecoder::validateScannedAllStreams(const std::string& msg) { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index e71973851..2a9479b83 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -77,6 +77,9 @@ class VideoDecoder { // Video-only fields derived from the AVCodecContext. std::optional width; std::optional height; + + // Audio-only fields + std::optional sampleRate; }; struct ContainerMetadata { @@ -139,9 +142,7 @@ class VideoDecoder { void addVideoStream( int streamIndex, const VideoStreamOptions& videoStreamOptions = VideoStreamOptions()); - void addAudioStreamDecoder( - int streamIndex, - const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); + void addAudioStream(int streamIndex); // -------------------------------------------------------------------------- // DECODING AND SEEKING APIs @@ -168,6 +169,10 @@ class VideoDecoder { int64_t numFrames, const VideoStreamOptions& videoStreamOptions, const StreamMetadata& streamMetadata); + explicit FrameBatchOutput( + int64_t numFrames, + int64_t numChannels, + int64_t numSamples); }; // Places the cursor at the first frame on or after the position in seconds. @@ -322,6 +327,8 @@ class VideoDecoder { struct StreamInfo { int streamIndex = -1; AVStream* stream = nullptr; + AVMediaType avMediaType = AVMEDIA_TYPE_UNKNOWN; + AVRational timeBase = {}; UniqueAVCodecContext codecContext; @@ -370,6 +377,7 @@ class VideoDecoder { FrameOutput getNextFrameInternal( std::optional preAllocatedOutputTensor = std::nullopt); + torch::Tensor maybePermuteOutputTensor(torch::Tensor& outputTensor); torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwcTensor); FrameOutput convertAVFrameToFrameOutput( @@ -381,12 +389,19 @@ class VideoDecoder { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); + void convertAudioAVFrameToFrameOutputOnCPU( + AVFrameStream& avFrameStream, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = std::nullopt); + torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame); int convertAVFrameToTensorUsingSwsScale( const AVFrame* avFrame, torch::Tensor& outputTensor); + FrameBatchOutput makeFrameBatchOutput(int64_t numFrames); + // -------------------------------------------------------------------------- // COLOR CONVERSION LIBRARIES HANDLERS CREATION // -------------------------------------------------------------------------- @@ -414,7 +429,7 @@ class VideoDecoder { const std::vector& keyFrames, int64_t pts) const; - int64_t secondsToIndexLowerBound(double seconds); + int64_t secondsToIndexLowerBound(double seconds) const; int64_t secondsToIndexUpperBound(double seconds); @@ -424,6 +439,11 @@ class VideoDecoder { // STREAM AND METADATA APIS // -------------------------------------------------------------------------- + void addStream( + int streamIndex, + AVMediaType mediaType, + const VideoStreamOptions& videoStreamOptions = VideoStreamOptions()); + // Returns the "best" stream index for a given media type. The "best" is // determined by various heuristics in FFMPEG. // See @@ -441,7 +461,8 @@ class VideoDecoder { // VALIDATION UTILS // -------------------------------------------------------------------------- - void validateActiveStream(); + void validateActiveStream( + std::optional avMediaType = std::nullopt); void validateScannedAllStreams(const std::string& msg); void validateFrameIndex( const StreamMetadata& streamMetadata, @@ -468,6 +489,7 @@ class VideoDecoder { bool scannedAllStreams_ = false; // Tracks that we've already been initialized. bool initialized_ = false; + bool alreadyCalledGetFramesPlayedInRange_ = false; }; // -------------------------------------------------------------------------- diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index d0fefcd9f..4a5efa60b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -34,6 +34,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None) -> ()"); m.def( "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None) -> ()"); + m.def( + "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None) -> ()"); m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); m.def("get_next_frame(Tensor(a!) decoder) -> (Tensor, Tensor, Tensor)"); m.def( @@ -222,7 +224,21 @@ void _add_video_stream( videoDecoder->addVideoStream(stream_index.value_or(-1), videoStreamOptions); } +void add_audio_stream( + at::Tensor& decoder, + std::optional stream_index) { + auto videoDecoder = unwrapTensorToGetDecoder(decoder); + videoDecoder->addAudioStream(stream_index.value_or(-1)); +} + void seek_to_pts(at::Tensor& decoder, double seconds) { + // TODO we should prevent more than one call to this op for audio streams, for + // the same reasons we do so for getFramesPlayedInRange(). But we can't + // implement the logic here, because we don't know media type (audio vs + // video). We also can't do it within setCursorPtsInSeconds because it's used + // by all other decoding methods. + // This isn't un-doable, just not easy with the API we currently have. + auto videoDecoder = static_cast(decoder.mutable_data_ptr()); videoDecoder->setCursorPtsInSeconds(seconds); } @@ -476,6 +492,16 @@ std::string get_stream_json_metadata( if (streamMetadata.averageFps.has_value()) { map["averageFps"] = std::to_string(*streamMetadata.averageFps); } + if (streamMetadata.sampleRate.has_value()) { + map["sampleRate"] = std::to_string(*streamMetadata.sampleRate); + } + if (streamMetadata.mediaType == AVMEDIA_TYPE_VIDEO) { + map["mediaType"] = "\"video\""; + } else if (streamMetadata.mediaType == AVMEDIA_TYPE_AUDIO) { + map["mediaType"] = "\"audio\""; + } else { + map["mediaType"] = "\"other\""; + } return mapToJson(map); } @@ -521,6 +547,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("seek_to_pts", &seek_to_pts); m.impl("add_video_stream", &add_video_stream); m.impl("_add_video_stream", &_add_video_stream); + m.impl("add_audio_stream", &add_audio_stream); m.impl("get_next_frame", &get_next_frame); m.impl("_get_key_frame_indices", &_get_key_frame_indices); m.impl("get_json_metadata", &get_json_metadata); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 8bdd05cdd..a3cc821ad 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -55,6 +55,10 @@ void _add_video_stream( std::optional device = std::nullopt, std::optional color_conversion_library = std::nullopt); +void add_audio_stream( + at::Tensor& decoder, + std::optional stream_index = std::nullopt); + // Seek to a particular presentation timestamp in the video in seconds. void seek_to_pts(at::Tensor& decoder, double seconds); diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index d39d3d237..8d9236044 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -6,6 +6,7 @@ from ._metadata import ( + AudioStreamMetadata, get_video_metadata, get_video_metadata_from_header, VideoMetadata, @@ -15,6 +16,7 @@ _add_video_stream, _get_key_frame_indices, _test_frame_pts_equality, + add_audio_stream, add_video_stream, create_from_bytes, create_from_file, diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py index 1ec4a87c2..6932f7362 100644 --- a/src/torchcodec/decoders/_core/_metadata.py +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -144,6 +144,102 @@ def __repr__(self): return s +@dataclass +class AudioStreamMetadata: + # TODO do we expose the notion of frame here, like in fps? It's technically + # valid, but potentially is an FFmpeg-specific concept for audio + # TODO Need sample rate and format + sample_rate: Optional[int] + duration_seconds_from_header: Optional[float] + bit_rate: Optional[float] + num_frames_from_header: Optional[int] + num_frames_from_content: Optional[int] + begin_stream_seconds_from_content: Optional[float] + end_stream_seconds_from_content: Optional[float] + codec: Optional[str] + average_fps_from_header: Optional[float] + stream_index: int + + @property + def num_frames(self) -> Optional[int]: + """Number of frames in the stream. This corresponds to + ``num_frames_from_content`` if a :term:`scan` was made, otherwise it + corresponds to ``num_frames_from_header``. + """ + if self.num_frames_from_content is not None: + return self.num_frames_from_content + else: + return self.num_frames_from_header + + @property + def duration_seconds(self) -> Optional[float]: + """Duration of the stream in seconds. We try to calculate the duration + from the actual frames if a :term:`scan` was performed. Otherwise we + fall back to ``duration_seconds_from_header``. + """ + if ( + self.end_stream_seconds_from_content is None + or self.begin_stream_seconds_from_content is None + ): + return self.duration_seconds_from_header + return ( + self.end_stream_seconds_from_content + - self.begin_stream_seconds_from_content + ) + + @property + def average_fps(self) -> Optional[float]: + """Average fps of the stream. If a :term:`scan` was perfomed, this is + computed from the number of frames and the duration of the stream. + Otherwise we fall back to ``average_fps_from_header``. + """ + if ( + self.end_stream_seconds_from_content is None + or self.begin_stream_seconds_from_content is None + or self.num_frames is None + ): + return self.average_fps_from_header + return self.num_frames / ( + self.end_stream_seconds_from_content + - self.begin_stream_seconds_from_content + ) + + @property + def begin_stream_seconds(self) -> float: + """Beginning of the stream, in seconds (float). Conceptually, this + corresponds to the first frame's :term:`pts`. If + ``begin_stream_seconds_from_content`` is not None, then it is returned. + Otherwise, this value is 0. + """ + if self.begin_stream_seconds_from_content is None: + return 0 + else: + return self.begin_stream_seconds_from_content + + @property + def end_stream_seconds(self) -> Optional[float]: + """End of the stream, in seconds (float or None). + Conceptually, this corresponds to last_frame.pts + last_frame.duration. + If ``end_stream_seconds_from_content`` is not None, then that value is + returned. Otherwise, returns ``duration_seconds``. + """ + if self.end_stream_seconds_from_content is None: + return self.duration_seconds + else: + return self.end_stream_seconds_from_content + + def __repr__(self): + # Overridden because properites are not printed by default. + s = self.__class__.__name__ + ":\n" + spaces = " " + s += f"{spaces}num_frames: {self.num_frames}\n" + s += f"{spaces}duration_seconds: {self.duration_seconds}\n" + s += f"{spaces}average_fps: {self.average_fps}\n" + for field in dataclasses.fields(self): + s += f"{spaces}{field.name}: {getattr(self, field.name)}\n" + return s + + @dataclass class VideoMetadata: duration_seconds_from_header: Optional[float] @@ -151,7 +247,7 @@ class VideoMetadata: best_video_stream_index: Optional[int] best_audio_stream_index: Optional[int] - streams: List[VideoStreamMetadata] + streams: List[Union[VideoStreamMetadata, AudioStreamMetadata]] @property def duration_seconds(self) -> Optional[float]: @@ -165,7 +261,9 @@ def bit_rate(self) -> Optional[float]: def best_video_stream(self) -> VideoStreamMetadata: if self.best_video_stream_index is None: raise ValueError("The best video stream is unknown.") - return self.streams[self.best_video_stream_index] + metadata = self.streams[self.best_video_stream_index] + assert isinstance(metadata, VideoStreamMetadata) # mypy <3 + return metadata def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata: @@ -176,28 +274,37 @@ def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata: """ container_dict = json.loads(_get_container_json_metadata(decoder)) - streams_metadata = [] + streams_metadata: List[Union[VideoStreamMetadata, AudioStreamMetadata]] = [] for stream_index in range(container_dict["numStreams"]): stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index)) - streams_metadata.append( - VideoStreamMetadata( - duration_seconds_from_header=stream_dict.get("durationSeconds"), - bit_rate=stream_dict.get("bitRate"), - num_frames_from_header=stream_dict.get("numFrames"), - num_frames_from_content=stream_dict.get("numFramesFromScan"), - begin_stream_seconds_from_content=stream_dict.get( - "minPtsSecondsFromScan" - ), - end_stream_seconds_from_content=stream_dict.get( - "maxPtsSecondsFromScan" - ), - codec=stream_dict.get("codec"), - width=stream_dict.get("width"), - height=stream_dict.get("height"), - average_fps_from_header=stream_dict.get("averageFps"), - stream_index=stream_index, - ) + common_meta = dict( + duration_seconds_from_header=stream_dict.get("durationSeconds"), + bit_rate=stream_dict.get("bitRate"), + num_frames_from_header=stream_dict.get("numFrames"), + num_frames_from_content=stream_dict.get("numFramesFromScan"), + begin_stream_seconds_from_content=stream_dict.get("minPtsSecondsFromScan"), + end_stream_seconds_from_content=stream_dict.get("maxPtsSecondsFromScan"), + codec=stream_dict.get("codec"), + average_fps_from_header=stream_dict.get("averageFps"), + stream_index=stream_index, ) + if stream_dict["mediaType"] == "audio": + streams_metadata.append( + AudioStreamMetadata( + sample_rate=stream_dict.get("sampleRate"), + **common_meta, + ) + ) + else: + # TODO we're adding a VideoStreamMetadata for all non-audio streams, + # including streams like subtitles, which makes little sense. + streams_metadata.append( + VideoStreamMetadata( + width=stream_dict.get("width"), + height=stream_dict.get("height"), + **common_meta, + ) + ) return VideoMetadata( duration_seconds_from_header=container_dict.get("durationSeconds"), diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 40216304a..190384684 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -69,6 +69,7 @@ def load_torchcodec_extension(): ) add_video_stream = torch.ops.torchcodec_ns.add_video_stream.default _add_video_stream = torch.ops.torchcodec_ns._add_video_stream.default +add_audio_stream = torch.ops.torchcodec_ns.add_audio_stream.default seek_to_pts = torch.ops.torchcodec_ns.seek_to_pts.default get_next_frame = torch.ops.torchcodec_ns.get_next_frame.default get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default @@ -150,6 +151,15 @@ def add_video_stream_abstract( return +@register_fake("torchcodec_ns::add_audio_stream") +def add_audio_stream_abstract( + decoder: torch.Tensor, + *, + stream_index: Optional[int] = None, +) -> None: + return + + @register_fake("torchcodec_ns::seek_to_pts") def seek_abstract(decoder: torch.Tensor, seconds: float) -> None: return diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 8e91efb71..aa6c33d61 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import os +from functools import partial os.environ["TORCH_LOGS"] = "output_code" import json @@ -18,6 +19,7 @@ from torchcodec.decoders._core import ( _add_video_stream, _test_frame_pts_equality, + add_audio_stream, add_video_stream, create_from_bytes, create_from_file, @@ -398,7 +400,7 @@ def test_audio_get_json_metadata(self): decoder = create_from_file(str(NASA_AUDIO.path)) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) - assert metadata_dict["durationSeconds"] == pytest.approx(13.25, abs=0.01) + assert metadata_dict["durationSeconds"] == pytest.approx(13.013, abs=0.01) def test_get_ffmpeg_version(self): ffmpeg_dict = get_ffmpeg_library_versions() @@ -618,6 +620,112 @@ def test_cuda_decoder(self): duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 ) + @pytest.mark.parametrize( + "method", + ( + partial(get_frame_at_index, frame_index=4), + partial(get_frames_at_indices, frame_indices=[4, 5]), + partial(get_frames_in_range, start=4, stop=5), + partial(get_frame_at_pts, seconds=2), + partial(get_frames_by_pts, timestamps=[0, 1.5]), + ), + ) + def test_audio_bad_method(self, method): + decoder = create_from_file(str(NASA_AUDIO.path)) + add_audio_stream(decoder) + with pytest.raises( + RuntimeError, match="The method you called doesn't support the media type" + ): + method(decoder) + + @pytest.mark.parametrize( + "start_seconds, stop_seconds", + ( + # Beginning to end + (0, 13.05), + # At frames boundaries. Frame duration is exactly 0.064 seconds for + # NASA_AUDIO. Need artifial -1e-5 for upper-bound to align the + # reference_frames with the frames returned by the decoder, where + # the interval is half-open. + (0.064 * 4, 0.064 * 20 - 1e-5), + # Not at frames boundaries + (2, 4), + ), + ) + def test_audio_get_frames_by_pts_in_range(self, start_seconds, stop_seconds): + decoder = create_from_file(str(NASA_AUDIO.path)) + add_audio_stream(decoder) + + reference_frames = NASA_AUDIO.get_frame_data_by_range( + start=NASA_AUDIO.pts_to_frame_index(start_seconds), + stop=NASA_AUDIO.pts_to_frame_index(stop_seconds) + 1, + ) + frames, _, _ = get_frames_by_pts_in_range( + decoder, start_seconds=start_seconds, stop_seconds=stop_seconds + ) + + assert_frames_equal(frames, reference_frames) + + def test_audio_get_frames_by_pts_in_range_multiple_calls(self): + decoder = create_from_file(str(NASA_AUDIO.path)) + add_audio_stream(decoder) + + get_frames_by_pts_in_range(decoder, start_seconds=0, stop_seconds=1) + with pytest.raises( + RuntimeError, match="Can only decode once with audio stream" + ): + get_frames_by_pts_in_range(decoder, start_seconds=0, stop_seconds=1) + + def test_audio_seek_and_next(self): + decoder = create_from_file(str(NASA_AUDIO.path)) + add_audio_stream(decoder) + + pts = 2 + # Need +1 because we're not at frames boundaries + reference_frame = NASA_AUDIO.get_frame_data_by_index( + NASA_AUDIO.pts_to_frame_index(pts) + 1 + ) + seek_to_pts(decoder, pts) + frame, _, _ = get_next_frame(decoder) + assert_frames_equal(frame, reference_frame) + + # Seeking forward is OK + pts = 4 + reference_frame = NASA_AUDIO.get_frame_data_by_index( + NASA_AUDIO.pts_to_frame_index(pts) + 1 + ) + seek_to_pts(decoder, pts) + frame, _, _ = get_next_frame(decoder) + assert_frames_equal(frame, reference_frame) + + # Seeking backwards doesn't error, but it's wrong. See TODO in + # `seek_to_pts` op. + prev_pts = pts + pts = 1 + seek_to_pts(decoder, pts) + frame, _, _ = get_next_frame(decoder) + # the decoder actually didn't seek, so the frame we're getting is just + # the "next: one without seeking. This assertion exists to illutrate + # what currently hapens, but it's obviously *wrong*. + reference_frame = NASA_AUDIO.get_frame_data_by_index( + NASA_AUDIO.pts_to_frame_index(prev_pts) + 2 + ) + assert_frames_equal(frame, reference_frame) + + # def test_audio_seek_and_next_backwards(self): + # decoder = create_from_file(str(NASA_AUDIO.path)) + # add_audio_stream(decoder) + + # for pts in (4.5, 2): + # # Need +1 because we're not at frames boundaries + # reference_frame = NASA_AUDIO.get_frame_data_by_index(NASA_AUDIO.pts_to_frame_index(pts) + 1) + # seek_to_pts(decoder, pts) + # frame, _, _ = get_next_frame(decoder) + # # assert_frames_equal(frame, reference_frame) + + # reference_frame = NASA_AUDIO.get_frame_data_by_index(NASA_AUDIO.pts_to_frame_index(4.5) + 2) + # assert_frames_equal(frame, reference_frame) + if __name__ == "__main__": pytest.main() diff --git a/test/resources/nasa_13013.mp4.stream4.all_frames.pt b/test/resources/nasa_13013.mp4.stream4.all_frames.pt new file mode 100644 index 000000000..c3c3aa8a2 Binary files /dev/null and b/test/resources/nasa_13013.mp4.stream4.all_frames.pt differ diff --git a/test/utils.py b/test/utils.py index 857273fb7..b8ded387a 100644 --- a/test/utils.py +++ b/test/utils.py @@ -25,8 +25,17 @@ def cpu_and_cuda(): return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) -def get_ffmpeg_major_version(): - return int(get_ffmpeg_library_versions()["ffmpeg_version"].split(".")[0]) +def assert_frames_equal(*args, **kwargs): + frame = args[0] + # This heuristic will work until we start returning uint8 audio frames... + if frame.dtype == torch.uint8: + return assert_video_frames_equal(*args, **kwargs) + else: + return assert_audio_frames_equal(*args, **kwargs) + + +def assert_audio_frames_equal(*args, **kwargs): + torch.testing.assert_close(*args, **kwargs) # For use with decoded data frames. On CPU Linux, we expect exact, bit-for-bit @@ -34,7 +43,7 @@ def get_ffmpeg_major_version(): # On other platforms (e.g. MacOS), we also allow a small tolerance. FFmpeg does # not guarantee bit-for-bit equality across systems and architectures, so we # also cannot. We currently use Linux on x86_64 as our reference system. -def assert_frames_equal(*args, **kwargs): +def assert_video_frames_equal(*args, **kwargs): if sys.platform == "linux": if args[0].device.type == "cuda": atol = 2 @@ -72,6 +81,10 @@ def assert_tensor_close_on_at_least(actual_tensor, ref_tensor, *, percentage, at ) +def get_ffmpeg_major_version(): + return int(get_ffmpeg_library_versions()["ffmpeg_version"].split(".")[0]) + + def in_fbcode() -> bool: return os.environ.get("IN_FBCODE_TORCHCODEC") == "1" @@ -89,11 +102,6 @@ def _get_file_path(filename: str) -> pathlib.Path: return pathlib.Path(__file__).parent / "resources" / filename -def _load_tensor_from_file(filename: str) -> torch.Tensor: - file_path = _get_file_path(filename) - return torch.load(file_path, weights_only=True).permute(2, 0, 1) - - @dataclass class TestFrameInfo: pts_seconds: float @@ -120,12 +128,7 @@ def to_tensor(self) -> torch.Tensor: def get_frame_data_by_index( self, idx: int, *, stream_index: Optional[int] = None ) -> torch.Tensor: - if stream_index is None: - stream_index = self.default_stream_index - - return _load_tensor_from_file( - f"{self.filename}.stream{stream_index}.frame{idx:06d}.pt" - ) + raise NotImplementedError("Override in child classes") def get_frame_data_by_range( self, @@ -202,6 +205,17 @@ class TestVideoStreamInfo: class TestVideo(TestContainerFile): stream_infos: Dict[int, TestVideoStreamInfo] + def get_frame_data_by_index( + self, idx: int, *, stream_index: Optional[int] = None + ) -> torch.Tensor: + if stream_index is None: + stream_index = self.default_stream_index + + file_path = _get_file_path( + f"{self.filename}.stream{stream_index}.frame{idx:06d}.pt" + ) + return torch.load(file_path, weights_only=True).permute(2, 0, 1) + @property def width(self) -> int: return self.stream_infos[self.default_stream_index].width @@ -298,10 +312,68 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor: }, ) -# When we start actually decoding audio-only files, we'll probably need to define -# a TestAudio class with audio specific values. Until then, we only need a filename. -NASA_AUDIO = TestContainerFile( - filename="nasa_13013.mp4.audio.mp3", default_stream_index=0, frames={} + +@dataclass +class TestAudioStreamInfo: + frame_rate: float + + +@dataclass +class TestAudio(TestContainerFile): + + stream_infos: Dict[int, TestAudioStreamInfo] + _reference_frames: tuple[torch.Tensor] = tuple() + + # Storing each individual frame is too expensive for audio, because there's + # a massive overhead in the binary format saved by pytorch. Saving all the + # frames in a single file uses 1.6MB while saving all frames in individual + # files uses 302MB (yes). + # So we store the reference frames in a single file, and load/cache those + # when the TestAudio instance is created. + def __post_init__(self): + # We hard-code the default stream index, see TODO below. + file_path = _get_file_path( + f"{self.filename}.stream{self.default_stream_index}.all_frames.pt" + ) + t = torch.load(file_path, weights_only=True) + + # These are hard-coded value assuming stream 4 of nasa_13013.mp4. Each + # of the 204 frames contains 1024 samples. + # TODO make this more generic + assert t.shape == (2, 204 * 1024) + self._reference_frames = torch.chunk(t, chunks=204, dim=1) + + def get_frame_data_by_index( + self, idx: int, *, stream_index: Optional[int] = None + ) -> torch.Tensor: + if stream_index is not None and stream_index != self.default_stream_index: + # TODO address this, the fix should be to let _reference_frames be a + # dict[tuple[torch.Tensor]] where keys are stream indices, and load + # all of those indices in __post_init__. + raise ValueError( + "Can only use default stream index with TestAudio for now." + ) + + return self._reference_frames[idx] + + def pts_to_frame_index(self, pts_seconds: float) -> int: + # These are hard-coded value assuming stream 4 of nasa_13013.mp4. Each + # of the 204 frames contains 1024 samples. + # TODO make this more generic + frame_duration_seconds = 0.064 + return int(pts_seconds // frame_duration_seconds) + + # TODO: this shouldn't be named chw. Also values are hard-coded + @property + def empty_chw_tensor(self) -> torch.Tensor: + return torch.empty([0, 2, 1024], dtype=torch.float32) + + +NASA_AUDIO = TestAudio( + filename="nasa_13013.mp4", + default_stream_index=4, + frames={}, # TODO + stream_infos={4: TestAudioStreamInfo(frame_rate=16_000)}, ) H265_VIDEO = TestVideo(