diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 545ddf9c8..a50581e3f 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -5,45 +5,111 @@ on: branches: [ main ] pull_request: +permissions: + id-token: write + contents: write + defaults: run: shell: bash -l -eo pipefail {0} jobs: + generate-matrix: + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + with: + package-type: wheel + os: linux + test-infra-repository: pytorch/test-infra + test-infra-ref: main + with-cpu: disable + with-xpu: disable + with-rocm: disable + with-cuda: enable + build-python-only: "disable" build: - runs-on: ubuntu-latest + needs: generate-matrix + strategy: + fail-fast: false + name: Build and Upload wheel + uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main + with: + repository: pytorch/torchcodec + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + post-script: packaging/post_build_script.sh + smoke-test-script: packaging/fake_smoke_test.py + package-name: torchcodec + trigger-event: ${{ github.event_name }} + build-platform: "python-build-package" + build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 ENABLE_CUDA=1 python -m build --wheel -vvv --no-isolation" + + build-docs: + runs-on: linux.4xlarge.nvidia.gpu strategy: fail-fast: false + matrix: + # 3.9 corresponds to the minimum python version for which we build + # the wheel unless the label cliflow/binaries/all is present in the + # PR. + python-version: ['3.9'] + cuda-version: ['12.4'] + ffmpeg-version-for-tests: ['7'] + container: + image: "pytorch/manylinux-builder:cuda${{ matrix.cuda-version }}" + options: "--gpus all -e NVIDIA_DRIVER_CAPABILITIES=video,compute,utility" + needs: build steps: - - name: Check out repo - uses: actions/checkout@v3 - - name: Setup conda env - uses: conda-incubator/setup-miniconda@v2 + - name: Setup env vars + run: | + cuda_version_without_periods=$(echo "${{ matrix.cuda-version }}" | sed 's/\.//g') + echo cuda_version_without_periods=${cuda_version_without_periods} >> $GITHUB_ENV + - uses: actions/download-artifact@v3 with: - auto-update-conda: true - miniconda-version: "latest" - activate-environment: test - python-version: '3.12' + name: pytorch_torchcodec__3.9_cu${{ env.cuda_version_without_periods }}_x86_64 + path: pytorch/torchcodec/dist/ + - name: Setup miniconda using test-infra + uses: pytorch/test-infra/.github/actions/setup-miniconda@main + with: + python-version: ${{ matrix.python-version }} + # + # For some reason nvidia::libnpp=12.4 doesn't install but nvidia/label/cuda-12.4.0::libnpp does. + # So we use the latter convention for libnpp. + # We install conda packages at the start because otherwise conda may have conflicts with dependencies. + default-packages: "nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }} conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }}" + - name: Check env + run: | + ${CONDA_RUN} env + ${CONDA_RUN} conda info + ${CONDA_RUN} nvidia-smi + ${CONDA_RUN} conda list + - name: Assert ffmpeg exists + run: | + ${CONDA_RUN} ffmpeg -buildconf - name: Update pip - run: python -m pip install --upgrade pip - - name: Install dependencies and FFmpeg + run: ${CONDA_RUN} python -m pip install --upgrade pip + - name: Install PyTorch run: | - # TODO: torchvision and torchaudio shouldn't be needed. They were only added - # to silence an error as seen in https://github.com/pytorch/torchcodec/issues/203 - python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu - conda install "ffmpeg=7.0.1" pkg-config -c conda-forge - ffmpeg -version - - name: Build and install torchcodec + ${CONDA_RUN} python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu${{ env.cuda_version_without_periods }} + ${CONDA_RUN} python -c 'import torch; print(f"{torch.__version__}"); print(f"{torch.__file__}"); print(f"{torch.cuda.is_available()=}")' + - name: Install torchcodec from the wheel run: | - python -m pip install -e ".[dev]" --no-build-isolation -vvv + wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"` + echo Installing $wheel_path + ${CONDA_RUN} python -m pip install $wheel_path -vvv + + - name: Check out repo + uses: actions/checkout@v3 + - name: Install doc dependencies run: | cd docs - python -m pip install -r requirements.txt + ${CONDA_RUN} python -m pip install -r requirements.txt - name: Build docs run: | cd docs - make html + ${CONDA_RUN} make html - uses: actions/upload-artifact@v3 with: name: Built-Docs diff --git a/docs/source/index.rst b/docs/source/index.rst index 20d6db902..b4d623654 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -50,6 +50,14 @@ We achieve these capabilities through: How to sample video clips + .. grid-item-card:: :octicon:`file-code;1em` + GPU decoding using TorchCodec + :img-top: _static/img/card-background.svg + :link: generated_examples/basic_cuda_example.html + :link-type: url + + A simple example demonstrating CUDA GPU decoding + .. toctree:: :maxdepth: 1 :caption: TorchCodec documentation diff --git a/examples/basic_cuda_example.py b/examples/basic_cuda_example.py new file mode 100644 index 000000000..5ff85e8e0 --- /dev/null +++ b/examples/basic_cuda_example.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +Accelerated video decoding on GPUs with CUDA and NVDEC +================================================================ + +TorchCodec can use supported Nvidia hardware (see support matrix +`here `_) to speed-up +video decoding. This is called "CUDA Decoding" and it uses Nvidia's +`NVDEC hardware decoder `_ +and CUDA kernels to respectively decompress and convert to RGB. +CUDA Decoding can be faster than CPU Decoding for the actual decoding step and also for +subsequent transform steps like scaling, cropping or rotating. This is because the decode step leaves +the decoded tensor in GPU memory so the GPU doesn't have to fetch from main memory before +running the transform steps. Encoded packets are often much smaller than decoded frames so +CUDA decoding also uses less PCI-e bandwidth. + +CUDA Decoding can offer speed-up over CPU Decoding in a few scenarios: + +#. You are decoding a large resolution video +#. You are decoding a large batch of videos that's saturating the CPU +#. You want to do whole-image transforms like scaling or convolutions on the decoded tensors + after decoding +#. Your CPU is saturated and you want to free it up for other work + + +Here are situations where CUDA Decoding may not make sense: + +#. You want bit-exact results compared to CPU Decoding +#. You have small resolution videos and the PCI-e transfer latency is large +#. Your GPU is already busy and CPU is not + +It's best to experiment with CUDA Decoding to see if it improves your use-case. With +TorchCodec you can simply pass in a device parameter to the +:class:`~torchcodec.decoders.VideoDecoder` class to use CUDA Decoding. + + +In order to use CUDA Decoding will need the following installed in your environment: + +#. An Nvidia GPU that supports decoding the video format you want to decode. See + the support matrix `here `_ +#. `CUDA-enabled pytorch `_ +#. FFmpeg binaries that support + `NVDEC-enabled `_ + codecs +#. libnpp and nvrtc (these are usually installed when you install the full cuda-toolkit) + + +FFmpeg versions 5, 6 and 7 from conda-forge are built with +`NVDEC support `_ +and you can install them with conda. For example, to install FFmpeg version 7: + + +.. code-block:: bash + + conda install ffmpeg=7 -c conda-forge + conda install libnpp cuda-nvrtc -c nvidia + + +""" + +# %% +# Checking if Pytorch has CUDA enabled +# ------------------------------------- +# +# .. note:: +# +# This tutorial requires FFmpeg libraries compiled with CUDA support. +# +# +import torch + +print(f"{torch.__version__=}") +print(f"{torch.cuda.is_available()=}") +print(f"{torch.cuda.get_device_properties(0)=}") + + +# %% +# Downloading the video +# ------------------------------------- +# +# We will use the following video which has the following properties: +# +# - Codec: H.264 +# - Resolution: 960x540 +# - FPS: 29.97 +# - Pixel format: YUV420P +# +# .. raw:: html +# +# +import urllib.request + +video_file = "video.mp4" +urllib.request.urlretrieve( + "https://download.pytorch.org/torchaudio/tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4", + video_file, +) + + +# %% +# CUDA Decoding using VideoDecoder +# ------------------------------------- +# +# To use CUDA decoder, you need to pass in a cuda device to the decoder. +# +from torchcodec.decoders import VideoDecoder + +decoder = VideoDecoder(video_file, device="cuda") +frame = decoder[0] + +# %% +# +# The video frames are decoded and returned as tensor of NCHW format. + +print(frame.shape, frame.dtype) + +# %% +# +# The video frames are left on the GPU memory. + +print(frame.data.device) + + +# %% +# Visualizing Frames +# ------------------------------------- +# +# Let's look at the frames decoded by CUDA decoder and compare them +# against equivalent results from the CPU decoders. +timestamps = [12, 19, 45, 131, 180] +cpu_decoder = VideoDecoder(video_file, device="cpu") +cuda_decoder = VideoDecoder(video_file, device="cuda") +cpu_frames = cpu_decoder.get_frames_played_at(timestamps).data +cuda_frames = cuda_decoder.get_frames_played_at(timestamps).data + + +def plot_cpu_and_cuda_frames(cpu_frames: torch.Tensor, cuda_frames: torch.Tensor): + try: + import matplotlib.pyplot as plt + from torchvision.transforms.v2.functional import to_pil_image + except ImportError: + print("Cannot plot, please run `pip install torchvision matplotlib`") + return + n_rows = len(timestamps) + fig, axes = plt.subplots(n_rows, 2, figsize=[12.8, 16.0]) + for i in range(n_rows): + axes[i][0].imshow(to_pil_image(cpu_frames[i].to("cpu"))) + axes[i][1].imshow(to_pil_image(cuda_frames[i].to("cpu"))) + + axes[0][0].set_title("CPU decoder", fontsize=24) + axes[0][1].set_title("CUDA decoder", fontsize=24) + plt.setp(axes, xticks=[], yticks=[]) + plt.tight_layout() + + +plot_cpu_and_cuda_frames(cpu_frames, cuda_frames) + +# %% +# +# They look visually similar to the human eye but there may be subtle +# differences because CUDA math is not bit-exact with respect to CPU math. +# +frames_equal = torch.equal(cpu_frames.to("cuda"), cuda_frames) +mean_abs_diff = torch.mean( + torch.abs(cpu_frames.float().to("cuda") - cuda_frames.float()) +) +max_abs_diff = torch.max(torch.abs(cpu_frames.to("cuda").float() - cuda_frames.float())) +print(f"{frames_equal=}") +print(f"{mean_abs_diff=}") +print(f"{max_abs_diff=}")