Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
3ff6292
Added doc for nvdec
ahmadsharif1 Nov 5, 2024
1fd5a10
.
ahmadsharif1 Nov 5, 2024
fa3e3b9
.
ahmadsharif1 Nov 5, 2024
36a5420
.
ahmadsharif1 Nov 5, 2024
f49baca
.
ahmadsharif1 Nov 5, 2024
f087a91
.
ahmadsharif1 Nov 5, 2024
5092418
.
ahmadsharif1 Nov 5, 2024
243e2ca
.
ahmadsharif1 Nov 5, 2024
7c6c033
.
ahmadsharif1 Nov 5, 2024
e40ec7a
.
ahmadsharif1 Nov 5, 2024
bb4bff9
.
ahmadsharif1 Nov 5, 2024
e8a5b07
.
ahmadsharif1 Nov 5, 2024
c9d54a4
.
ahmadsharif1 Nov 5, 2024
fb633e4
.
ahmadsharif1 Nov 6, 2024
9e334cd
.
ahmadsharif1 Nov 6, 2024
c107e02
.
ahmadsharif1 Nov 6, 2024
885c43f
.
ahmadsharif1 Nov 6, 2024
dd937c6
.
ahmadsharif1 Nov 6, 2024
bab07db
.
ahmadsharif1 Nov 6, 2024
60b06e1
.
ahmadsharif1 Nov 6, 2024
904bfa3
.
ahmadsharif1 Nov 6, 2024
75e76ee
.
ahmadsharif1 Nov 6, 2024
16218ac
.
ahmadsharif1 Nov 6, 2024
e8f0128
.
ahmadsharif1 Nov 6, 2024
9c36f4e
.
ahmadsharif1 Nov 6, 2024
2406435
.
ahmadsharif1 Nov 6, 2024
7b78be3
.
ahmadsharif1 Nov 6, 2024
20c6fba
.
ahmadsharif1 Nov 6, 2024
7630fdd
.
ahmadsharif1 Nov 6, 2024
37bfa5c
.
ahmadsharif1 Nov 6, 2024
24f2843
.
ahmadsharif1 Nov 6, 2024
4cb95a2
.
ahmadsharif1 Nov 6, 2024
4055346
.
ahmadsharif1 Nov 6, 2024
63bbb9e
.
ahmadsharif1 Nov 6, 2024
51e2308
.
ahmadsharif1 Nov 6, 2024
a926934
.
ahmadsharif1 Nov 6, 2024
400001a
.
ahmadsharif1 Nov 6, 2024
ccf95da
.
ahmadsharif1 Nov 7, 2024
209e746
.
ahmadsharif1 Nov 7, 2024
8d66147
.
ahmadsharif1 Nov 7, 2024
0a8ae5f
.
ahmadsharif1 Nov 7, 2024
8864b30
.
ahmadsharif1 Nov 7, 2024
936cbd1
.
ahmadsharif1 Nov 7, 2024
49197b5
.
ahmadsharif1 Nov 7, 2024
8291aa6
.
ahmadsharif1 Nov 7, 2024
4e10d0b
.
ahmadsharif1 Nov 7, 2024
b90bc7f
.
ahmadsharif1 Nov 7, 2024
2ae49ac
.
ahmadsharif1 Nov 7, 2024
f0444d4
.
ahmadsharif1 Nov 7, 2024
3d95977
.
ahmadsharif1 Nov 7, 2024
5cbccd0
.
ahmadsharif1 Nov 8, 2024
bf81cbe
.
ahmadsharif1 Nov 8, 2024
0ca9469
.
ahmadsharif1 Nov 8, 2024
64a9ebd
.
ahmadsharif1 Nov 8, 2024
30d9be7
.
ahmadsharif1 Nov 8, 2024
c91e73c
.
ahmadsharif1 Nov 8, 2024
0f50210
.
ahmadsharif1 Nov 8, 2024
f8d5e69
.
ahmadsharif1 Nov 8, 2024
5a4291a
.
ahmadsharif1 Nov 8, 2024
af3f684
.
ahmadsharif1 Nov 8, 2024
891125b
.
ahmadsharif1 Nov 8, 2024
9809feb
.
ahmadsharif1 Nov 8, 2024
92e2aef
.
ahmadsharif1 Nov 8, 2024
8d206f4
Merge branch 'main' of https://github.com/pytorch/torchcodec into doc1
ahmadsharif1 Nov 8, 2024
893c490
.
ahmadsharif1 Nov 8, 2024
2a106ca
.
ahmadsharif1 Nov 9, 2024
39f4606
.
ahmadsharif1 Nov 9, 2024
a51dfbd
.
ahmadsharif1 Nov 10, 2024
f29b05c
.
ahmadsharif1 Nov 10, 2024
dfa9fcc
Merge branch 'main' of https://github.com/pytorch/torchcodec into doc1
ahmadsharif1 Nov 11, 2024
3003be2
.
ahmadsharif1 Nov 11, 2024
015f355
.
ahmadsharif1 Nov 11, 2024
bde6324
Fixed typo
ahmadsharif1 Nov 11, 2024
3304341
used cuda instead of cuda:0
ahmadsharif1 Nov 11, 2024
3d74528
addressed comments
ahmadsharif1 Nov 11, 2024
f79dfa4
addressed comments
ahmadsharif1 Nov 11, 2024
cc1c3a8
.
ahmadsharif1 Nov 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 86 additions & 20 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,111 @@ on:
branches: [ main ]
pull_request:

