Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions benchmarks/decoders/benchmark_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import abc
import argparse
import importlib
import json
import os
import timeit

Expand All @@ -17,7 +18,10 @@
from torchcodec.decoders._core import (
add_video_stream,
create_from_file,
get_frames_at_indices,
get_json_metadata,
get_next_frame,
scan_all_streams_to_update_metadata,
seek_to_pts,
)

Expand Down Expand Up @@ -143,6 +147,39 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
return frames


class TorchCodecDecoderNonCompiledBatch(AbstractDecoder):
def __init__(self, num_threads=None):
self._print_each_iteration_time = False
self._num_threads = num_threads

def get_frames_from_video(self, video_file, pts_list):
decoder = create_from_file(video_file)
scan_all_streams_to_update_metadata(decoder)
add_video_stream(decoder, num_threads=self._num_threads)
metadata = json.loads(get_json_metadata(decoder))
average_fps = metadata["averageFps"]
best_video_stream = metadata["bestVideoStreamIndex"]
indexes_list = [int(pts * average_fps) for pts in pts_list]
frames = []
frames = get_frames_at_indices(
decoder, stream_index=best_video_stream, frame_indices=indexes_list
)
return frames

def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
decoder = create_from_file(video_file)
scan_all_streams_to_update_metadata(decoder)
add_video_stream(decoder, num_threads=self._num_threads)
metadata = json.loads(get_json_metadata(decoder))
best_video_stream = metadata["bestVideoStreamIndex"]
frames = []
indices_list = list(range(numFramesToDecode))
frames = get_frames_at_indices(
decoder, stream_index=best_video_stream, frame_indices=indices_list
)
return frames


@torch.compile(fullgraph=True, backend="eager")
def compiled_seek_and_next(decoder, pts):
seek_to_pts(decoder, pts)
Expand Down Expand Up @@ -257,9 +294,9 @@ def main() -> None:
)
parser.add_argument(
"--decoders",
help="Comma-separated list of decoders to benchmark. Choices are torchcodec, torchaudio, torchvision, decord, torchcodec1, torchcodec_compiled. torchcodec1 means torchcodec with num_threads=1. torchcodec_compiled means torch.compiled torchcodec",
help="Comma-separated list of decoders to benchmark. Choices are torchcodec, torchaudio, torchvision, decord, torchcodec1, torchcodec_compiled. torchcodec1 means torchcodec with num_threads=1. torchcodec_compiled means torch.compiled torchcodec. torchcodec_batch means torchcodec using batch methods.",
type=str,
default="decord,torchcodec,torchvision,torchaudio,torchcodec1,torchcodec_compiled",
default="decord,torchcodec,torchvision,torchaudio,torchcodec1,torchcodec_compiled,torchcodec_batch",
)

args = parser.parse_args()
Expand Down Expand Up @@ -291,6 +328,10 @@ def main() -> None:
)
if "torchaudio" in decoders:
decoder_dict["TorchAudioDecoder"] = TorchAudioDecoder()
if "torchcodec_batch" in decoders:
decoder_dict["TorchCodecDecoderNonCompiledBatch"] = (
TorchCodecDecoderNonCompiledBatch()
)

decoder_dict["TVNewAPIDecoderWithBackendVideoReader"]

Expand Down
Loading