Skip to content

Conversation

Dan-Flores
Copy link
Contributor

@Dan-Flores Dan-Flores commented Aug 1, 2025

This PR enables custom_frame_mappings to be used in the Python VideoDecoder.

  • Implements read_custom_frame_mappings(). This parses a JSON str or JSON file to extract all_frames, is_key_frame, duration.
  • Updates the seek_mode if custom_frame_mappings is passed in to avoid the exact mode scan during initialization.
  • Tests are added to check that the frames are being decoded correctly.

Benchmarking

I wrote a short benchmarking script to test the initialization times of custom_frame_mappings versus exact.

custom_frame_mapping mode was quicker than exact mode if the ffprobe command used to generate the frame mapping json used the -show_entries option to reduce the JSON size.

With this optimization the performance improvement increased with longer videos.
Without this optimization, using the full ffprobe output, exact mode was faster.

The results on nasa_13013.mp4:

exact:
med = 3.38ms +- 2.54
custom_frame_mappings:
med = 3.02ms +- 0.86

The results on a generated video, mandelbrot_1920x1080_120s.mp4:

exact:
med = 29.51ms +- 7.56
custom_frame_mappings:
med = 16.69ms +- 9.38

The benchmarking code:

import subprocess
import torch
from time import perf_counter_ns

from torchcodec.decoders._video_decoder import VideoDecoder


def bench(f, *args, num_exp=100, warmup=0, **kwargs):

    for _ in range(warmup):
        f(*args, **kwargs)

    times = []
    for _ in range(num_exp):
        start = perf_counter_ns()
        f(*args, **kwargs)
        end = perf_counter_ns()
        times.append(end - start)
    return torch.tensor(times).float()

def report_stats(times, unit="ms"):
    mul = {
        "ns": 1,
        "µs": 1e-3,
        "ms": 1e-6,
        "s": 1e-9,
    }[unit]
    times = times * mul
    std = times.std().item()
    med = times.median().item()
    print(f"{med = :.2f}{unit} +- {std:.2f}")
    return med

def main() -> None:
    """Benchmarks the init of VideoDecoder with different seek_modes"""
    resources_dir = "/Users/danielflores3/torchcodec/test/resources/"

    # video=resources_dir+"nasa_13013.mp4"
    video=resources_dir+"mandelbrot_1920x1080_120s.mp4"

    mappings_json = subprocess.run(
            [
                "ffprobe",
                "-i",
                f"{video}",
                "-select_streams",
                "0",
                "-show_frames",
                "-show_entries", "frame=pts,key_frame,duration",
                "-of",
                "json",
            ],
            check=True,
            capture_output=True,
            text=True,
        ).stdout
    
    print("exact:")
    report_stats(bench(VideoDecoder, source=video, seek_mode="exact", stream_index=0))

    print("custom_frame_mappings:")
    report_stats(bench(VideoDecoder, source=video, custom_frame_mappings=mappings_json, stream_index=0))


if __name__ == "__main__":
    main()
</details>

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 1, 2025
@NicolasHug
Copy link
Contributor

Thanks for the benchmarks! That's quite interesting. It'd be interesting to use the core API instead of the VideoDecoder in the benchmarks, so that we can measure the time for add_video_stream() specifically when passing the frame mapping as tensors.

If that's faster than add_video_stream(seek_mode="exact"), as we would expect, then it may be that what's slowing down the frame-mappings is the io or the json parsing parts?

@Dan-Flores
Copy link
Contributor Author

Dan-Flores commented Aug 1, 2025

... then it may be that what's slowing down the frame-mappings is the io or the json parsing parts?

I believe this is it - I updated the PR description with my updated benchmarking script and results. By reducing the JSON size to contain only the necessary information, the performance of custom_frame_mapping improves significantly.

Edit: Here's my alternative code for benchmarking the core API by passing the frame mapping as tensors. It shows a similar performance improvement to reducing the JSON, so this could be another viable design approach.

import subprocess
import torch
from time import perf_counter_ns

from torchcodec import _core
from torchcodec.decoders._video_decoder import VideoDecoder, read_custom_frame_mappings


def bench(f, *args, num_exp=100, warmup=0, **kwargs):

    for _ in range(warmup):
        f(*args, **kwargs)

    times = []
    for _ in range(num_exp):
        start = perf_counter_ns()
        f(*args, **kwargs)
        end = perf_counter_ns()
        times.append(end - start)
    return torch.tensor(times).float()