permissions:
id-token: write
contents: write

defaults:
run:
shell: bash -l -eo pipefail {0}

jobs:
generate-matrix:
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
with:
package-type: wheel
os: linux
test-infra-repository: pytorch/test-infra
test-infra-ref: main
with-cpu: disable
with-xpu: disable
with-rocm: disable
with-cuda: enable
build-python-only: "disable"
build:
runs-on: ubuntu-latest
needs: generate-matrix
strategy:
fail-fast: false
name: Build and Upload wheel
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
with:
repository: pytorch/torchcodec
ref: ""
test-infra-repository: pytorch/test-infra
test-infra-ref: main
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/fake_smoke_test.py
package-name: torchcodec
trigger-event: ${{ github.event_name }}
build-platform: "python-build-package"
build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 ENABLE_CUDA=1 python -m build --wheel -vvv --no-isolation"

build-docs:
runs-on: linux.4xlarge.nvidia.gpu
strategy:
fail-fast: false
matrix:
# 3.9 corresponds to the minimum python version for which we build
# the wheel unless the label cliflow/binaries/all is present in the
# PR.
python-version: ['3.9']
cuda-version: ['12.4']
ffmpeg-version-for-tests: ['7']
container:
image: "pytorch/manylinux-builder:cuda${{ matrix.cuda-version }}"
options: "--gpus all -e NVIDIA_DRIVER_CAPABILITIES=video,compute,utility"
needs: build
steps:
- name: Check out repo
uses: actions/checkout@v3
- name: Setup conda env
uses: conda-incubator/setup-miniconda@v2
- name: Setup env vars
run: |
cuda_version_without_periods=$(echo "${{ matrix.cuda-version }}" | sed 's/\.//g')
echo cuda_version_without_periods=${cuda_version_without_periods} >> $GITHUB_ENV
- uses: actions/download-artifact@v3
with:
auto-update-conda: true
miniconda-version: "latest"
activate-environment: test
python-version: '3.12'
name: pytorch_torchcodec__3.9_cu${{ env.cuda_version_without_periods }}_x86_64
path: pytorch/torchcodec/dist/
- name: Setup miniconda using test-infra
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
with:
python-version: ${{ matrix.python-version }}
#
# For some reason nvidia::libnpp=12.4 doesn't install but nvidia/label/cuda-12.4.0::libnpp does.
# So we use the latter convention for libnpp.
# We install conda packages at the start because otherwise conda may have conflicts with dependencies.
default-packages: "nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }} conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }}"
- name: Check env
run: |
${CONDA_RUN} env
${CONDA_RUN} conda info
${CONDA_RUN} nvidia-smi
${CONDA_RUN} conda list
- name: Assert ffmpeg exists
run: |
${CONDA_RUN} ffmpeg -buildconf
- name: Update pip
run: python -m pip install --upgrade pip
- name: Install dependencies and FFmpeg
run: ${CONDA_RUN} python -m pip install --upgrade pip
- name: Install PyTorch
run: |
# TODO: torchvision and torchaudio shouldn't be needed. They were only added
# to silence an error as seen in https://github.com/pytorch/torchcodec/issues/203
python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
conda install "ffmpeg=7.0.1" pkg-config -c conda-forge
ffmpeg -version
- name: Build and install torchcodec
${CONDA_RUN} python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu${{ env.cuda_version_without_periods }}
${CONDA_RUN} python -c 'import torch; print(f"{torch.__version__}"); print(f"{torch.__file__}"); print(f"{torch.cuda.is_available()=}")'
- name: Install torchcodec from the wheel
run: |
python -m pip install -e ".[dev]" --no-build-isolation -vvv
wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"`
echo Installing $wheel_path
${CONDA_RUN} python -m pip install $wheel_path -vvv

