Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/torchcodec/_core/Metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ std::optional<double> StreamMetadata::getDurationSeconds(
return static_cast<double>(numFramesFromHeader.value()) /
averageFpsFromHeader.value();
}
if (durationSecondsFromContainer.has_value()) {
return durationSecondsFromContainer.value();
}
return std::nullopt;
default:
TORCH_CHECK(false, "Unknown SeekMode");
Expand Down Expand Up @@ -80,13 +83,13 @@ std::optional<int64_t> StreamMetadata::getNumFrames(SeekMode seekMode) const {
numFramesFromContent.has_value(), "Missing numFramesFromContent");
return numFramesFromContent.value();
case SeekMode::approximate: {
auto durationSeconds = getDurationSeconds(seekMode);
if (numFramesFromHeader.has_value()) {
return numFramesFromHeader.value();
}
if (averageFpsFromHeader.has_value() &&
durationSecondsFromHeader.has_value()) {
if (averageFpsFromHeader.has_value() && durationSeconds.has_value()) {
return static_cast<int64_t>(
averageFpsFromHeader.value() * durationSecondsFromHeader.value());
averageFpsFromHeader.value() * durationSeconds.value());
}
return std::nullopt;
}
Expand Down
3 changes: 3 additions & 0 deletions src/torchcodec/_core/Metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ struct StreamMetadata {
std::optional<double> averageFpsFromHeader;
std::optional<double> bitRate;

// Used as fallback in approximate mode when stream duration is unavailable.
std::optional<double> durationSecondsFromContainer;

// More accurate duration, obtained by scanning the file.
// These presentation timestamps are in time base.
std::optional<int64_t> beginStreamPtsFromContent;
Expand Down
43 changes: 23 additions & 20 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,26 @@ void SingleStreamDecoder::initializeDecoder() {
"Failed to find stream info: ",
getFFMPEGErrorStringFromErrorCode(status));

if (formatContext_->duration > 0) {
AVRational defaultTimeBase{1, AV_TIME_BASE};
containerMetadata_.durationSecondsFromHeader =
ptsToSeconds(formatContext_->duration, defaultTimeBase);
}

if (formatContext_->bit_rate > 0) {
containerMetadata_.bitRate = formatContext_->bit_rate;
}

int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO);
if (bestVideoStream >= 0) {
containerMetadata_.bestVideoStreamIndex = bestVideoStream;
}

int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO);
if (bestAudioStream >= 0) {
containerMetadata_.bestAudioStreamIndex = bestAudioStream;
}

for (unsigned int i = 0; i < formatContext_->nb_streams; i++) {
AVStream* avStream = formatContext_->streams[i];
StreamMetadata streamMetadata;
Expand Down Expand Up @@ -149,27 +169,10 @@ void SingleStreamDecoder::initializeDecoder() {
containerMetadata_.numAudioStreams++;
}

containerMetadata_.allStreamMetadata.push_back(streamMetadata);
}

if (formatContext_->duration > 0) {
AVRational defaultTimeBase{1, AV_TIME_BASE};
containerMetadata_.durationSecondsFromHeader =
ptsToSeconds(formatContext_->duration, defaultTimeBase);
}

if (formatContext_->bit_rate > 0) {
containerMetadata_.bitRate = formatContext_->bit_rate;
}

int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO);
if (bestVideoStream >= 0) {
containerMetadata_.bestVideoStreamIndex = bestVideoStream;
}
streamMetadata.durationSecondsFromContainer =
containerMetadata_.durationSecondsFromHeader;

int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO);
if (bestAudioStream >= 0) {
containerMetadata_.bestAudioStreamIndex = bestAudioStream;
containerMetadata_.allStreamMetadata.push_back(streamMetadata);
}

if (seekMode_ == SeekMode::exact) {
Expand Down
3 changes: 2 additions & 1 deletion src/torchcodec/_core/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class StreamMetadata:
from the actual frames if a :term:`scan` was performed. Otherwise we
fall back to ``duration_seconds_from_header``. If that value is also None,
we instead calculate the duration from ``num_frames_from_header`` and
``average_fps_from_header``.
``average_fps_from_header``. If all of those are unavailable, we fall back
to the container-level ``duration_seconds_from_header``.
"""
begin_stream_seconds: Optional[float]
"""Beginning of the stream, in seconds (float). Conceptually, this
Expand Down
11 changes: 0 additions & 11 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
SINE_MONO_S32,
SINE_MONO_S32_44100,
SINE_MONO_S32_8000,
supports_approximate_mode,
TEST_SRC_2_720P,
TEST_SRC_2_720P_H265,
TEST_SRC_2_720P_MPEG4,
Expand Down Expand Up @@ -1465,8 +1464,6 @@ def test_get_frames_at_tensor_indices(self):
def test_beta_cuda_interface_get_frame_at(
self, asset, contiguous_indices, seek_mode
):
if seek_mode == "approximate" and not supports_approximate_mode(asset):
pytest.skip("asset doesn't work with approximate mode")

if in_fbcode() and asset is AV1_VIDEO:
pytest.skip("AV1 CUDA not supported internally")
Expand Down Expand Up @@ -1513,8 +1510,6 @@ def test_beta_cuda_interface_get_frame_at(
def test_beta_cuda_interface_get_frames_at(
self, asset, contiguous_indices, seek_mode
):
if seek_mode == "approximate" and not supports_approximate_mode(asset):
pytest.skip("asset doesn't work with approximate mode")
if in_fbcode() and asset is AV1_VIDEO:
pytest.skip("AV1 CUDA not supported internally")

Expand Down Expand Up @@ -1558,8 +1553,6 @@ def test_beta_cuda_interface_get_frames_at(
)
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode):
if seek_mode == "approximate" and not supports_approximate_mode(asset):
pytest.skip("asset doesn't work with approximate mode")
if in_fbcode() and asset is AV1_VIDEO:
pytest.skip("AV1 CUDA not supported internally")

Expand Down Expand Up @@ -1600,8 +1593,6 @@ def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode):
)
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode):
if seek_mode == "approximate" and not supports_approximate_mode(asset):
pytest.skip("asset doesn't work with approximate mode")
if in_fbcode() and asset is AV1_VIDEO:
pytest.skip("AV1 CUDA not supported internally")

Expand Down Expand Up @@ -1643,8 +1634,6 @@ def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode):
)
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_beta_cuda_interface_backwards(self, asset, seek_mode):
if seek_mode == "approximate" and not supports_approximate_mode(asset):
pytest.skip("asset doesn't work with approximate mode")
if in_fbcode() and asset is AV1_VIDEO:
pytest.skip("AV1 CUDA not supported internally")

Expand Down
Loading