def report_stats(times, unit="ms"):
    mul = {
        "ns": 1,
        "µs": 1e-3,
        "ms": 1e-6,
        "s": 1e-9,
    }[unit]
    times = times * mul
    std = times.std().item()
    med = times.median().item()
    print(f"{med = :.2f}{unit} +- {std:.2f}")
    return med

def main() -> None:
    """Benchmarks the init of VideoDecoder with different seek_modes"""
    resources_dir = "/Users/danielflores3/torchcodec/test/resources/"

    # video=resources_dir+"nasa_13013.mp4"
    video=resources_dir+"mandelbrot_1920x1080_120s.mp4"

    mappings_json = subprocess.run(
            [
                "ffprobe",
                "-i",
                f"{video}",
                "-select_streams",
                "0",
                "-show_frames",
                "-show_entries", "frame=pts,key_frame,duration",
                "-of",
                "json",
            ],
            check=True,
            capture_output=True,
            text=True,
        ).stdout
    
    # Get tensors of each, Pass into core function directly
    custom_frame_mappings_data = read_custom_frame_mappings(mappings_json)
    args = [("exact", None), ("custom_frame_mappings", custom_frame_mappings_data)]
    for seek_mode, frame_mappings_data in args:
        # benchmark speed up in add_video_stream
        print(f"{seek_mode=}")
        report_stats(bench(init_add_stream, 
            video,
            seek_mode,
            frame_mappings_data,
        ))

def init_add_stream(video, seek_mode, frame_mappings_data):
    decoder = _core.create_from_file(str(video), seek_mode)
    _core.add_video_stream( 
            decoder,
            stream_index=0,
            dimension_order="NCHW",
            num_threads=1,
            device="cpu",
            custom_frame_mappings=frame_mappings_data,
        )

if __name__ == "__main__":
    main()

@Dan-Flores Dan-Flores force-pushed the init_with_frame_mappings branch from 897e15d to 0d661f5 Compare August 1, 2025 21:21
@Dan-Flores Dan-Flores force-pushed the init_with_frame_mappings branch from e220d14 to be9466b Compare August 20, 2025 12:11
@Dan-Flores Dan-Flores changed the title [wip] Update VideoDecoder init Add custom_frame_mappings to VideoDecoder init Aug 20, 2025
@pytest.mark.skipif(
get_ffmpeg_major_version() in (4, 5),
reason="ffprobe isn't accurate on ffmpeg 4 and 5",
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ffprobe does include the fields pts and duration in older versions, using the name pkt_pts and pkt_duration, so we can enable these tests by falling back to the pts... keys.

@Dan-Flores Dan-Flores marked this pull request as ready for review August 20, 2025 16:12
custom_frame_mappings_data = (
read_custom_frame_mappings(custom_frame_mappings)
if custom_frame_mappings is not None
else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ever valid for seek_mode == "custom_frame_mappings" and custom_frame_mappings_data is None to both be true? If no, then we should probably raise an error here.

Actually, shouldn't line 93 mean that custom_frame_mappings cannot be None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, that case would not be valid.
Currently, creating a VideoDecoder with seek_mode = custom_frame_mappings and custom_frame_mappings=None will throw an error in add_video_stream, which is called on line 121.

Do you think it would be helpful to error sooner?

except json.JSONDecodeError:
raise ValueError(
"Invalid custom frame mappings. "
"It should be a valid JSON string or a JSON file object."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our type annotation, we only say we accept string-like objects, not file-like objects. In other parts of the code, I've indicated a file-like object with io.RawIOBase and io.BufferedReader: https://github.com/pytorch/torchcodec/blob/c3eea9f9f42d3d51f9c53ba19be6c11cc88c21dd/src/torchcodec/decoders/_video_decoder.py#L74-L77

For our doc string, however, we just say "file-like object" since those actual types are probably meaningless to most users: https://github.com/pytorch/torchcodec/blob/c3eea9f9f42d3d51f9c53ba19be6c11cc88c21dd/src/torchcodec/decoders/_video_decoder.py#L26

Which introduces a good point: since this a top-level API that we expect users to call, we need a docstring. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the references, I'll add a docstring and use the term file-like object for uniformity.

device: Optional[Union[str, torch_device]] = "cpu",
seek_mode: Literal["exact", "approximate"] = "exact",
seek_mode: Literal["exact", "approximate", "custom_frame_mappings"] = "exact",
custom_frame_mappings: Optional[Union[bytes, bytearray, str]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two values should also be explained in the docstring.

"It should be a valid JSON string or a JSON file object."
)
# These keys are prefixed with "pkt_" in ffmpeg 4 and ffmpeg 5
pts_key = "pkt_pts" if "pts" not in input_data["frames"][0] else "pts"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, but it would be helpful to users to do explicit error checks on our assumptions about the data:

  1. Assert that input_data is non-empty.
  2. Assert that "frames" is a valid key.
  3. Assert that one of the valid pts and duration keys exist. (Rather than assuming it must be the other one.)

If any of the above isn't true, we'll still fail, but it will just be with KeyError exceptions that the cause of may not be obvious to users.

stream_index=0,
device=device,
custom_frame_mappings=f.read(),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the comments I left in read_custom_frame_mappings(), I think we should also test for valid JSON that does not follow the schema that FFmpeg produces.

test/utils.py Outdated
def create_custom_frame_mappings(self, stream_index: int) -> None:
result = json.loads(self.generate_custom_frame_mappings(stream_index))
# These keys are prefixed with "pkt_" in ffmpeg 4 and ffmpeg 5
pts_key = "pkt_pts" if "pts" not in result["frames"][0] else "pts"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's simpler to just call read_custom_frame_mappings() here. I believe it's accomplishing the same thing, and as that function evolves, so to will this utils function need to evolve.

@Dan-Flores Dan-Flores marked this pull request as draft August 21, 2025 14:32
},
...
]
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also say something about the relationship with seek_mode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also add that in order to accomodate for different FFmpeg versions we also allow pkt_pts and pkt_duration ?

def test_custom_frame_mappings_init_fails_invalid_json(self, tmp_path, device):
invalid_json_path = tmp_path / "invalid_json"
with open(invalid_json_path, "w+") as f:
f.write("""'{"invalid": "json"'""")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this is not obviously invalid - it's invalid because it's missing the closing }, but I had to read really closely to catch that! Part of the reason I had difficulty is that there's so much quoting going on. I think it might be better to do something like:

  f.write("garbage input")