- name: Check out repo
uses: actions/checkout@v3

- name: Install doc dependencies
run: |
cd docs
python -m pip install -r requirements.txt
${CONDA_RUN} python -m pip install -r requirements.txt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we understand why we need CONDA_RUN here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything is installed in a conda env. Without CONDA_RUN the pip that's used is the outside pip

- name: Build docs
run: |
cd docs
make html
${CONDA_RUN} make html
- uses: actions/upload-artifact@v3
with:
name: Built-Docs
Expand Down
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ We achieve these capabilities through:

How to sample video clips

.. grid-item-card:: :octicon:`file-code;1em`
GPU decoding using TorchCodec
:img-top: _static/img/card-background.svg
:link: generated_examples/basic_cuda_example.html
:link-type: url

A simple example demonstrating CUDA GPU decoding

.. toctree::
:maxdepth: 1
:caption: TorchCodec documentation
Expand Down
176 changes: 176 additions & 0 deletions examples/basic_cuda_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Accelerated video decoding on GPUs with CUDA and NVDEC
================================================================

TorchCodec can use supported Nvidia hardware (see support matrix
`here <https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new>`_) to speed-up
video decoding. This is called "CUDA Decoding" and it uses Nvidia's
`NVDEC hardware decoder <https://developer.nvidia.com/video-codec-sdk>`_
and CUDA kernels to respectively decompress and convert to RGB.
CUDA Decoding can be faster than CPU Decoding for the actual decoding step and also for
subsequent transform steps like scaling, cropping or rotating. This is because the decode step leaves
the decoded tensor in GPU memory so the GPU doesn't have to fetch from main memory before
running the transform steps. Encoded packets are often much smaller than decoded frames so
CUDA decoding also uses less PCI-e bandwidth.

CUDA Decoding can offer speed-up over CPU Decoding in a few scenarios:

#. You are decoding a large resolution video
#. You are decoding a large batch of videos that's saturating the CPU
#. You want to do whole-image transforms like scaling or convolutions on the decoded tensors
after decoding
#. Your CPU is saturated and you want to free it up for other work


Here are situations where CUDA Decoding may not make sense:

#. You want bit-exact results compared to CPU Decoding
#. You have small resolution videos and the PCI-e transfer latency is large
#. Your GPU is already busy and CPU is not

It's best to experiment with CUDA Decoding to see if it improves your use-case. With
TorchCodec you can simply pass in a device parameter to the
:class:`~torchcodec.decoders.VideoDecoder` class to use CUDA Decoding.


In order to use CUDA Decoding will need the following installed in your environment:

#. An Nvidia GPU that supports decoding the video format you want to decode. See
the support matrix `here <https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new>`_
#. `CUDA-enabled pytorch <https://pytorch.org/get-started/locally/>`_
#. FFmpeg binaries that support
`NVDEC-enabled <https://docs.nvidia.com/video-technologies/video-codec-sdk/12.0/ffmpeg-with-nvidia-gpu/index.html>`_
codecs
#. libnpp and nvrtc (these are usually installed when you install the full cuda-toolkit)


