-
Notifications
You must be signed in to change notification settings - Fork 64
Added doc for nvdec #335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Added doc for nvdec #335
Changes from all commits
Commits
Show all changes
77 commits
Select commit
Hold shift + click to select a range
3ff6292
Added doc for nvdec
ahmadsharif1 1fd5a10
.
ahmadsharif1 fa3e3b9
.
ahmadsharif1 36a5420
.
ahmadsharif1 f49baca
.
ahmadsharif1 f087a91
.
ahmadsharif1 5092418
.
ahmadsharif1 243e2ca
.
ahmadsharif1 7c6c033
.
ahmadsharif1 e40ec7a
.
ahmadsharif1 bb4bff9
.
ahmadsharif1 e8a5b07
.
ahmadsharif1 c9d54a4
.
ahmadsharif1 fb633e4
.
ahmadsharif1 9e334cd
.
ahmadsharif1 c107e02
.
ahmadsharif1 885c43f
.
ahmadsharif1 dd937c6
.
ahmadsharif1 bab07db
.
ahmadsharif1 60b06e1
.
ahmadsharif1 904bfa3
.
ahmadsharif1 75e76ee
.
ahmadsharif1 16218ac
.
ahmadsharif1 e8f0128
.
ahmadsharif1 9c36f4e
.
ahmadsharif1 2406435
.
ahmadsharif1 7b78be3
.
ahmadsharif1 20c6fba
.
ahmadsharif1 7630fdd
.
ahmadsharif1 37bfa5c
.
ahmadsharif1 24f2843
.
ahmadsharif1 4cb95a2
.
ahmadsharif1 4055346
.
ahmadsharif1 63bbb9e
.
ahmadsharif1 51e2308
.
ahmadsharif1 a926934
.
ahmadsharif1 400001a
.
ahmadsharif1 ccf95da
.
ahmadsharif1 209e746
.
ahmadsharif1 8d66147
.
ahmadsharif1 0a8ae5f
.
ahmadsharif1 8864b30
.
ahmadsharif1 936cbd1
.
ahmadsharif1 49197b5
.
ahmadsharif1 8291aa6
.
ahmadsharif1 4e10d0b
.
ahmadsharif1 b90bc7f
.
ahmadsharif1 2ae49ac
.
ahmadsharif1 f0444d4
.
ahmadsharif1 3d95977
.
ahmadsharif1 5cbccd0
.
ahmadsharif1 bf81cbe
.
ahmadsharif1 0ca9469
.
ahmadsharif1 64a9ebd
.
ahmadsharif1 30d9be7
.
ahmadsharif1 c91e73c
.
ahmadsharif1 0f50210
.
ahmadsharif1 f8d5e69
.
ahmadsharif1 5a4291a
.
ahmadsharif1 af3f684
.
ahmadsharif1 891125b
.
ahmadsharif1 9809feb
.
ahmadsharif1 92e2aef
.
ahmadsharif1 8d206f4
Merge branch 'main' of https://github.com/pytorch/torchcodec into doc1
ahmadsharif1 893c490
.
ahmadsharif1 2a106ca
.
ahmadsharif1 39f4606
.
ahmadsharif1 a51dfbd
.
ahmadsharif1 f29b05c
.
ahmadsharif1 dfa9fcc
Merge branch 'main' of https://github.com/pytorch/torchcodec into doc1
ahmadsharif1 3003be2
.
ahmadsharif1 015f355
.
ahmadsharif1 bde6324
Fixed typo
ahmadsharif1 3304341
used cuda instead of cuda:0
ahmadsharif1 3d74528
addressed comments
ahmadsharif1 f79dfa4
addressed comments
ahmadsharif1 cc1c3a8
.
ahmadsharif1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
""" | ||
Accelerated video decoding on GPUs with CUDA and NVDEC | ||
================================================================ | ||
|
||
TorchCodec can use supported Nvidia hardware (see support matrix | ||
`here <https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new>`_) to speed-up | ||
video decoding. This is called "CUDA Decoding" and it uses Nvidia's | ||
`NVDEC hardware decoder <https://developer.nvidia.com/video-codec-sdk>`_ | ||
and CUDA kernels to respectively decompress and convert to RGB. | ||
CUDA Decoding can be faster than CPU Decoding for the actual decoding step and also for | ||
subsequent transform steps like scaling, cropping or rotating. This is because the decode step leaves | ||
the decoded tensor in GPU memory so the GPU doesn't have to fetch from main memory before | ||
running the transform steps. Encoded packets are often much smaller than decoded frames so | ||
CUDA decoding also uses less PCI-e bandwidth. | ||
|
||
CUDA Decoding can offer speed-up over CPU Decoding in a few scenarios: | ||
|
||
#. You are decoding a large resolution video | ||
#. You are decoding a large batch of videos that's saturating the CPU | ||
#. You want to do whole-image transforms like scaling or convolutions on the decoded tensors | ||
after decoding | ||
#. Your CPU is saturated and you want to free it up for other work | ||
|
||
|
||
Here are situations where CUDA Decoding may not make sense: | ||
|
||
#. You want bit-exact results compared to CPU Decoding | ||
ahmadsharif1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#. You have small resolution videos and the PCI-e transfer latency is large | ||
#. Your GPU is already busy and CPU is not | ||
|
||
It's best to experiment with CUDA Decoding to see if it improves your use-case. With | ||
TorchCodec you can simply pass in a device parameter to the | ||
:class:`~torchcodec.decoders.VideoDecoder` class to use CUDA Decoding. | ||
|
||
|
||
In order to use CUDA Decoding will need the following installed in your environment: | ||
|
||
#. An Nvidia GPU that supports decoding the video format you want to decode. See | ||
the support matrix `here <https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new>`_ | ||
#. `CUDA-enabled pytorch <https://pytorch.org/get-started/locally/>`_ | ||
#. FFmpeg binaries that support | ||
`NVDEC-enabled <https://docs.nvidia.com/video-technologies/video-codec-sdk/12.0/ffmpeg-with-nvidia-gpu/index.html>`_ | ||
codecs | ||
#. libnpp and nvrtc (these are usually installed when you install the full cuda-toolkit) | ||
|
||
|
||
FFmpeg versions 5, 6 and 7 from conda-forge are built with | ||
`NVDEC support <https://docs.nvidia.com/video-technologies/video-codec-sdk/12.0/ffmpeg-with-nvidia-gpu/index.html>`_ | ||
and you can install them with conda. For example, to install FFmpeg version 7: | ||
|
||
|
||
ahmadsharif1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.. code-block:: bash | ||
|
||
conda install ffmpeg=7 -c conda-forge | ||
conda install libnpp cuda-nvrtc -c nvidia | ||
|
||
|
||
""" | ||
|
||
# %% | ||
# Checking if Pytorch has CUDA enabled | ||
# ------------------------------------- | ||
# | ||
# .. note:: | ||
# | ||
# This tutorial requires FFmpeg libraries compiled with CUDA support. | ||
# | ||
# | ||
import torch | ||
|
||
print(f"{torch.__version__=}") | ||
print(f"{torch.cuda.is_available()=}") | ||
print(f"{torch.cuda.get_device_properties(0)=}") | ||
|
||
|
||
# %% | ||
# Downloading the video | ||
# ------------------------------------- | ||
# | ||
# We will use the following video which has the following properties: | ||
# | ||
# - Codec: H.264 | ||
# - Resolution: 960x540 | ||
# - FPS: 29.97 | ||
# - Pixel format: YUV420P | ||
# | ||
# .. raw:: html | ||
# | ||
# <video style="max-width: 100%" controls> | ||
# <source src="https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4" type="video/mp4"> | ||
# </video> | ||
import urllib.request | ||
|
||
video_file = "video.mp4" | ||
urllib.request.urlretrieve( | ||
"https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4", | ||
video_file, | ||
) | ||
|
||
|
||
# %% | ||
# CUDA Decoding using VideoDecoder | ||
# ------------------------------------- | ||
# | ||
# To use CUDA decoder, you need to pass in a cuda device to the decoder. | ||
# | ||
from torchcodec.decoders import VideoDecoder | ||
|
||
decoder = VideoDecoder(video_file, device="cuda") | ||
frame = decoder[0] | ||
|
||
# %% | ||
# | ||
# The video frames are decoded and returned as tensor of NCHW format. | ||
|
||
print(frame.shape, frame.dtype) | ||
|
||
# %% | ||
# | ||
# The video frames are left on the GPU memory. | ||
|
||
print(frame.data.device) | ||
|
||
|
||
# %% | ||
# Visualizing Frames | ||
# ------------------------------------- | ||
# | ||
# Let's look at the frames decoded by CUDA decoder and compare them | ||
# against equivalent results from the CPU decoders. | ||
timestamps = [12, 19, 45, 131, 180] | ||
cpu_decoder = VideoDecoder(video_file, device="cpu") | ||
cuda_decoder = VideoDecoder(video_file, device="cuda") | ||
cpu_frames = cpu_decoder.get_frames_played_at(timestamps).data | ||
cuda_frames = cuda_decoder.get_frames_played_at(timestamps).data | ||
|
||
|
||
def plot_cpu_and_cuda_frames(cpu_frames: torch.Tensor, cuda_frames: torch.Tensor): | ||
try: | ||
import matplotlib.pyplot as plt | ||
from torchvision.transforms.v2.functional import to_pil_image | ||
except ImportError: | ||
print("Cannot plot, please run `pip install torchvision matplotlib`") | ||
return | ||
n_rows = len(timestamps) | ||
fig, axes = plt.subplots(n_rows, 2, figsize=[12.8, 16.0]) | ||
for i in range(n_rows): | ||
axes[i][0].imshow(to_pil_image(cpu_frames[i].to("cpu"))) | ||
axes[i][1].imshow(to_pil_image(cuda_frames[i].to("cpu"))) | ||
|
||
axes[0][0].set_title("CPU decoder", fontsize=24) | ||
axes[0][1].set_title("CUDA decoder", fontsize=24) | ||
plt.setp(axes, xticks=[], yticks=[]) | ||
plt.tight_layout() | ||
|
||
|
||
plot_cpu_and_cuda_frames(cpu_frames, cuda_frames) | ||
|
||
# %% | ||
# | ||
# They look visually similar to the human eye but there may be subtle | ||
# differences because CUDA math is not bit-exact with respect to CPU math. | ||
# | ||
frames_equal = torch.equal(cpu_frames.to("cuda"), cuda_frames) | ||
mean_abs_diff = torch.mean( | ||
torch.abs(cpu_frames.float().to("cuda") - cuda_frames.float()) | ||
) | ||
max_abs_diff = torch.max(torch.abs(cpu_frames.to("cuda").float() - cuda_frames.float())) | ||
print(f"{frames_equal=}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Instead of this binary indicator, we may want to print max abs diff and mean abs diff? We could even do it across all frames. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
print(f"{mean_abs_diff=}") | ||
print(f"{max_abs_diff=}") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we understand why we need CONDA_RUN here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything is installed in a conda env. Without CONDA_RUN the
pip
that's used is the outside pip