diff --git a/README.md b/README.md index 40603254e..8643e0e3a 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,7 @@ ffmpeg -f lavfi -i \ ``` ## Installing TorchCodec +### Installing CPU-only TorchCodec 1. Install the latest stable version of PyTorch following the [official instructions](https://pytorch.org/get-started/locally/). For other @@ -127,9 +128,65 @@ The following table indicates the compatibility between versions of | not yet supported | `2.5` | `>=3.9`, `<=3.12` | | `0.0.3` | `2.4` | `>=3.8`, `<=3.12` | +### Installing CUDA-enabled TorchCodec + +First, make sure you have a GPU that has NVDEC hardware that can decode the +format you want. Refer to Nvidia's GPU support matrix for more details +[here](https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new). + +1. Install CUDA Toolkit. Pytorch and TorchCodec supports CUDA Toolkit + versions 11.8, 12.1 or 12.4. In particular TorchCodec depends on + CUDA libraries libnpp and libnvrtc (which are part of CUDA Toolkit). + +2. Install Pytorch that corresponds to your CUDA Toolkit version using the + [official instructions](https://pytorch.org/get-started/locally/). + +3. Install or compile FFmpeg with NVDEC support. + TorchCodec with CUDA should work with FFmpeg versions in [4, 7]. + + If FFmpeg is not already installed, or you need a more recent version, an + easy way to install it is to use `conda`: + + ```bash + conda install ffmpeg + # or + conda install ffmpeg -c conda-forge + ``` + + If you are building FFmpeg from source you can follow Nvidia's guide to + configuring and installing FFmpeg with NVDEC support + [here](https://docs.nvidia.com/video-technologies/video-codec-sdk/12.0/ffmpeg-with-nvidia-gpu/index.html). + + After installing FFmpeg make sure it has NVDEC support when you list the supported + decoders: + + ```bash + ffmpeg -decoders | grep -i nvidia + # This should show a line like this: + # V..... h264_cuvid Nvidia CUVID H264 decoder (codec h264) + ``` + + To check that FFmpeg libraries work with NVDEC correctly you can decode a sample video: + + ```bash + ffmpeg -hwaccel cuda -hwaccel_output_format cuda -i test/resources/nasa_13013.mp4 -f null - + ``` + +4. Install TorchCodec by passing in an `--index-url` parameter that corresponds to your CUDA + Toolkit version, example: + + ```bash + # This corresponds to CUDA Toolkit version 12.4 and nightly Pytorch. + pip install torchcodec --index-url=https://download.pytorch.org/whl/nightly/cu124 + ``` + + Note that without passing in the `--index-url` parameter, `pip` installs TorchCodec + binaries from PyPi which are CPU-only and do not have CUDA support. + ## Benchmark Results -The following was generated by running [our benchmark script](./benchmarks/decoders/generate_readme_data.py) on a lightly loaded 56-core machine. +The following was generated by running [our benchmark script](./benchmarks/decoders/generate_readme_data.py) on a lightly loaded 22-core machine with an Nvidia A100 with +5 [NVDEC decoders](https://docs.nvidia.com/video-technologies/video-codec-sdk/12.1/nvdec-application-note/index.html#). ![benchmark_results](./benchmarks/decoders/benchmark_readme_chart.png) @@ -156,3 +213,10 @@ guide](CONTRIBUTING.md) for more details. ## License TorchCodec is released under the [BSD 3 license](./LICENSE). + +However, TorchCodec may be used with code not written by Meta which may be +distributed under different licenses. + +For example, if you build TorchCodec with ENABLE_CUDA=1 or use the CUDA-enabled +release of torchcodec, please review CUDA's license here: +[Nvidia licenses](https://docs.nvidia.com/cuda/eula/index.html). diff --git a/benchmarks/decoders/benchmark_decoders_library.py b/benchmarks/decoders/benchmark_decoders_library.py index 33d25c637..a5102fd32 100644 --- a/benchmarks/decoders/benchmark_decoders_library.py +++ b/benchmarks/decoders/benchmark_decoders_library.py @@ -38,6 +38,14 @@ def __init__(self): def get_frames_from_video(self, video_file, pts_list): pass + @abc.abstractmethod + def get_consecutive_frames_from_video(self, video_file, numFramesToDecode): + pass + + @abc.abstractmethod + def decode_and_transform(self, video_file, pts_list, height, width, device): + pass + class DecordAccurate(AbstractDecoder): def __init__(self): @@ -89,8 +97,10 @@ def __init__(self, backend): self._backend = backend self._print_each_iteration_time = False import torchvision # noqa: F401 + from torchvision.transforms import v2 as transforms_v2 self.torchvision = torchvision + self.transforms_v2 = transforms_v2 def get_frames_from_video(self, video_file, pts_list): self.torchvision.set_video_backend(self._backend) @@ -111,6 +121,20 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode): frames.append(frame["data"].permute(1, 2, 0)) return frames + def decode_and_transform(self, video_file, pts_list, height, width, device): + self.torchvision.set_video_backend(self._backend) + reader = self.torchvision.io.VideoReader(video_file, "video") + frames = [] + for pts in pts_list: + reader.seek(pts) + frame = next(reader) + frames.append(frame["data"].permute(1, 2, 0)) + frames = [ + self.transforms_v2.functional.resize(frame.to(device), (height, width)) + for frame in frames + ] + return frames + class TorchCodecCore(AbstractDecoder): def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"): @@ -239,6 +263,10 @@ def __init__(self, num_ffmpeg_threads=None, device="cpu"): ) self._device = device + from torchvision.transforms import v2 as transforms_v2 + + self.transforms_v2 = transforms_v2 + def get_frames_from_video(self, video_file, pts_list): decoder = VideoDecoder( video_file, num_ffmpeg_threads=self._num_ffmpeg_threads, device=self._device @@ -258,6 +286,14 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode): break return frames + def decode_and_transform(self, video_file, pts_list, height, width, device): + decoder = VideoDecoder( + video_file, num_ffmpeg_threads=self._num_ffmpeg_threads, device=self._device + ) + frames = decoder.get_frames_played_at(pts_list) + frames = self.transforms_v2.functional.resize(frames.data, (height, width)) + return frames + @torch.compile(fullgraph=True, backend="eager") def compiled_seek_and_next(decoder, pts): @@ -299,7 +335,9 @@ def __init__(self): self.torchaudio = torchaudio - pass + from torchvision.transforms import v2 as transforms_v2 + + self.transforms_v2 = transforms_v2 def get_frames_from_video(self, video_file, pts_list): stream_reader = self.torchaudio.io.StreamReader(src=video_file) @@ -325,6 +363,21 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode): return frames + def decode_and_transform(self, video_file, pts_list, height, width, device): + stream_reader = self.torchaudio.io.StreamReader(src=video_file) + stream_reader.add_basic_video_stream(frames_per_chunk=1) + frames = [] + for pts in pts_list: + stream_reader.seek(pts) + stream_reader.fill_buffer() + clip = stream_reader.pop_chunks() + frames.append(clip[0][0]) + frames = [ + self.transforms_v2.functional.resize(frame.to(device), (height, width)) + for frame in frames + ] + return frames + def create_torchcodec_decoder_from_file(video_file): video_decoder = create_from_file(video_file) @@ -443,7 +496,7 @@ def plot_data(df_data, plot_path): # Set the title for the subplot base_video = Path(video).name.removesuffix(".mp4") - ax.set_title(f"{base_video}\n{vcount} x {vtype}", fontsize=11) + ax.set_title(f"{base_video}\n{vtype}", fontsize=11) # Plot bars with error bars ax.barh( @@ -486,6 +539,14 @@ class BatchParameters: batch_size: int +@dataclass +class DataLoaderInspiredWorkloadParameters: + batch_parameters: BatchParameters + resize_height: int + resize_width: int + resize_device: str + + def run_batch_using_threads( function, *args, @@ -525,6 +586,7 @@ def run_benchmarks( num_sequential_frames_from_start: list[int], min_runtime_seconds: float, benchmark_video_creation: bool, + dataloader_parameters: DataLoaderInspiredWorkloadParameters = None, batch_parameters: BatchParameters = None, ) -> list[dict[str, str | float | int]]: # Ensure that we have the same seed across benchmark runs. @@ -550,6 +612,39 @@ def run_benchmarks( for decoder_name, decoder in decoder_dict.items(): print(f"video={video_file_path}, decoder={decoder_name}") + if dataloader_parameters: + bp = dataloader_parameters.batch_parameters + dataloader_result = benchmark.Timer( + stmt="run_batch_using_threads(decoder.decode_and_transform, video_file, pts_list, height, width, device, batch_parameters=batch_parameters)", + globals={ + "video_file": str(video_file_path), + "pts_list": uniform_pts_list, + "decoder": decoder, + "run_batch_using_threads": run_batch_using_threads, + "batch_parameters": dataloader_parameters.batch_parameters, + "height": dataloader_parameters.resize_height, + "width": dataloader_parameters.resize_width, + "device": dataloader_parameters.resize_device, + }, + label=f"video={video_file_path} {metadata_label}", + sub_label=decoder_name, + description=f"dataloader[threads={bp.num_threads} batch_size={bp.batch_size}] {num_samples} decode_and_transform()", + ) + results.append( + dataloader_result.blocked_autorange( + min_run_time=min_runtime_seconds + ) + ) + df_data.append( + convert_result_to_df_item( + results[-1], + decoder_name, + video_file_path, + num_samples * dataloader_parameters.batch_parameters.batch_size, + f"dataloader[threads={bp.num_threads} batch_size={bp.batch_size}] {num_samples} x decode_and_transform()", + ) + ) + for kind, pts_list in [ ("uniform", uniform_pts_list), ("random", random_pts_list), diff --git a/benchmarks/decoders/benchmark_readme_chart.png b/benchmarks/decoders/benchmark_readme_chart.png index 5dec9348a..f433f6199 100644 Binary files a/benchmarks/decoders/benchmark_readme_chart.png and b/benchmarks/decoders/benchmark_readme_chart.png differ diff --git a/benchmarks/decoders/benchmark_readme_data.json b/benchmarks/decoders/benchmark_readme_data.json index 28c79303d..7e4354d49 100644 --- a/benchmarks/decoders/benchmark_readme_data.json +++ b/benchmarks/decoders/benchmark_readme_data.json @@ -1,297 +1,394 @@ [ { - "decoder": "TorchCodec", + "decoder": "torchcodec", + "description": "dataloader[threads=8 batch_size=64] 10 decode_and_transform()", + "fps_median": 3.722043790887239, + "fps_p25": 3.722043790887239, + "fps_p75": 3.722043790887239, + "frame_count": 640, + "iqr": 0.0, + "median": 171.94854116626084, + "type": "dataloader[threads=8 batch_size=64] 10 x decode_and_transform()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" + }, + { + "decoder": "torchcodec", "description": "uniform 10 seek()+next()", - "fps_median": 2.874245162330032, - "fps_p25": 2.890669218048133, - "fps_p75": 2.862767020928218, + "fps_median": 1.4397872541120416, + "fps_p25": 1.451428122051996, + "fps_p75": 1.4351890231275515, "frame_count": 10, - "iqr": 0.03371739387512207, - "median": 3.4791743345558643, - "type": "uniform:seek()+next()", - "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1280x720_120s_60fps_600gop_libx264_yuv420p.mp4" + "iqr": 0.07795738987624645, + "median": 6.945470569655299, + "type": "uniform seek()+next()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" }, { - "decoder": "TorchCodec", + "decoder": "torchcodec", "description": "random 10 seek()+next()", - "fps_median": 3.658458368739899, - "fps_p25": 3.679569429876782, - "fps_p75": 3.606357184854245, + "fps_median": 1.736215701229184, + "fps_p25": 1.7631990988377526, + "fps_p75": 1.725260270510397, "frame_count": 10, - "iqr": 0.05517190555110574, - "median": 2.733391771093011, - "type": "random:seek()+next()", - "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1280x720_120s_60fps_600gop_libx264_yuv420p.mp4" + "iqr": 0.12471765535883605, + "median": 5.7596530159935355, + "type": "random seek()+next()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" }, { - "decoder": "TorchCodec", + "decoder": "torchcodec", "description": "100 next()", - "fps_median": 248.32328677696285, - "fps_p25": 250.6501813668334, - "fps_p75": 244.31468271091225, + "fps_median": 125.18633474918164, + "fps_p25": 131.35032274770882, + "fps_p75": 121.19682688127735, "frame_count": 100, - "iqr": 0.010345779359340668, - "median": 0.402700855396688, - "type": "next()", - "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1280x720_120s_60fps_600gop_libx264_yuv420p.mp4" + "iqr": 0.0637812758795917, + "median": 0.7988092326559126, + "type": "100 next()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" + }, + { + "decoder": "torchcodec[cuda]", + "description": "dataloader[threads=8 batch_size=64] 10 decode_and_transform()", + "fps_median": 6.71338393971077, + "fps_p25": 6.71338393971077, + "fps_p75": 6.71338393971077, + "frame_count": 640, + "iqr": 0.0, + "median": 95.33195267058909, + "type": "dataloader[threads=8 batch_size=64] 10 x decode_and_transform()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" }, { - "decoder": "TorchVision[video_reader]", + "decoder": "torchcodec[cuda]", "description": "uniform 10 seek()+next()", - "fps_median": 0.4195638935454161, - "fps_p25": 0.42014650010734295, - "fps_p75": 0.4189829005177118, + "fps_median": 1.243389954704256, + "fps_p25": 1.2474136405315202, + "fps_p75": 1.231124865014404, "frame_count": 10, - "iqr": 0.06610076874494553, - "median": 23.834272095002234, - "type": "uniform:seek()+next()", - "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1280x720_120s_60fps_600gop_libx264_yuv420p.mp4" + "iqr": 0.1060659158974886, + "median": 8.042529185768217, + "type": "uniform seek()+next()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" }, { - "decoder": "TorchVision[video_reader]", + "decoder": "torchcodec[cuda]", "description": "random 10 seek()+next()", - "fps_median": 0.32245125406966435, - "fps_p25": 0.32245125406966435, - "fps_p75": 0.32245125406966435, + "fps_median": 1.5460842780696675, + "fps_p25": 1.5614610546171193, + "fps_p75": 1.5057133085494405, "frame_count": 10, - "iqr": 0.0, - "median": 31.01243947353214, - "type": "random:seek()+next()", - "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1280x720_120s_60fps_600gop_libx264_yuv420p.mp4" + "iqr": 0.2371121821925044, + "median": 6.467952712439001, + "type": "random seek()+next()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" }, { - "decoder": "TorchVision[video_reader]", + "decoder": "torchcodec[cuda]", "description": "100 next()", - "fps_median": 176.82997455955987, - "fps_p25": 177.54379221046926, - "fps_p75": 175.14625035317215, + "fps_median": 178.9611538836387, + "fps_p25": 182.15441067650494, + "fps_p75": 124.20054668713973, "frame_count": 100, - "iqr": 0.00771009735763073, - "median": 0.5655149826779962, - "type": "next()", - "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1280x720_120s_60fps_600gop_libx264_yuv420p.mp4" + "iqr": 0.2561646499671042, + "median": 0.5587804829701781, + "type": "100 next()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" + }, + { + "decoder": "torchvision[video_reader]", + "description": "dataloader[threads=8 batch_size=64] 10 decode_and_transform()", + "fps_median": 1.5697915959359987, + "fps_p25": 1.5697915959359987, + "fps_p75": 1.5697915959359987, + "frame_count": 640, + "iqr": 0.0, + "median": 407.69743044674397, + "type": "dataloader[threads=8 batch_size=64] 10 x decode_and_transform()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" }, { - "decoder": "TorchAudio", + "decoder": "torchvision[video_reader]", "description": "uniform 10 seek()+next()", - "fps_median": 0.5316661236830815, - "fps_p25": 0.5318804166321828, - "fps_p75": 0.5314520033403498, + "fps_median": 0.1978081509193199, + "fps_p25": 0.1978081509193199, + "fps_p75": 0.1978081509193199, "frame_count": 10, - "iqr": 0.015156010165810585, - "median": 18.80879663862288, - "type": "uniform:seek()+next()", - "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1280x720_120s_60fps_600gop_libx264_yuv420p.mp4" + "iqr": 0.0, + "median": 50.5540340654552, + "type": "uniform seek()+next()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" }, { - "decoder": "TorchAudio", + "decoder": "torchvision[video_reader]", "description": "random 10 seek()+next()", - "fps_median": 0.417209378798961, - "fps_p25": 0.41758998612516984, - "fps_p75": 0.4168294646408316, + "fps_median": 0.14806268866290584, + "fps_p25": 0.14806268866290584, + "fps_p75": 0.14806268866290584, "frame_count": 10, - "iqr": 0.04369210824370384, - "median": 23.968780444934964, - "type": "random:seek()+next()", - "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1280x720_120s_60fps_600gop_libx264_yuv420p.mp4" + "iqr": 0.0, + "median": 67.53895995207131, + "type": "random seek()+next()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" }, { - "decoder": "TorchAudio", + "decoder": "torchvision[video_reader]", "description": "100 next()", - "fps_median": 179.4697720392447, - "fps_p25": 181.05508626841646, - "fps_p75": 173.49148405860208, + "fps_median": 90.71105513124185, + "fps_p25": 93.42498301579607, + "fps_p75": 66.99561256477986, "frame_count": 100, - "iqr": 0.024079074384644628, - "median": 0.5571968965232372, - "type": "next()", - "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1280x720_120s_60fps_600gop_libx264_yuv420p.mp4" + "iqr": 0.4222575547173619, + "median": 1.1024014642462134, + "type": "100 next()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" }, { - "decoder": "Decord", + "decoder": "torchaudio", + "description": "dataloader[threads=8 batch_size=64] 10 decode_and_transform()", + "fps_median": 0.22995229905603712, + "fps_p25": 0.22995229905603712, + "fps_p75": 0.22995229905603712, + "frame_count": 640, + "iqr": 0.0, + "median": 2783.1859156321734, + "type": "dataloader[threads=8 batch_size=64] 10 x decode_and_transform()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" + }, + { + "decoder": "torchaudio", "description": "uniform 10 seek()+next()", - "fps_median": 2.9254250604823127, - "fps_p25": 2.928776037979067, - "fps_p75": 2.9179279307467434, + "fps_median": 0.2240825403534482, + "fps_p25": 0.2240825403534482, + "fps_p75": 0.2240825403534482, "frame_count": 10, - "iqr": 0.012693846598267555, - "median": 3.4183066710829735, - "type": "uniform:seek()+next()", - "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1280x720_120s_60fps_600gop_libx264_yuv420p.mp4" + "iqr": 0.0, + "median": 44.62641303613782, + "type": "uniform seek()+next()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" }, { - "decoder": "Decord", + "decoder": "torchaudio", "description": "random 10 seek()+next()", - "fps_median": 2.3913952159683447, - "fps_p25": 2.409423905905687, - "fps_p75": 2.379609551240287, + "fps_median": 0.17111174619117042, + "fps_p25": 0.17111174619117042, + "fps_p75": 0.17111174619117042, "frame_count": 10, - "iqr": 0.05200037732720375, - "median": 4.181659281253815, - "type": "random:seek()+next()", - "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1280x720_120s_60fps_600gop_libx264_yuv420p.mp4" + "iqr": 0.0, + "median": 58.44134153611958, + "type": "random seek()+next()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" }, { - "decoder": "Decord", + "decoder": "torchaudio", "description": "100 next()", - "fps_median": 301.64058598625485, - "fps_p25": 304.4257754819803, - "fps_p75": 297.8336145342091, + "fps_median": 71.79386812605445, + "fps_p25": 72.51511528517597, + "fps_p75": 71.07345912080734, "frame_count": 100, - "iqr": 0.0072706404607743025, - "median": 0.33152037439867854, - "type": "next()", - "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1280x720_120s_60fps_600gop_libx264_yuv420p.mp4" + "iqr": 0.027972140349447727, + "median": 1.3928766148164868, + "type": "100 next()", + "video": "/tmp/torchcodec_benchmarking_videos/mandelbrot_1920x1080_120s_60fps_600gop_libx264_yuv420p.mp4" + }, + { + "decoder": "torchcodec", + "description": "dataloader[threads=8 batch_size=64] 10 decode_and_transform()", + "fps_median": 67.79397429639815, + "fps_p25": 68.59028564862771, + "fps_p75": 65.71375822403033, + "frame_count": 640, + "iqr": 0.40844123042188585, + "median": 9.440367033239454, + "type": "dataloader[threads=8 batch_size=64] 10 x decode_and_transform()", + "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" }, { - "decoder": "TorchCodec", + "decoder": "torchcodec", "description": "uniform 10 seek()+next()", - "fps_median": 32.797487611763884, - "fps_p25": 33.16049598127707, - "fps_p75": 32.34604228526151, + "fps_median": 28.46557984604316, + "fps_p25": 29.922295359471935, + "fps_p75": 26.73851645668455, "frame_count": 10, - "iqr": 0.007593189366161823, - "median": 0.30490140337496996, - "type": "uniform:seek()+next()", + "iqr": 0.03979336703196168, + "median": 0.3513014684431255, + "type": "uniform seek()+next()", "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" }, { - "decoder": "TorchCodec", + "decoder": "torchcodec", "description": "random 10 seek()+next()", - "fps_median": 31.86583108909739, - "fps_p25": 32.422775044534646, - "fps_p75": 31.51582155415764, + "fps_median": 29.945412645922318, + "fps_p25": 30.85213034698643, + "fps_p75": 28.594925068542125, "frame_count": 10, - "iqr": 0.008875773288309574, - "median": 0.313815759960562, - "type": "random:seek()+next()", + "iqr": 0.025585678406059742, + "median": 0.3339409651234746, + "type": "random seek()+next()", "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" }, { - "decoder": "TorchCodec", + "decoder": "torchcodec", "description": "100 next()", - "fps_median": 478.59019817346393, - "fps_p25": 483.5137156807738, - "fps_p75": 469.99482336428133, + "fps_median": 614.0952680137761, + "fps_p25": 628.2472414155694, + "fps_p75": 593.4204378926277, "frame_count": 100, - "iqr": 0.00594893516972661, - "median": 0.20894702896475792, - "type": "next()", + "iqr": 0.00934158405289054, + "median": 0.16284118313342333, + "type": "100 next()", + "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" + }, + { + "decoder": "torchcodec[cuda]", + "description": "dataloader[threads=8 batch_size=64] 10 decode_and_transform()", + "fps_median": 14.269868091545673, + "fps_p25": 14.269868091545673, + "fps_p75": 14.269868091545673, + "frame_count": 640, + "iqr": 0.0, + "median": 44.84974884800613, + "type": "dataloader[threads=8 batch_size=64] 10 x decode_and_transform()", "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" }, { - "decoder": "TorchVision[video_reader]", + "decoder": "torchcodec[cuda]", "description": "uniform 10 seek()+next()", - "fps_median": 5.863026312201653, - "fps_p25": 5.883880399939407, - "fps_p75": 5.821213575132222, + "fps_median": 3.3887557630838336, + "fps_p25": 3.3996850012955604, + "fps_p75": 3.378710590173761, "frame_count": 10, - "iqr": 0.01829617563635111, - "median": 1.705603807233274, - "type": "uniform:seek()+next()", + "iqr": 0.018259971868246794, + "median": 2.950935593806207, + "type": "uniform seek()+next()", "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" }, { - "decoder": "TorchVision[video_reader]", + "decoder": "torchcodec[cuda]", "description": "random 10 seek()+next()", - "fps_median": 3.5804448256145283, - "fps_p25": 3.610017214846564, - "fps_p75": 3.573725565752598, + "fps_median": 4.206317288181642, + "fps_p25": 4.215251838912759, + "fps_p75": 4.200058005638581, "frame_count": 10, - "iqr": 0.028130420949310064, - "median": 2.7929490571841598, - "type": "random:seek()+next()", + "iqr": 0.008582000620663166, + "median": 2.3773765303194523, + "type": "random seek()+next()", "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" }, { - "decoder": "TorchVision[video_reader]", + "decoder": "torchcodec[cuda]", "description": "100 next()", - "fps_median": 220.60506211225706, - "fps_p25": 221.92824221952606, - "fps_p75": 219.83775682122163, + "fps_median": 108.15849644090362, + "fps_p25": 108.87629308883722, + "fps_p75": 107.1563304726156, "frame_count": 100, - "iqr": 0.004284817026928067, - "median": 0.45329875499010086, - "type": "next()", + "iqr": 0.014742388390004635, + "median": 0.92456906568259, + "type": "100 next()", "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" }, { - "decoder": "TorchAudio", + "decoder": "torchvision[video_reader]", + "description": "dataloader[threads=8 batch_size=64] 10 decode_and_transform()", + "fps_median": 34.07782542607838, + "fps_p25": 34.40001338466056, + "fps_p75": 33.76161664373731, + "frame_count": 640, + "iqr": 0.35179429268464446, + "median": 18.780541070271283, + "type": "dataloader[threads=8 batch_size=64] 10 x decode_and_transform()", + "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" + }, + { + "decoder": "torchvision[video_reader]", "description": "uniform 10 seek()+next()", - "fps_median": 10.562701139154996, - "fps_p25": 10.594120999307123, - "fps_p75": 10.475383401305544, + "fps_median": 4.782367451226253, + "fps_p25": 4.8422619348709155, + "fps_p75": 4.588920179004961, "frame_count": 10, - "iqr": 0.010699251666665077, - "median": 0.9467275338247418, - "type": "uniform:seek()+next()", + "iqr": 0.11401132540777326, + "median": 2.091014565899968, + "type": "uniform seek()+next()", "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" }, { - "decoder": "TorchAudio", + "decoder": "torchvision[video_reader]", "description": "random 10 seek()+next()", - "fps_median": 7.143372898069971, - "fps_p25": 7.190431420792876, - "fps_p75": 6.984323268168379, + "fps_median": 2.9594605005777126, + "fps_p25": 3.0142969757047573, + "fps_p75": 2.939646297328533, "frame_count": 10, - "iqr": 0.041040806798264384, - "median": 1.3998989192768931, - "type": "random:seek()+next()", + "iqr": 0.08424665033817291, + "median": 3.37899424508214, + "type": "random seek()+next()", "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" }, { - "decoder": "TorchAudio", + "decoder": "torchvision[video_reader]", "description": "100 next()", - "fps_median": 234.31507730203276, - "fps_p25": 235.5241203289182, - "fps_p75": 233.2609776710573, + "fps_median": 195.10398914661542, + "fps_p25": 198.12837004507918, + "fps_p75": 179.9937080258182, "frame_count": 100, - "iqr": 0.004119404591619968, - "median": 0.42677578050643206, - "type": "next()", + "iqr": 0.05085169989615679, + "median": 0.5125471828505397, + "type": "100 next()", + "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" + }, + { + "decoder": "torchaudio", + "description": "dataloader[threads=8 batch_size=64] 10 decode_and_transform()", + "fps_median": 9.635713291548546, + "fps_p25": 9.635713291548546, + "fps_p75": 9.635713291548546, + "frame_count": 640, + "iqr": 0.0, + "median": 66.41957690473646, + "type": "dataloader[threads=8 batch_size=64] 10 x decode_and_transform()", "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" }, { - "decoder": "Decord", + "decoder": "torchaudio", "description": "uniform 10 seek()+next()", - "fps_median": 28.951774908715404, - "fps_p25": 29.60339324217526, - "fps_p75": 28.130013392739134, + "fps_median": 9.561191528064604, + "fps_p25": 9.803393073168538, + "fps_p75": 9.349710586702399, "frame_count": 10, - "iqr": 0.017693073954433203, - "median": 0.3454019669443369, - "type": "uniform:seek()+next()", + "iqr": 0.049496835097670555, + "median": 1.0458947475999594, + "type": "uniform seek()+next()", "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" }, { - "decoder": "Decord", + "decoder": "torchaudio", "description": "random 10 seek()+next()", - "fps_median": 23.55224702573544, - "fps_p25": 23.829998507803513, - "fps_p75": 23.327513783118945, + "fps_median": 6.485374499539733, + "fps_p25": 6.651050187497567, + "fps_p75": 6.3834662410329805, "frame_count": 10, - "iqr": 0.00903920829296112, - "median": 0.424587938003242, - "type": "random:seek()+next()", + "iqr": 0.06302505941130221, + "median": 1.5419310019351542, + "type": "random seek()+next()", "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" }, { - "decoder": "Decord", + "decoder": "torchaudio", "description": "100 next()", - "fps_median": 526.5873947695661, - "fps_p25": 534.5632246100912, - "fps_p75": 514.0865116923063, + "fps_median": 258.3824937093069, + "fps_p25": 265.91286414117894, + "fps_p75": 250.26730738957093, "frame_count": 100, - "iqr": 0.007451178273186088, - "median": 0.18990200106054544, - "type": "next()", + "iqr": 0.02350972522981465, + "median": 0.38702312437817454, + "type": "100 next()", "video": "/tmp/torchcodec_benchmarking_videos/nasa_960x540_206s_30fps_yuv420p.mp4" }, { - "cpu_count": 56, + "cpu_count": 22, + "is_cuda_available": "True", "machine": "x86_64", "processor": "x86_64", - "python_version": "3.11.10", + "python_version": "3.12.5", "system": "Linux" } ] diff --git a/benchmarks/decoders/generate_readme_data.py b/benchmarks/decoders/generate_readme_data.py index 277cd3752..3dee39130 100644 --- a/benchmarks/decoders/generate_readme_data.py +++ b/benchmarks/decoders/generate_readme_data.py @@ -10,8 +10,11 @@ import shutil from pathlib import Path +import torch + from benchmark_decoders_library import ( - DecordAccurateBatch, + BatchParameters, + DataLoaderInspiredWorkloadParameters, generate_videos, retrieve_videos, run_benchmarks, @@ -27,43 +30,45 @@ def main() -> None: """Benchmarks the performance of a few video decoders on synthetic videos""" videos_dir_path = "/tmp/torchcodec_benchmarking_videos" - shutil.rmtree(videos_dir_path, ignore_errors=True) - os.makedirs(videos_dir_path) + if not os.path.exists(videos_dir_path): + shutil.rmtree(videos_dir_path, ignore_errors=True) + os.makedirs(videos_dir_path) - resolutions = ["1280x720"] - encodings = ["libx264"] - patterns = ["mandelbrot"] - fpses = [60] - gop_sizes = [600] - durations = [120] - pix_fmts = ["yuv420p"] - ffmpeg_path = "ffmpeg" - generate_videos( - resolutions, - encodings, - patterns, - fpses, - gop_sizes, - durations, - pix_fmts, - ffmpeg_path, - videos_dir_path, - ) + resolutions = ["1920x1080"] + encodings = ["libx264"] + patterns = ["mandelbrot"] + fpses = [60] + gop_sizes = [600] + durations = [120] + pix_fmts = ["yuv420p"] + ffmpeg_path = "ffmpeg" + generate_videos( + resolutions, + encodings, + patterns, + fpses, + gop_sizes, + durations, + pix_fmts, + ffmpeg_path, + videos_dir_path, + ) - urls_and_dest_paths = [ - (NASA_URL, f"{videos_dir_path}/nasa_960x540_206s_30fps_yuv420p.mp4") - ] - retrieve_videos(urls_and_dest_paths) + urls_and_dest_paths = [ + (NASA_URL, f"{videos_dir_path}/nasa_960x540_206s_30fps_yuv420p.mp4") + ] + retrieve_videos(urls_and_dest_paths) decoder_dict = {} - decoder_dict["TorchCodec"] = TorchCodecPublic() - decoder_dict["TorchVision[video_reader]"] = TorchVision("video_reader") - decoder_dict["TorchAudio"] = TorchAudioDecoder() - decoder_dict["Decord"] = DecordAccurateBatch() + decoder_dict["torchcodec"] = TorchCodecPublic() + decoder_dict["torchcodec[cuda]"] = TorchCodecPublic(device="cuda") + decoder_dict["torchvision[video_reader]"] = TorchVision("video_reader") + decoder_dict["torchaudio"] = TorchAudioDecoder() # These are the number of uniform seeks we do in the seek+decode benchmark. num_samples = 10 video_files_paths = list(Path(videos_dir_path).glob("*.mp4")) + assert len(video_files_paths) == 2, "Expected exactly 2 videos" df_data = run_benchmarks( decoder_dict, video_files_paths, @@ -71,6 +76,12 @@ def main() -> None: num_sequential_frames_from_start=[100], min_runtime_seconds=30, benchmark_video_creation=False, + dataloader_parameters=DataLoaderInspiredWorkloadParameters( + batch_parameters=BatchParameters(batch_size=64, num_threads=8), + resize_height=256, + resize_width=256, + resize_device="cuda", + ), ) df_data.append( { @@ -79,6 +90,7 @@ def main() -> None: "machine": platform.machine(), "processor": platform.processor(), "python_version": str(platform.python_version()), + "is_cuda_available": str(torch.cuda.is_available()), } ) diff --git a/examples/basic_cuda_example.py b/examples/basic_cuda_example.py index 5ff85e8e0..7c29e4475 100644 --- a/examples/basic_cuda_example.py +++ b/examples/basic_cuda_example.py @@ -18,6 +18,9 @@ running the transform steps. Encoded packets are often much smaller than decoded frames so CUDA decoding also uses less PCI-e bandwidth. +When to and when not to use CUDA Decoding +----------------------------------------- + CUDA Decoding can offer speed-up over CPU Decoding in a few scenarios: #. You are decoding a large resolution video @@ -37,28 +40,10 @@ TorchCodec you can simply pass in a device parameter to the :class:`~torchcodec.decoders.VideoDecoder` class to use CUDA Decoding. +Installing TorchCodec with CUDA Enabled +--------------------------------------- -In order to use CUDA Decoding will need the following installed in your environment: - -#. An Nvidia GPU that supports decoding the video format you want to decode. See - the support matrix `here `_ -#. `CUDA-enabled pytorch `_ -#. FFmpeg binaries that support - `NVDEC-enabled `_ - codecs -#. libnpp and nvrtc (these are usually installed when you install the full cuda-toolkit) - - -FFmpeg versions 5, 6 and 7 from conda-forge are built with -`NVDEC support `_ -and you can install them with conda. For example, to install FFmpeg version 7: - - -.. code-block:: bash - - conda install ffmpeg=7 -c conda-forge - conda install libnpp cuda-nvrtc -c nvidia - +Refer to the installation guide in the `README `_. """ diff --git a/setup.py b/setup.py index 7588c000f..9120c7fe0 100644 --- a/setup.py +++ b/setup.py @@ -196,6 +196,7 @@ def _write_version_files(): print("INFO: Didn't find sha. Is this a git repo?") with open(_ROOT_DIR / "src/torchcodec/version.py", "w") as f: + f.write("# Note that this file is generated during install.\n") f.write(f"__version__ = '{version}'\n") diff --git a/src/torchcodec/__init__.py b/src/torchcodec/__init__.py index 2b83eb035..e96aa5b33 100644 --- a/src/torchcodec/__init__.py +++ b/src/torchcodec/__init__.py @@ -10,6 +10,7 @@ from . import decoders, samplers # noqa try: + # Note that version.py is generated during install. from .version import __version__ # noqa: F401 except Exception: pass diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 8aa464e4e..75dea7d8f 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -223,13 +223,24 @@ void convertAVFrameToDecodedOutputOnCuda( Npp8u* input[2] = {src->data[0], src->data[1]}; auto start = std::chrono::high_resolution_clock::now(); - NppStatus status = nppiNV12ToRGB_8u_P2C3R( - input, - src->linesize[0], - static_cast(dst.data_ptr()), - dst.stride(0), - oSizeROI); + NppStatus status; + if (src->colorspace == AVColorSpace::AVCOL_SPC_BT709) { + status = nppiNV12ToRGB_709HDTV_8u_P2C3R( + input, + src->linesize[0], + static_cast(dst.data_ptr()), + dst.stride(0), + oSizeROI); + } else { + status = nppiNV12ToRGB_8u_P2C3R( + input, + src->linesize[0], + static_cast(dst.data_ptr()), + dst.stride(0), + oSizeROI); + } TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); + // Make the pytorch stream wait for the npp kernel to finish before using the // output. at::cuda::CUDAEvent nppDoneEvent; diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index eb3b55db6..ec56d9fe9 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -118,8 +118,10 @@ def test_get_frame_at_pts(self, device): # return the next frame since the right boundary of the interval is # open. next_frame, _, _ = get_frame_at_pts(decoder, 6.039367) - with pytest.raises(AssertionError): - frame_compare_function(next_frame, reference_frame6.to(device)) + if device == "cpu": + # We can only compare exact equality on CPU. + with pytest.raises(AssertionError): + frame_compare_function(next_frame, reference_frame6.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_at_index(self, device): diff --git a/test/utils.py b/test/utils.py index 89957da7a..0fa843851 100644 --- a/test/utils.py +++ b/test/utils.py @@ -44,7 +44,7 @@ def assert_tensor_equal(*args, **kwargs): # Asserts that at least `percentage`% of the values are within the absolute tolerance. def assert_tensor_close_on_at_least( - actual_tensor, ref_tensor, percentage=90, abs_tolerance=20 + actual_tensor, ref_tensor, percentage=90, abs_tolerance=19 ): assert ( actual_tensor.device == ref_tensor.device