FFmpeg versions 5, 6 and 7 from conda-forge are built with
`NVDEC support <https://docs.nvidia.com/video-technologies/video-codec-sdk/12.0/ffmpeg-with-nvidia-gpu/index.html>`_
and you can install them with conda. For example, to install FFmpeg version 7:


.. code-block:: bash

conda install ffmpeg=7 -c conda-forge
conda install libnpp cuda-nvrtc -c nvidia


"""

# %%
# Checking if Pytorch has CUDA enabled
# -------------------------------------
#
# .. note::
#
# This tutorial requires FFmpeg libraries compiled with CUDA support.
#
#
import torch

print(f"{torch.__version__=}")
print(f"{torch.cuda.is_available()=}")
print(f"{torch.cuda.get_device_properties(0)=}")


# %%
# Downloading the video
# -------------------------------------
#
# We will use the following video which has the following properties:
#
# - Codec: H.264
# - Resolution: 960x540
# - FPS: 29.97
# - Pixel format: YUV420P
#
# .. raw:: html
#
# <video style="max-width: 100%" controls>
# <source src="https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4" type="video/mp4">
# </video>
import urllib.request

video_file = "video.mp4"
urllib.request.urlretrieve(
"https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4",
video_file,
)


# %%
# CUDA Decoding using VideoDecoder
# -------------------------------------
#
# To use CUDA decoder, you need to pass in a cuda device to the decoder.
#
from torchcodec.decoders import VideoDecoder

decoder = VideoDecoder(video_file, device="cuda")
frame = decoder[0]

# %%
#
# The video frames are decoded and returned as tensor of NCHW format.

print(frame.shape, frame.dtype)

# %%
#
# The video frames are left on the GPU memory.

print(frame.data.device)


# %%
# Visualizing Frames
# -------------------------------------
#
# Let's look at the frames decoded by CUDA decoder and compare them
# against equivalent results from the CPU decoders.
timestamps = [12, 19, 45, 131, 180]
cpu_decoder = VideoDecoder(video_file, device="cpu")
cuda_decoder = VideoDecoder(video_file, device="cuda")
cpu_frames = cpu_decoder.get_frames_played_at(timestamps).data
cuda_frames = cuda_decoder.get_frames_played_at(timestamps).data


def plot_cpu_and_cuda_frames(cpu_frames: torch.Tensor, cuda_frames: torch.Tensor):
try:
import matplotlib.pyplot as plt
from torchvision.transforms.v2.functional import to_pil_image
except ImportError:
print("Cannot plot, please run `pip install torchvision matplotlib`")
return
n_rows = len(timestamps)
fig, axes = plt.subplots(n_rows, 2, figsize=[12.8, 16.0])
for i in range(n_rows):
axes[i][0].imshow(to_pil_image(cpu_frames[i].to("cpu")))
axes[i][1].imshow(to_pil_image(cuda_frames[i].to("cpu")))

axes[0][0].set_title("CPU decoder", fontsize=24)
axes[0][1].set_title("CUDA decoder", fontsize=24)
plt.setp(axes, xticks=[], yticks=[])
plt.tight_layout()


plot_cpu_and_cuda_frames(cpu_frames, cuda_frames)

# %%
#
# They look visually similar to the human eye but there may be subtle
# differences because CUDA math is not bit-exact with respect to CPU math.
#
frames_equal = torch.equal(cpu_frames.to("cuda"), cuda_frames)
mean_abs_diff = torch.mean(
torch.abs(cpu_frames.float().to("cuda") - cuda_frames.float())
)
max_abs_diff = torch.max(torch.abs(cpu_frames.to("cuda").float() - cuda_frames.float()))
print(f"{frames_equal=}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Instead of this binary indicator, we may want to print max abs diff and mean abs diff? We could even do it across all frames.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

image

print(f"{mean_abs_diff=}")
print(f"{max_abs_diff=}")
Loading