That is, something that doesn't even look like JSON at a glance.

@Dan-Flores Dan-Flores marked this pull request as ready for review August 22, 2025 14:20
Copy link
Contributor

@scotts scotts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! There's a few small things we can address before merging.

)


def read_custom_frame_mappings(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
def read_custom_frame_mappings(
def _read_custom_frame_mappings(

This makes it more obvious that it is private. It is already private even without the leading underscore, because it's only exposed within _video_decoder.py which itself has the underscore. But explicitly having an underscore makes it even more obvious, and consistent with the other private functions defined in this file

num_ffmpeg_threads: int = 1,
device: Optional[Union[str, torch_device]] = "cpu",
seek_mode: Literal["exact", "approximate"] = "exact",
seek_mode: Literal["exact", "approximate", "custom_frame_mappings"] = "exact",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC we decided not to expose custom_frame_mappings at the public level, and just force users to rely on the default seek_mode (approximate) if they pass custom_frames_mapping? That's consistent with the existing (correct) check below:

        if custom_frame_mappings is not None and seek_mode == "approximate":
            raise ValueError(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe @NicolasHug is right here. I completely forgot about that.


if not pts_key or not duration_key or not key_frame_present:
raise ValueError(
"Invalid custom frame mappings. The 'pts', 'duration', and 'key_frame' keys are required in the frame metadata."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably want to indicate that pkt_pts and pkt_duration are also OK?

(float(frame[pts_key]), frame["key_frame"], float(frame[duration_key]))
for frame in input_data["frames"]
]
all_frames, is_key_frame, duration = map(Tensor, zip(*frame_data))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be torch.tensor, not Tensor which is a type

Also in general, list comprehensions are preferred to map (but we can keep it as-is)

Comment on lines 470 to 472
assert (
len(all_frames) == len(is_key_frame) == len(duration)
), "Mismatched lengths in frame index data"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For user-facing assertions we prefer if + ValueError rather than assert which will lead to an AssertionError. The ValueError is more explicitly about bad user-input.

decoder = VideoDecoder(asset.path)
decoder.get_frame_at(10)

def setup_frame_mappings(tmp_path: str, file: bool, stream_index: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit tmp_path isn't a str, it's Path, otherwise we wouldn't be able to apply / to it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - since type annotations are not used elsewhere in this file I will remove them from this function.

stream_index=stream_index,
device=device,
custom_frame_mappings=custom_frame_mappings,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: everything below can probably be out of the context manager. It's preferable to end the CM's scope as soon as it's not needed.

@Dan-Flores Dan-Flores merged commit f9d80d4 into meta-pytorch:main Aug 25, 2025
47 checks passed
@Dan-Flores Dan-Flores deleted the init_with_frame_mappings branch August 25, 2025 14:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants