From 9d441c33f01c13c68ef1d873e56120e1b2b92dc6 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 4 Sep 2024 10:27:10 -0700 Subject: [PATCH] Add benchmark for batch decoding. --- benchmarks/decoders/benchmark_decoders.py | 45 ++++++++++++++++++++++- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/benchmarks/decoders/benchmark_decoders.py b/benchmarks/decoders/benchmark_decoders.py index 9bf5733e1..fb4cf0e66 100644 --- a/benchmarks/decoders/benchmark_decoders.py +++ b/benchmarks/decoders/benchmark_decoders.py @@ -7,6 +7,7 @@ import abc import argparse import importlib +import json import os import timeit @@ -17,7 +18,10 @@ from torchcodec.decoders._core import ( add_video_stream, create_from_file, + get_frames_at_indices, + get_json_metadata, get_next_frame, + scan_all_streams_to_update_metadata, seek_to_pts, ) @@ -143,6 +147,39 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode): return frames +class TorchCodecDecoderNonCompiledBatch(AbstractDecoder): + def __init__(self, num_threads=None): + self._print_each_iteration_time = False + self._num_threads = num_threads + + def get_frames_from_video(self, video_file, pts_list): + decoder = create_from_file(video_file) + scan_all_streams_to_update_metadata(decoder) + add_video_stream(decoder, num_threads=self._num_threads) + metadata = json.loads(get_json_metadata(decoder)) + average_fps = metadata["averageFps"] + best_video_stream = metadata["bestVideoStreamIndex"] + indexes_list = [int(pts * average_fps) for pts in pts_list] + frames = [] + frames = get_frames_at_indices( + decoder, stream_index=best_video_stream, frame_indices=indexes_list + ) + return frames + + def get_consecutive_frames_from_video(self, video_file, numFramesToDecode): + decoder = create_from_file(video_file) + scan_all_streams_to_update_metadata(decoder) + add_video_stream(decoder, num_threads=self._num_threads) + metadata = json.loads(get_json_metadata(decoder)) + best_video_stream = metadata["bestVideoStreamIndex"] + frames = [] + indices_list = list(range(numFramesToDecode)) + frames = get_frames_at_indices( + decoder, stream_index=best_video_stream, frame_indices=indices_list + ) + return frames + + @torch.compile(fullgraph=True, backend="eager") def compiled_seek_and_next(decoder, pts): seek_to_pts(decoder, pts) @@ -257,9 +294,9 @@ def main() -> None: ) parser.add_argument( "--decoders", - help="Comma-separated list of decoders to benchmark. Choices are torchcodec, torchaudio, torchvision, decord, torchcodec1, torchcodec_compiled. torchcodec1 means torchcodec with num_threads=1. torchcodec_compiled means torch.compiled torchcodec", + help="Comma-separated list of decoders to benchmark. Choices are torchcodec, torchaudio, torchvision, decord, torchcodec1, torchcodec_compiled. torchcodec1 means torchcodec with num_threads=1. torchcodec_compiled means torch.compiled torchcodec. torchcodec_batch means torchcodec using batch methods.", type=str, - default="decord,torchcodec,torchvision,torchaudio,torchcodec1,torchcodec_compiled", + default="decord,torchcodec,torchvision,torchaudio,torchcodec1,torchcodec_compiled,torchcodec_batch", ) args = parser.parse_args() @@ -291,6 +328,10 @@ def main() -> None: ) if "torchaudio" in decoders: decoder_dict["TorchAudioDecoder"] = TorchAudioDecoder() + if "torchcodec_batch" in decoders: + decoder_dict["TorchCodecDecoderNonCompiledBatch"] = ( + TorchCodecDecoderNonCompiledBatch() + ) decoder_dict["TVNewAPIDecoderWithBackendVideoReader"]