diff --git a/.github/workflows/linux_wheel.yaml b/.github/workflows/linux_wheel.yaml
index 099a905c4..25a5a564e 100644
--- a/.github/workflows/linux_wheel.yaml
+++ b/.github/workflows/linux_wheel.yaml
@@ -72,10 +72,14 @@ jobs:
name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_x86_64
path: pytorch/torchcodec/dist/
- name: Setup conda env
- uses: conda-incubator/setup-miniconda@v2
+ uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
- miniconda-version: "latest"
+ # Using miniforge instead of miniconda ensures that the default
+ # conda channel is conda-forge instead of main/default. This ensures
+ # ABI consistency between dependencies:
+ # https://conda-forge.org/docs/user/transitioning_from_defaults/
+ miniforge-version: latest
activate-environment: test
python-version: ${{ matrix.python-version }}
- name: Update pip
diff --git a/.github/workflows/reference_resources.yaml b/.github/workflows/reference_resources.yaml
index 25353d70c..8f97378f1 100644
--- a/.github/workflows/reference_resources.yaml
+++ b/.github/workflows/reference_resources.yaml
@@ -14,7 +14,40 @@ defaults:
shell: bash -l -eo pipefail {0}
jobs:
+ generate-matrix:
+ uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
+ with:
+ package-type: wheel
+ os: linux
+ test-infra-repository: pytorch/test-infra
+ test-infra-ref: main
+ with-xpu: disable
+ with-rocm: disable
+ with-cuda: disable
+ build-python-only: "disable"
+
+ build:
+ needs: generate-matrix
+ strategy:
+ fail-fast: false
+ name: Build and Upload Linux wheel
+ uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
+ with:
+ repository: meta-pytorch/torchcodec
+ ref: ""
+ test-infra-repository: pytorch/test-infra
+ test-infra-ref: main
+ build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
+ pre-script: packaging/pre_build_script.sh
+ post-script: packaging/post_build_script.sh
+ smoke-test-script: packaging/fake_smoke_test.py
+ package-name: torchcodec
+ trigger-event: ${{ github.event_name }}
+ build-platform: "python-build-package"
+ build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 python -m build --wheel -vvv --no-isolation"
+
test-reference-resource-generation:
+ needs: build
runs-on: ubuntu-latest
strategy:
fail-fast: false
@@ -22,6 +55,10 @@ jobs:
python-version: ['3.10']
ffmpeg-version-for-tests: ['4.4.2', '5.1.2', '6.1.1', '7.0.1']
steps:
+ - uses: actions/download-artifact@v4
+ with:
+ name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_x86_64
+ path: pytorch/torchcodec/dist/
- name: Setup conda env
uses: conda-incubator/setup-miniconda@v2
with:
@@ -43,11 +80,16 @@ jobs:
# Note that we're installing stable - this is for running a script where we're a normal PyTorch
# user, not for building TorhCodec.
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
- python -m pip install numpy pillow
+ python -m pip install numpy pillow pytest
+ - name: Install torchcodec from the wheel
+ run: |
+ wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"`
+ echo Installing $wheel_path
+ python -m pip install $wheel_path -vvv
- name: Check out repo
uses: actions/checkout@v3
- name: Run generation reference resources
run: |
- python test/generate_reference_resources.py
+ python -m test.generate_reference_resources
diff --git a/.github/workflows/windows_wheel.yaml b/.github/workflows/windows_wheel.yaml
index 39247f770..2fd773fec 100644
--- a/.github/workflows/windows_wheel.yaml
+++ b/.github/workflows/windows_wheel.yaml
@@ -71,8 +71,7 @@ jobs:
# TODO: FFmpeg 5 on Windows segfaults in avcodec_open2() when passing
# bad parameters.
# See https://github.com/pytorch/torchcodec/pull/806
- # TODO: Support FFmpeg 8 on Windows
- ffmpeg-version-for-tests: ['4.4.2', '5.1.2', '6.1.1', '7.0.1']
+ ffmpeg-version-for-tests: ['4.4.2', '5.1.2', '6.1.1', '7.0.1', '8.0']
needs: build
steps:
- uses: actions/download-artifact@v4
@@ -83,7 +82,11 @@ jobs:
uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
- miniconda-version: "latest"
+ # Using miniforge instead of miniconda ensures that the default
+ # conda channel is conda-forge instead of main/default. This ensures
+ # ABI consistency between dependencies:
+ # https://conda-forge.org/docs/user/transitioning_from_defaults/
+ miniforge-version: latest
activate-environment: test
python-version: ${{ matrix.python-version }}
- name: Update pip
diff --git a/README.md b/README.md
index 8050cf2a3..6c1721036 100644
--- a/README.md
+++ b/README.md
@@ -107,8 +107,8 @@ ffmpeg -f lavfi -i \
`torch` and `torchcodec`.
2. Install FFmpeg, if it's not already installed. Linux distributions usually
- come with FFmpeg pre-installed. TorchCodec supports all major FFmpeg versions
- in [4, 7].
+ come with FFmpeg pre-installed. TorchCodec supports major FFmpeg versions
+ in [4, 7] on all platforms, and FFmpeg version 8 is supported on Mac and Linux.
If FFmpeg is not already installed, or you need a more recent version, an
easy way to install it is to use `conda`:
@@ -131,6 +131,7 @@ The following table indicates the compatibility between versions of
| `torchcodec` | `torch` | Python |
| ------------------ | ------------------ | ------------------- |
| `main` / `nightly` | `main` / `nightly` | `>=3.10`, `<=3.13` |
+| `0.8` | `2.9` | `>=3.10`, `<=3.13` |
| `0.7` | `2.8` | `>=3.9`, `<=3.13` |
| `0.6` | `2.8` | `>=3.9`, `<=3.13` |
| `0.5` | `2.7` | `>=3.9`, `<=3.13` |
@@ -147,7 +148,8 @@ format you want. Refer to Nvidia's GPU support matrix for more details
[here](https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new).
1. Install FFmpeg with NVDEC support.
- TorchCodec with CUDA should work with FFmpeg versions in [4, 7].
+ TorchCodec with CUDA should work with FFmpeg versions in [4, 7] on all platforms,
+ and FFmpeg version 8 is supported on Linux.
If FFmpeg is not already installed, or you need a more recent version, an
easy way to install it is to use `conda`:
diff --git a/benchmarks/decoders/benchmark_transforms.py b/benchmarks/decoders/benchmark_transforms.py
new file mode 100644
index 000000000..75a49d63b
--- /dev/null
+++ b/benchmarks/decoders/benchmark_transforms.py
@@ -0,0 +1,164 @@
+import math
+from argparse import ArgumentParser
+from pathlib import Path
+from time import perf_counter_ns
+
+import torch
+from torch import Tensor
+from torchcodec._core import add_video_stream, create_from_file, get_frames_by_pts
+from torchcodec.decoders import VideoDecoder
+from torchvision.transforms import v2
+
+DEFAULT_NUM_EXP = 20
+
+
+def bench(f, *args, num_exp=DEFAULT_NUM_EXP, warmup=1) -> Tensor:
+
+ for _ in range(warmup):
+ f(*args)
+
+ times = []
+ for _ in range(num_exp):
+ start = perf_counter_ns()
+ f(*args)
+ end = perf_counter_ns()
+ times.append(end - start)
+ return torch.tensor(times).float()
+
+
+def report_stats(times: Tensor, unit: str = "ms", prefix: str = "") -> float:
+ mul = {
+ "ns": 1,
+ "µs": 1e-3,
+ "ms": 1e-6,
+ "s": 1e-9,
+ }[unit]
+ times = times * mul
+ std = times.std().item()
+ med = times.median().item()
+ mean = times.mean().item()
+ min = times.min().item()
+ max = times.max().item()
+ print(
+ f"{prefix:<45} {med = :.2f}, {mean = :.2f} +- {std:.2f}, {min = :.2f}, {max = :.2f} - in {unit}"
+ )
+
+
+def torchvision_resize(
+ path: Path, pts_seconds: list[float], dims: tuple[int, int]
+) -> None:
+ decoder = create_from_file(str(path), seek_mode="approximate")
+ add_video_stream(decoder)
+ raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds)
+ return v2.functional.resize(raw_frames, size=dims)
+
+
+def torchvision_crop(
+ path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int
+) -> None:
+ decoder = create_from_file(str(path), seek_mode="approximate")
+ add_video_stream(decoder)
+ raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds)
+ return v2.functional.crop(raw_frames, top=y, left=x, height=dims[0], width=dims[1])
+
+
+def decoder_native_resize(
+ path: Path, pts_seconds: list[float], dims: tuple[int, int]
+) -> None:
+ decoder = create_from_file(str(path), seek_mode="approximate")
+ add_video_stream(decoder, transform_specs=f"resize, {dims[0]}, {dims[1]}")
+ return get_frames_by_pts(decoder, timestamps=pts_seconds)[0]
+
+
+def decoder_native_crop(
+ path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int
+) -> None:
+ decoder = create_from_file(str(path), seek_mode="approximate")
+ add_video_stream(decoder, transform_specs=f"crop, {dims[0]}, {dims[1]}, {x}, {y}")
+ return get_frames_by_pts(decoder, timestamps=pts_seconds)[0]
+
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument("--path", type=str, help="path to file", required=True)
+ parser.add_argument(
+ "--num-exp",
+ type=int,
+ default=DEFAULT_NUM_EXP,
+ help="number of runs to average over",
+ )
+
+ args = parser.parse_args()
+ path = Path(args.path)
+
+ metadata = VideoDecoder(path).metadata
+ duration = metadata.duration_seconds
+
+ print(
+ f"Benchmarking {path.name}, duration: {duration}, codec: {metadata.codec}, averaging over {args.num_exp} runs:"
+ )
+
+ input_height = metadata.height
+ input_width = metadata.width
+ fraction_of_total_frames_to_sample = [0.005, 0.01, 0.05, 0.1]
+ fraction_of_input_dimensions = [0.5, 0.25, 0.125]
+
+ for num_fraction in fraction_of_total_frames_to_sample:
+ num_frames_to_sample = math.ceil(metadata.num_frames * num_fraction)
+ print(
+ f"Sampling {num_fraction * 100}%, {num_frames_to_sample}, of {metadata.num_frames} frames"
+ )
+ uniform_timestamps = [
+ i * duration / num_frames_to_sample for i in range(num_frames_to_sample)
+ ]
+
+ for dims_fraction in fraction_of_input_dimensions:
+ dims = (int(input_height * dims_fraction), int(input_width * dims_fraction))
+
+ times = bench(
+ torchvision_resize, path, uniform_timestamps, dims, num_exp=args.num_exp
+ )
+ report_stats(times, prefix=f"torchvision_resize({dims})")
+
+ times = bench(
+ decoder_native_resize,
+ path,
+ uniform_timestamps,
+ dims,
+ num_exp=args.num_exp,
+ )
+ report_stats(times, prefix=f"decoder_native_resize({dims})")
+ print()
+
+ center_x = (input_height - dims[0]) // 2
+ center_y = (input_width - dims[1]) // 2
+ times = bench(
+ torchvision_crop,
+ path,
+ uniform_timestamps,
+ dims,
+ center_x,
+ center_y,
+ num_exp=args.num_exp,
+ )
+ report_stats(
+ times, prefix=f"torchvision_crop({dims}, {center_x}, {center_y})"
+ )
+
+ times = bench(
+ decoder_native_crop,
+ path,
+ uniform_timestamps,
+ dims,
+ center_x,
+ center_y,
+ num_exp=args.num_exp,
+ )
+ report_stats(
+ times, prefix=f"decoder_native_crop({dims}, {center_x}, {center_y})"
+ )
+ print()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 8dea1dc8b..85f9a067c 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -11,7 +11,7 @@ We achieve these capabilities through:
* Pythonic APIs that mirror Python and PyTorch conventions.
* Relying on `FFmpeg `_ to do the decoding / encoding.
- TorchCodec uses the version of FFmpeg you already have installed. FMPEG is a
+ TorchCodec uses the version of FFmpeg you already have installed. FFmpeg is a
mature library with broad coverage available on most systems. It is, however,
not easy to use. TorchCodec abstracts FFmpeg's complexity to ensure it is
used correctly and efficiently.
diff --git a/src/torchcodec/_core/AVIOTensorContext.cpp b/src/torchcodec/_core/AVIOTensorContext.cpp
index 3f45f5be5..238475761 100644
--- a/src/torchcodec/_core/AVIOTensorContext.cpp
+++ b/src/torchcodec/_core/AVIOTensorContext.cpp
@@ -18,15 +18,15 @@ constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB
int read(void* opaque, uint8_t* buf, int buf_size) {
auto tensorContext = static_cast(opaque);
TORCH_CHECK(
- tensorContext->current <= tensorContext->data.numel(),
- "Tried to read outside of the buffer: current=",
- tensorContext->current,
+ tensorContext->current_pos <= tensorContext->data.numel(),
+ "Tried to read outside of the buffer: current_pos=",
+ tensorContext->current_pos,
", size=",
tensorContext->data.numel());
int64_t numBytesRead = std::min(
static_cast(buf_size),
- tensorContext->data.numel() - tensorContext->current);
+ tensorContext->data.numel() - tensorContext->current_pos);
TORCH_CHECK(
numBytesRead >= 0,
@@ -34,8 +34,8 @@ int read(void* opaque, uint8_t* buf, int buf_size) {
numBytesRead,
", size=",
tensorContext->data.numel(),
- ", current=",
- tensorContext->current);
+ ", current_pos=",
+ tensorContext->current_pos);
if (numBytesRead == 0) {
return AVERROR_EOF;
@@ -43,9 +43,9 @@ int read(void* opaque, uint8_t* buf, int buf_size) {
std::memcpy(
buf,
- tensorContext->data.data_ptr() + tensorContext->current,
+ tensorContext->data.data_ptr() + tensorContext->current_pos,
numBytesRead);
- tensorContext->current += numBytesRead;
+ tensorContext->current_pos += numBytesRead;
return numBytesRead;
}
@@ -54,7 +54,7 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
auto tensorContext = static_cast(opaque);
int64_t bufSize = static_cast(buf_size);
- if (tensorContext->current + bufSize > tensorContext->data.numel()) {
+ if (tensorContext->current_pos + bufSize > tensorContext->data.numel()) {
TORCH_CHECK(
tensorContext->data.numel() * 2 <= MAX_TENSOR_SIZE,
"We tried to allocate an output encoded tensor larger than ",
@@ -68,13 +68,17 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
}
TORCH_CHECK(
- tensorContext->current + bufSize <= tensorContext->data.numel(),
+ tensorContext->current_pos + bufSize <= tensorContext->data.numel(),
"Re-allocation of the output tensor didn't work. ",
"This should not happen, please report on TorchCodec bug tracker");
uint8_t* outputTensorData = tensorContext->data.data_ptr();
- std::memcpy(outputTensorData + tensorContext->current, buf, bufSize);
- tensorContext->current += bufSize;
+ std::memcpy(outputTensorData + tensorContext->current_pos, buf, bufSize);
+ tensorContext->current_pos += bufSize;
+ // Track the maximum position written so getOutputTensor's narrow() does not
+ // truncate the file if final seek was backwards
+ tensorContext->max_pos =
+ std::max(tensorContext->current_pos, tensorContext->max_pos);
return buf_size;
}
@@ -88,7 +92,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) {
ret = tensorContext->data.numel();
break;
case SEEK_SET:
- tensorContext->current = offset;
+ tensorContext->current_pos = offset;
ret = offset;
break;
default:
@@ -101,7 +105,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) {
} // namespace
AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
- : tensorContext_{data, 0} {
+ : tensorContext_{data, 0, 0} {
TORCH_CHECK(data.numel() > 0, "data must not be empty");
TORCH_CHECK(data.is_contiguous(), "data must be contiguous");
TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8");
@@ -110,14 +114,17 @@ AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
}
AVIOToTensorContext::AVIOToTensorContext()
- : tensorContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} {
+ : tensorContext_{
+ torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}),
+ 0,
+ 0} {
createAVIOContext(
nullptr, &write, &seek, &tensorContext_, /*isForWriting=*/true);
}
torch::Tensor AVIOToTensorContext::getOutputTensor() {
return tensorContext_.data.narrow(
- /*dim=*/0, /*start=*/0, /*length=*/tensorContext_.current);
+ /*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max_pos);
}
} // namespace facebook::torchcodec
diff --git a/src/torchcodec/_core/AVIOTensorContext.h b/src/torchcodec/_core/AVIOTensorContext.h
index 15f97da55..bcd97052b 100644
--- a/src/torchcodec/_core/AVIOTensorContext.h
+++ b/src/torchcodec/_core/AVIOTensorContext.h
@@ -15,7 +15,8 @@ namespace detail {
struct TensorContext {
torch::Tensor data;
- int64_t current;
+ int64_t current_pos;
+ int64_t max_pos;
};
} // namespace detail
diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp
index 78fa8d635..b0caa9705 100644
--- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp
+++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp
@@ -15,7 +15,7 @@
#include "src/torchcodec/_core/FFMPEGCommon.h"
#include "src/torchcodec/_core/NVDECCache.h"
-// #include // For cudaStreamSynchronize
+#include "src/torchcodec/_core/NVCUVIDRuntimeLoader.h"
#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h"
#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h"
@@ -53,74 +53,6 @@ pfnDisplayPictureCallback(void* pUserData, CUVIDPARSERDISPINFO* dispInfo) {
}
static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
- // Check decoder capabilities - same checks as DALI
- auto caps = CUVIDDECODECAPS{};
- caps.eCodecType = videoFormat->codec;
- caps.eChromaFormat = videoFormat->chroma_format;
- caps.nBitDepthMinus8 = videoFormat->bit_depth_luma_minus8;
- CUresult result = cuvidGetDecoderCaps(&caps);
- TORCH_CHECK(result == CUDA_SUCCESS, "Failed to get decoder caps: ", result);
-
- TORCH_CHECK(
- caps.bIsSupported,
- "Codec configuration not supported on this GPU. "
- "Codec: ",
- static_cast(videoFormat->codec),
- ", chroma format: ",
- static_cast(videoFormat->chroma_format),
- ", bit depth: ",
- videoFormat->bit_depth_luma_minus8 + 8);
-
- TORCH_CHECK(
- videoFormat->coded_width >= caps.nMinWidth &&
- videoFormat->coded_height >= caps.nMinHeight,
- "Video is too small in at least one dimension. Provided: ",
- videoFormat->coded_width,
- "x",
- videoFormat->coded_height,
- " vs supported:",
- caps.nMinWidth,
- "x",
- caps.nMinHeight);
-
- TORCH_CHECK(
- videoFormat->coded_width <= caps.nMaxWidth &&
- videoFormat->coded_height <= caps.nMaxHeight,
- "Video is too large in at least one dimension. Provided: ",
- videoFormat->coded_width,
- "x",
- videoFormat->coded_height,
- " vs supported:",
- caps.nMaxWidth,
- "x",
- caps.nMaxHeight);
-
- // See nMaxMBCount in cuviddec.h
- constexpr unsigned int macroblockConstant = 256;
- TORCH_CHECK(
- videoFormat->coded_width * videoFormat->coded_height /
- macroblockConstant <=
- caps.nMaxMBCount,
- "Video is too large (too many macroblocks). "
- "Provided (width * height / ",
- macroblockConstant,
- "): ",
- videoFormat->coded_width * videoFormat->coded_height / macroblockConstant,
- " vs supported:",
- caps.nMaxMBCount);
-
- // Below we'll set the decoderParams.OutputFormat to NV12, so we need to make
- // sure it's actually supported.
- TORCH_CHECK(
- (caps.nOutputFormatMask >> cudaVideoSurfaceFormat_NV12) & 1,
- "NV12 output format is not supported for this configuration. ",
- "Codec: ",
- static_cast(videoFormat->codec),
- ", chroma format: ",
- static_cast(videoFormat->chroma_format),
- ", bit depth: ",
- videoFormat->bit_depth_luma_minus8 + 8);
-
// Decoder creation parameters, most are taken from DALI
CUVIDDECODECREATEINFO decoderParams = {};
decoderParams.bitDepthMinus8 = videoFormat->bit_depth_luma_minus8;
@@ -129,7 +61,7 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
// automatically converted to 8bits by NVDEC itself. That is, the raw frames
// we get back from cuvidMapVideoFrame will already be in 8bit format. We
// won't need to do the conversion ourselves, so that's a lot easier.
- // In the default interface, we have to do the 10 -> 8bits conversion
+ // In the ffmpeg CUDA interface, we have to do the 10 -> 8bits conversion
// ourselves later in convertAVFrameToFrameOutput(), because FFmpeg explicitly
// requests 10 or 16bits output formats for >8-bit videos!
// https://github.com/FFmpeg/FFmpeg/blob/e05f8acabff468c1382277c1f31fa8e9d90c3202/libavcodec/nvdec.c#L376-L403
@@ -157,13 +89,39 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
decoderParams.display_area.bottom = videoFormat->display_area.bottom;
CUvideodecoder* decoder = new CUvideodecoder();
- result = cuvidCreateDecoder(decoder, &decoderParams);
+ CUresult result = cuvidCreateDecoder(decoder, &decoderParams);
TORCH_CHECK(
result == CUDA_SUCCESS, "Failed to create NVDEC decoder: ", result);
return UniqueCUvideodecoder(decoder, CUvideoDecoderDeleter{});
}
-cudaVideoCodec validateCodecSupport(AVCodecID codecId) {
+std::optional validateChromaSupport(
+ const AVPixFmtDescriptor* desc) {
+ // Return the corresponding cudaVideoChromaFormat if supported, std::nullopt
+ // otherwise.
+ TORCH_CHECK(desc != nullptr, "desc can't be null");
+
+ if (desc->nb_components == 1) {
+ return cudaVideoChromaFormat_Monochrome;
+ } else if (desc->nb_components >= 3 && !(desc->flags & AV_PIX_FMT_FLAG_RGB)) {
+ // Make sure it's YUV: has chroma planes and isn't RGB
+ if (desc->log2_chroma_w == 0 && desc->log2_chroma_h == 0) {
+ return cudaVideoChromaFormat_444; // 1x1 subsampling = 4:4:4
+ } else if (desc->log2_chroma_w == 1 && desc->log2_chroma_h == 1) {
+ return cudaVideoChromaFormat_420; // 2x2 subsampling = 4:2:0
+ } else if (desc->log2_chroma_w == 1 && desc->log2_chroma_h == 0) {
+ return cudaVideoChromaFormat_422; // 2x1 subsampling = 4:2:2
+ }
+ }
+
+ return std::nullopt;
+}
+
+std::optional validateCodecSupport(AVCodecID codecId) {
+ // Return the corresponding cudaVideoCodec if supported, std::nullopt
+ // otherwise
+ // Note that we currently return nullopt (and thus fallback to CPU) for some
+ // codecs that are technically supported by NVDEC, see comment below.
switch (codecId) {
case AV_CODEC_ID_H264:
return cudaVideoCodec_H264;
@@ -189,12 +147,72 @@ cudaVideoCodec validateCodecSupport(AVCodecID codecId) {
// return cudaVideoCodec_JPEG;
// case AV_CODEC_ID_VC1:
// return cudaVideoCodec_VC1;
- default: {
- TORCH_CHECK(false, "Unsupported codec type: ", avcodec_get_name(codecId));
- }
+ default:
+ return std::nullopt;
}
}
+bool nativeNVDECSupport(const SharedAVCodecContext& codecContext) {
+ // Return true iff the input video stream is supported by our NVDEC
+ // implementation.
+
+ auto codecType = validateCodecSupport(codecContext->codec_id);
+ if (!codecType.has_value()) {
+ return false;
+ }
+
+ const AVPixFmtDescriptor* desc = av_pix_fmt_desc_get(codecContext->pix_fmt);
+ if (!desc) {
+ return false;
+ }
+
+ auto chromaFormat = validateChromaSupport(desc);
+ if (!chromaFormat.has_value()) {
+ return false;
+ }
+
+ auto caps = CUVIDDECODECAPS{};
+ caps.eCodecType = codecType.value();
+ caps.eChromaFormat = chromaFormat.value();
+ caps.nBitDepthMinus8 = desc->comp[0].depth - 8;
+
+ CUresult result = cuvidGetDecoderCaps(&caps);
+ if (result != CUDA_SUCCESS) {
+ return false;
+ }
+
+ if (!caps.bIsSupported) {
+ return false;
+ }
+
+ auto coded_width = static_cast(codecContext->coded_width);
+ auto coded_height = static_cast(codecContext->coded_height);
+ if (coded_width < static_cast(caps.nMinWidth) ||
+ coded_height < static_cast(caps.nMinHeight) ||
+ coded_width > caps.nMaxWidth || coded_height > caps.nMaxHeight) {
+ return false;
+ }
+
+ // See nMaxMBCount in cuviddec.h
+ constexpr unsigned int macroblockConstant = 256;
+ if (coded_width * coded_height / macroblockConstant > caps.nMaxMBCount) {
+ return false;
+ }
+
+ // We'll set the decoderParams.OutputFormat to NV12, so we need to make
+ // sure it's actually supported.
+ // TODO: If this fail, we could consider decoding to something else than NV12
+ // (like cudaVideoSurfaceFormat_P016) instead of falling back to CPU. This is
+ // what FFmpeg does.
+ bool supportsNV12Output =
+ (caps.nOutputFormatMask >> cudaVideoSurfaceFormat_NV12) & 1;
+ if (!supportsNV12Output) {
+ return false;
+ }
+
+ return true;
+}
+
} // namespace
BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
@@ -205,6 +223,8 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
initializeCudaContextWithPytorch(device_);
nppCtx_ = getNppStreamContext(device_);
+
+ nvcuvidAvailable_ = loadNVCUVIDLibrary();
}
BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
@@ -216,12 +236,11 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
// unclear.
flush();
unmapPreviousFrame();
- NVDECCache::getCache(device_.index())
- .returnDecoder(&videoFormat_, std::move(decoder_));
+ NVDECCache::getCache(device_).returnDecoder(
+ &videoFormat_, std::move(decoder_));
}
if (videoParser_) {
- // TODONVDEC P2: consider caching this? Does DALI do that?
cuvidDestroyVideoParser(videoParser_);
videoParser_ = nullptr;
}
@@ -231,7 +250,21 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
void BetaCudaDeviceInterface::initialize(
const AVStream* avStream,
- const UniqueDecodingAVFormatContext& avFormatCtx) {
+ const UniqueDecodingAVFormatContext& avFormatCtx,
+ [[maybe_unused]] const SharedAVCodecContext& codecContext) {
+ if (!nvcuvidAvailable_ || !nativeNVDECSupport(codecContext)) {
+ cpuFallback_ = createDeviceInterface(torch::kCPU);
+ TORCH_CHECK(
+ cpuFallback_ != nullptr, "Failed to create CPU device interface");
+ cpuFallback_->initialize(avStream, avFormatCtx, codecContext);
+ cpuFallback_->initializeVideo(
+ VideoStreamOptions(),
+ {},
+ /*resizedOutputDims=*/std::nullopt);
+ // We'll always use the CPU fallback from now on, so we can return early.
+ return;
+ }
+
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
timeBase_ = avStream->time_base;
frameRateAvgFromFFmpeg_ = avStream->r_frame_rate;
@@ -243,7 +276,11 @@ void BetaCudaDeviceInterface::initialize(
// Create parser. Default values that aren't obvious are taken from DALI.
CUVIDPARSERPARAMS parserParams = {};
- parserParams.CodecType = validateCodecSupport(codecPar->codec_id);
+ auto codecType = validateCodecSupport(codecPar->codec_id);
+ TORCH_CHECK(
+ codecType.has_value(),
+ "This should never happen, we should be using the CPU fallback by now. Please report a bug.");
+ parserParams.CodecType = codecType.value();
parserParams.ulMaxNumDecodeSurfaces = 8;
parserParams.ulMaxDisplayDelay = 0;
// Callback setup, all are triggered by the parser within a call
@@ -362,11 +399,12 @@ int BetaCudaDeviceInterface::streamPropertyChange(CUVIDEOFORMAT* videoFormat) {
}
if (!decoder_) {
- decoder_ = NVDECCache::getCache(device_.index()).getDecoder(videoFormat);
+ decoder_ = NVDECCache::getCache(device_).getDecoder(videoFormat);
if (!decoder_) {
// TODONVDEC P2: consider re-configuring an existing decoder instead of
- // re-creating one. See docs, see DALI.
+ // re-creating one. See docs, see DALI. Re-configuration doesn't seem to
+ // be enabled in DALI by default.
decoder_ = createDecoder(videoFormat);
}
@@ -382,6 +420,10 @@ int BetaCudaDeviceInterface::streamPropertyChange(CUVIDEOFORMAT* videoFormat) {
// Moral equivalent of avcodec_send_packet(). Here, we pass the AVPacket down to
// the NVCUVID parser.
int BetaCudaDeviceInterface::sendPacket(ReferenceAVPacket& packet) {
+ if (cpuFallback_) {
+ return cpuFallback_->sendPacket(packet);
+ }
+
TORCH_CHECK(
packet.get() && packet->data && packet->size > 0,
"sendPacket received an empty packet, this is unexpected, please report.");
@@ -405,6 +447,10 @@ int BetaCudaDeviceInterface::sendPacket(ReferenceAVPacket& packet) {
}
int BetaCudaDeviceInterface::sendEOFPacket() {
+ if (cpuFallback_) {
+ return cpuFallback_->sendEOFPacket();
+ }
+
CUVIDSOURCEDATAPACKET cuvidPacket = {};
cuvidPacket.flags = CUVID_PKT_ENDOFSTREAM;
eofSent_ = true;
@@ -466,6 +512,10 @@ int BetaCudaDeviceInterface::frameReadyInDisplayOrder(
// Moral equivalent of avcodec_receive_frame().
int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) {
+ if (cpuFallback_) {
+ return cpuFallback_->receiveFrame(avFrame);
+ }
+
if (readyFrames_.empty()) {
// No frame found, instruct caller to try again later after sending more
// packets, or to stop if EOF was already sent.
@@ -480,8 +530,7 @@ int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) {
procParams.top_field_first = dispInfo.top_field_first;
procParams.unpaired_field = dispInfo.repeat_first_field < 0;
// We set the NVDEC stream to the current stream. It will be waited upon by
- // the NPP stream before any color conversion. Currently, that syncing logic
- // is in the default interface.
+ // the NPP stream before any color conversion.
// Re types: we get a cudaStream_t from PyTorch but it's interchangeable with
// CUstream
procParams.output_stream = reinterpret_cast(
@@ -601,6 +650,11 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame(
}
void BetaCudaDeviceInterface::flush() {
+ if (cpuFallback_) {
+ cpuFallback_->flush();
+ return;
+ }
+
// The NVCUVID docs mention that after seeking, i.e. when flush() is called,
// we should send a packet with the CUVID_PKT_DISCONTINUITY flag. The docs
// don't say whether this should be an empty packet, or whether it should be a
@@ -618,8 +672,23 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
std::optional preAllocatedOutputTensor) {
- // TODONVDEC P2: we may need to handle 10bit videos the same way the default
- // interface does it with maybeConvertAVFrameToNV12OrRGB24().
+ if (cpuFallback_) {
+ // CPU decoded frame - need to do CPU color conversion then transfer to GPU
+ FrameOutput cpuFrameOutput;
+ cpuFallback_->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput);
+
+ // Transfer CPU frame to GPU
+ if (preAllocatedOutputTensor.has_value()) {
+ preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data);
+ frameOutput.data = preAllocatedOutputTensor.value();
+ } else {
+ frameOutput.data = cpuFrameOutput.data.to(device_);
+ }
+ return;
+ }
+
+ // TODONVDEC P2: we may need to handle 10bit videos the same way the CUDA
+ // ffmpeg interface does it with maybeConvertAVFrameToNV12OrRGB24().
TORCH_CHECK(
avFrame->format == AV_PIX_FMT_CUDA,
"Expected CUDA format frame from BETA CUDA interface");
@@ -633,4 +702,17 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor);
}
+std::string BetaCudaDeviceInterface::getDetails() {
+ std::string details = "Beta CUDA Device Interface.";
+ if (cpuFallback_) {
+ details += " Using CPU fallback.";
+ if (!nvcuvidAvailable_) {
+ details += " NVCUVID not available!";
+ }
+ } else {
+ details += " Using NVDEC.";
+ }
+ return details;
+}
+
} // namespace facebook::torchcodec
diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h
index 0bf9951d6..29511df50 100644
--- a/src/torchcodec/_core/BetaCudaDeviceInterface.h
+++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h
@@ -40,7 +40,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {
void initialize(
const AVStream* avStream,
- const UniqueDecodingAVFormatContext& avFormatCtx) override;
+ const UniqueDecodingAVFormatContext& avFormatCtx,
+ const SharedAVCodecContext& codecContext) override;
void convertAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
@@ -48,10 +49,6 @@ class BetaCudaDeviceInterface : public DeviceInterface {
std::optional preAllocatedOutputTensor =
std::nullopt) override;
- bool canDecodePacketDirectly() const override {
- return true;
- }
-
int sendPacket(ReferenceAVPacket& packet) override;
int sendEOFPacket() override;
int receiveFrame(UniqueAVFrame& avFrame) override;
@@ -62,6 +59,8 @@ class BetaCudaDeviceInterface : public DeviceInterface {
int frameReadyForDecoding(CUVIDPICPARAMS* picParams);
int frameReadyInDisplayOrder(CUVIDPARSERDISPINFO* dispInfo);
+ std::string getDetails() override;
+
private:
int sendCuvidPacket(CUVIDSOURCEDATAPACKET& cuvidPacket);
@@ -97,6 +96,9 @@ class BetaCudaDeviceInterface : public DeviceInterface {
// NPP context for color conversion
UniqueNppContext nppCtx_;
+
+ std::unique_ptr cpuFallback_;
+ bool nvcuvidAvailable_ = false;
};
} // namespace facebook::torchcodec
diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt
index 75d1b036c..6b4ccb5d4 100644
--- a/src/torchcodec/_core/CMakeLists.txt
+++ b/src/torchcodec/_core/CMakeLists.txt
@@ -99,7 +99,7 @@ function(make_torchcodec_libraries
)
if(ENABLE_CUDA)
- list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp)
+ list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp NVCUVIDRuntimeLoader.cpp)
endif()
set(core_library_dependencies
@@ -108,27 +108,9 @@ function(make_torchcodec_libraries
)
if(ENABLE_CUDA)
- # Try to find NVCUVID. Try the normal way first. This should work locally.
- find_library(NVCUVID_LIBRARY NAMES nvcuvid)
- # If not found, try with version suffix, or hardcoded path. Appears
- # to be necessary on the CI.
- if(NOT NVCUVID_LIBRARY)
- find_library(NVCUVID_LIBRARY NAMES nvcuvid.1 PATHS /usr/lib64 /usr/lib)
- endif()
- if(NOT NVCUVID_LIBRARY)
- set(NVCUVID_LIBRARY "/usr/lib64/libnvcuvid.so.1")
- endif()
-
- if(NVCUVID_LIBRARY)
- message(STATUS "Found NVCUVID: ${NVCUVID_LIBRARY}")
- else()
- message(FATAL_ERROR "Could not find NVCUVID library")
- endif()
-
list(APPEND core_library_dependencies
${CUDA_nppi_LIBRARY}
${CUDA_nppicc_LIBRARY}
- ${NVCUVID_LIBRARY}
)
endif()
diff --git a/src/torchcodec/_core/CUDACommon.cpp b/src/torchcodec/_core/CUDACommon.cpp
index 4f3664031..4532e3c76 100644
--- a/src/torchcodec/_core/CUDACommon.cpp
+++ b/src/torchcodec/_core/CUDACommon.cpp
@@ -5,14 +5,12 @@
// LICENSE file in the root directory of this source tree.
#include "src/torchcodec/_core/CUDACommon.h"
+#include "src/torchcodec/_core/Cache.h" // for PerGpuCache
namespace facebook::torchcodec {
namespace {
-// Pytorch can only handle up to 128 GPUs.
-// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44
-const int MAX_CUDA_GPUS = 128;
// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching.
// Set to a positive number to have a cache of that size.
const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
@@ -249,7 +247,7 @@ torch::Tensor convertNV12FrameToRGB(
}
UniqueNppContext getNppStreamContext(const torch::Device& device) {
- torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
+ int deviceIndex = getDeviceIndex(device);
UniqueNppContext nppCtx = g_cached_npp_ctxs.get(device);
if (nppCtx) {
@@ -266,13 +264,13 @@ UniqueNppContext getNppStreamContext(const torch::Device& device) {
nppCtx = std::make_unique();
cudaDeviceProp prop{};
- cudaError_t err = cudaGetDeviceProperties(&prop, nonNegativeDeviceIndex);
+ cudaError_t err = cudaGetDeviceProperties(&prop, deviceIndex);
TORCH_CHECK(
err == cudaSuccess,
"cudaGetDeviceProperties failed: ",
cudaGetErrorString(err));
- nppCtx->nCudaDeviceId = nonNegativeDeviceIndex;
+ nppCtx->nCudaDeviceId = deviceIndex;
nppCtx->nMultiProcessorCount = prop.multiProcessorCount;
nppCtx->nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor;
nppCtx->nMaxThreadsPerBlock = prop.maxThreadsPerBlock;
@@ -312,4 +310,21 @@ void validatePreAllocatedTensorShape(
}
}
+int getDeviceIndex(const torch::Device& device) {
+ // PyTorch uses int8_t as its torch::DeviceIndex, but FFmpeg and CUDA
+ // libraries use int. So we use int, too.
+ int deviceIndex = static_cast(device.index());
+ TORCH_CHECK(
+ deviceIndex >= -1 && deviceIndex < MAX_CUDA_GPUS,
+ "Invalid device index = ",
+ deviceIndex);
+
+ if (deviceIndex == -1) {
+ TORCH_CHECK(
+ cudaGetDevice(&deviceIndex) == cudaSuccess,
+ "Failed to get current CUDA device.");
+ }
+ return deviceIndex;
+}
+
} // namespace facebook::torchcodec
diff --git a/src/torchcodec/_core/CUDACommon.h b/src/torchcodec/_core/CUDACommon.h
index b935cd4bf..588f60e49 100644
--- a/src/torchcodec/_core/CUDACommon.h
+++ b/src/torchcodec/_core/CUDACommon.h
@@ -11,7 +11,6 @@
#include
#include
-#include "src/torchcodec/_core/Cache.h"
#include "src/torchcodec/_core/FFMPEGCommon.h"
#include "src/torchcodec/_core/Frame.h"
@@ -22,6 +21,10 @@ extern "C" {
namespace facebook::torchcodec {
+// Pytorch can only handle up to 128 GPUs.
+// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44
+constexpr int MAX_CUDA_GPUS = 128;
+
void initializeCudaContextWithPytorch(const torch::Device& device);
// Unique pointer type for NPP stream context
@@ -43,4 +46,6 @@ void validatePreAllocatedTensorShape(
const std::optional& preAllocatedOutputTensor,
const UniqueAVFrame& avFrame);
+int getDeviceIndex(const torch::Device& device);
+
} // namespace facebook::torchcodec
diff --git a/src/torchcodec/_core/Cache.h b/src/torchcodec/_core/Cache.h
index 7b088a145..b2c93e8ea 100644
--- a/src/torchcodec/_core/Cache.h
+++ b/src/torchcodec/_core/Cache.h
@@ -95,30 +95,16 @@ class PerGpuCache {
std::vector>> cache_;
};
-// Note: this function is inline for convenience, not performance. Because the
-// rest of this file is template functions, they must all be defined in this
-// header. This function is not a template function, and should, in principle,
-// be defined in a .cpp file to preserve the One Definition Rule. That's
-// annoying for such a small amount of code, so we just inline it. If this file
-// grows, and there are more such functions, we should break them out into a
-// .cpp file.
-inline torch::DeviceIndex getNonNegativeDeviceIndex(
- const torch::Device& device) {
- torch::DeviceIndex deviceIndex = device.index();
- // For single GPU machines libtorch returns -1 for the device index. So for
- // that case we set the device index to 0. That's used in per-gpu cache
- // implementation and during initialization of CUDA and FFmpeg contexts
- // which require non negative indices.
- deviceIndex = std::max(deviceIndex, 0);
- TORCH_CHECK(deviceIndex >= 0, "Device index out of range");
- return deviceIndex;
-}
+// Forward declaration of getDeviceIndex which exists in CUDACommon.h
+// This avoids circular dependency between Cache.h and CUDACommon.cpp which also
+// needs to include Cache.h
+int getDeviceIndex(const torch::Device& device);
template
bool PerGpuCache::addIfCacheHasCapacity(
const torch::Device& device,
element_type&& obj) {
- torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device);
+ int deviceIndex = getDeviceIndex(device);
TORCH_CHECK(
static_cast(deviceIndex) < cache_.size(),
"Device index out of range");
@@ -128,7 +114,7 @@ bool PerGpuCache::addIfCacheHasCapacity(
template
typename PerGpuCache::element_type PerGpuCache::get(
const torch::Device& device) {
- torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device);
+ int deviceIndex = getDeviceIndex(device);
TORCH_CHECK(
static_cast(deviceIndex) < cache_.size(),
"Device index out of range");
diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp
index e6b96e3e4..5aa20b09e 100644
--- a/src/torchcodec/_core/CpuDeviceInterface.cpp
+++ b/src/torchcodec/_core/CpuDeviceInterface.cpp
@@ -48,8 +48,10 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
void CpuDeviceInterface::initialize(
const AVStream* avStream,
- [[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx) {
+ [[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx,
+ const SharedAVCodecContext& codecContext) {
TORCH_CHECK(avStream != nullptr, "avStream is null");
+ codecContext_ = codecContext;
timeBase_ = avStream->time_base;
}
@@ -344,4 +346,8 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
return rgbAVFrameToTensor(filterGraph_->convert(avFrame));
}
+std::string CpuDeviceInterface::getDetails() {
+ return std::string("CPU Device Interface.");
+}
+
} // namespace facebook::torchcodec
diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h
index 399b0c6be..3f6f7c962 100644
--- a/src/torchcodec/_core/CpuDeviceInterface.h
+++ b/src/torchcodec/_core/CpuDeviceInterface.h
@@ -25,7 +25,8 @@ class CpuDeviceInterface : public DeviceInterface {
virtual void initialize(
const AVStream* avStream,
- const UniqueDecodingAVFormatContext& avFormatCtx) override;
+ const UniqueDecodingAVFormatContext& avFormatCtx,
+ const SharedAVCodecContext& codecContext) override;
virtual void initializeVideo(
const VideoStreamOptions& videoStreamOptions,
@@ -38,6 +39,8 @@ class CpuDeviceInterface : public DeviceInterface {
std::optional preAllocatedOutputTensor =
std::nullopt) override;
+ std::string getDetails() override;
+
private:
int convertAVFrameToTensorUsingSwScale(
const UniqueAVFrame& avFrame,
diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp
index aea2b2d9a..be45050e6 100644
--- a/src/torchcodec/_core/CudaDeviceInterface.cpp
+++ b/src/torchcodec/_core/CudaDeviceInterface.cpp
@@ -32,9 +32,6 @@ static bool g_cuda = registerDeviceInterface(
// from
// the cache. If the cache is empty we create a new cuda context.
-// Pytorch can only handle up to 128 GPUs.
-// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44
-const int MAX_CUDA_GPUS = 128;
// Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching.
// Set to a positive number to have a cache of that size.
const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1;
@@ -54,7 +51,7 @@ int getFlagsAVHardwareDeviceContextCreate() {
UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) {
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
- torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device);
+ int deviceIndex = getDeviceIndex(device);
UniqueAVBufferRef hardwareDeviceCtx = g_cached_hw_device_ctxs.get(device);
if (hardwareDeviceCtx) {
@@ -63,14 +60,12 @@ UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) {
// Create hardware device context
c10::cuda::CUDAGuard deviceGuard(device);
- // Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
- // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
- // So we ensure the deviceIndex is not negative.
// We set the device because we may be called from a different thread than
// the one that initialized the cuda context.
- cudaSetDevice(nonNegativeDeviceIndex);
+ TORCH_CHECK(
+ cudaSetDevice(deviceIndex) == cudaSuccess, "Failed to set CUDA device");
AVBufferRef* hardwareDeviceCtxRaw = nullptr;
- std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
+ std::string deviceOrdinal = std::to_string(deviceIndex);
int err = av_hwdevice_ctx_create(
&hardwareDeviceCtxRaw,
@@ -117,15 +112,17 @@ CudaDeviceInterface::~CudaDeviceInterface() {
void CudaDeviceInterface::initialize(
const AVStream* avStream,
- const UniqueDecodingAVFormatContext& avFormatCtx) {
+ const UniqueDecodingAVFormatContext& avFormatCtx,
+ const SharedAVCodecContext& codecContext) {
TORCH_CHECK(avStream != nullptr, "avStream is null");
+ codecContext_ = codecContext;
timeBase_ = avStream->time_base;
// TODO: Ideally, we should keep all interface implementations independent.
cpuInterface_ = createDeviceInterface(torch::kCPU);
TORCH_CHECK(
cpuInterface_ != nullptr, "Failed to create CPU device interface");
- cpuInterface_->initialize(avStream, avFormatCtx);
+ cpuInterface_->initialize(avStream, avFormatCtx, codecContext);
cpuInterface_->initializeVideo(
VideoStreamOptions(),
{},
@@ -287,9 +284,12 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
frameOutput.data = cpuFrameOutput.data.to(device_);
}
+ usingCPUFallback_ = true;
return;
}
+ usingCPUFallback_ = false;
+
// Above we checked that the AVFrame was on GPU, but that's not enough, we
// also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
// because this is what the NPP color conversion routines expect. This SHOULD
@@ -354,4 +354,12 @@ std::optional CudaDeviceInterface::findCodec(
return std::nullopt;
}
+std::string CudaDeviceInterface::getDetails() {
+ // Note: for this interface specifically the fallback is only known after a
+ // frame has been decoded, not before: that's when FFmpeg decides to fallback,
+ // so we can't know earlier.
+ return std::string("FFmpeg CUDA Device Interface. Using ") +
+ (usingCPUFallback_ ? "CPU fallback." : "NVDEC.");
+}
+
} // namespace facebook::torchcodec
diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h
index 1a8f184ec..9f171ee3c 100644
--- a/src/torchcodec/_core/CudaDeviceInterface.h
+++ b/src/torchcodec/_core/CudaDeviceInterface.h
@@ -22,7 +22,8 @@ class CudaDeviceInterface : public DeviceInterface {
void initialize(
const AVStream* avStream,
- const UniqueDecodingAVFormatContext& avFormatCtx) override;
+ const UniqueDecodingAVFormatContext& avFormatCtx,
+ const SharedAVCodecContext& codecContext) override;
void initializeVideo(
const VideoStreamOptions& videoStreamOptions,
@@ -39,6 +40,8 @@ class CudaDeviceInterface : public DeviceInterface {
std::optional preAllocatedOutputTensor =
std::nullopt) override;
+ std::string getDetails() override;
+
private:
// Our CUDA decoding code assumes NV12 format. In order to handle other
// kinds of input, we need to convert them to NV12. Our current implementation
@@ -59,6 +62,8 @@ class CudaDeviceInterface : public DeviceInterface {
// maybeConvertAVFrameToNV12().
std::unique_ptr nv12ConversionContext_;
std::unique_ptr nv12Conversion_;
+
+ bool usingCPUFallback_ = false;
};
} // namespace facebook::torchcodec
diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h
index cac29e838..773317e83 100644
--- a/src/torchcodec/_core/DeviceInterface.h
+++ b/src/torchcodec/_core/DeviceInterface.h
@@ -21,7 +21,7 @@ namespace facebook::torchcodec {
// Key for device interface registration with device type + variant support
struct DeviceInterfaceKey {
torch::DeviceType deviceType;
- std::string_view variant = "default"; // e.g., "default", "beta", etc.
+ std::string_view variant = "ffmpeg"; // e.g., "ffmpeg", "beta", etc.
bool operator<(const DeviceInterfaceKey& other) const {
if (deviceType != other.deviceType) {
@@ -54,7 +54,8 @@ class DeviceInterface {
// Initialize the device with parameters generic to all kinds of decoding.
virtual void initialize(
const AVStream* avStream,
- const UniqueDecodingAVFormatContext& avFormatCtx) = 0;
+ const UniqueDecodingAVFormatContext& avFormatCtx,
+ const SharedAVCodecContext& codecContext) = 0;
// Initialize the device with parameters specific to video decoding. There is
// a default empty implementation.
@@ -80,52 +81,51 @@ class DeviceInterface {
// Extension points for custom decoding paths
// ------------------------------------------
- // Override to return true if this device interface can decode packets
- // directly. This means that the following two member functions can both
- // be called:
- //
- // 1. sendPacket()
- // 2. receiveFrame()
- virtual bool canDecodePacketDirectly() const {
- return false;
- }
-
- // Moral equivalent of avcodec_send_packet()
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if decoder queue full, or
// other AVERROR on failure
- virtual int sendPacket([[maybe_unused]] ReferenceAVPacket& avPacket) {
+ // Default implementation uses FFmpeg directly
+ virtual int sendPacket(ReferenceAVPacket& avPacket) {
TORCH_CHECK(
- false,
- "Send/receive packet decoding not implemented for this device interface");
- return AVERROR(ENOSYS);
+ codecContext_ != nullptr,
+ "Codec context not available for default packet sending");
+ return avcodec_send_packet(codecContext_.get(), avPacket.get());
}
// Send an EOF packet to flush the decoder
// Returns AVSUCCESS on success, or other AVERROR on failure
+ // Default implementation uses FFmpeg directly
virtual int sendEOFPacket() {
TORCH_CHECK(
- false, "Send EOF packet not implemented for this device interface");
- return AVERROR(ENOSYS);
+ codecContext_ != nullptr,
+ "Codec context not available for default EOF packet sending");
+ return avcodec_send_packet(codecContext_.get(), nullptr);
}
- // Moral equivalent of avcodec_receive_frame()
// Returns AVSUCCESS on success, AVERROR(EAGAIN) if no frame ready,
// AVERROR_EOF if end of stream, or other AVERROR on failure
- virtual int receiveFrame([[maybe_unused]] UniqueAVFrame& avFrame) {
+ // Default implementation uses FFmpeg directly
+ virtual int receiveFrame(UniqueAVFrame& avFrame) {
TORCH_CHECK(
- false,
- "Send/receive packet decoding not implemented for this device interface");
- return AVERROR(ENOSYS);
+ codecContext_ != nullptr,
+ "Codec context not available for default frame receiving");
+ return avcodec_receive_frame(codecContext_.get(), avFrame.get());
}
// Flush remaining frames from decoder
virtual void flush() {
- // Default implementation is no-op for standard decoders
- // Custom decoders can override this method
+ TORCH_CHECK(
+ codecContext_ != nullptr,
+ "Codec context not available for default flushing");
+ avcodec_flush_buffers(codecContext_.get());
+ }
+
+ virtual std::string getDetails() {
+ return "";
}
protected:
torch::Device device_;
+ SharedAVCodecContext codecContext_;
};
using CreateDeviceInterfaceFn =
@@ -141,7 +141,7 @@ void validateDeviceInterface(
std::unique_ptr createDeviceInterface(
const torch::Device& device,
- const std::string_view variant = "default");
+ const std::string_view variant = "ffmpeg");
torch::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame);
diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp
index 14ef1cb94..1d9c2c089 100644
--- a/src/torchcodec/_core/Encoder.cpp
+++ b/src/torchcodec/_core/Encoder.cpp
@@ -4,10 +4,6 @@
#include "src/torchcodec/_core/Encoder.h"
#include "torch/types.h"
-extern "C" {
-#include
-}
-
namespace facebook::torchcodec {
namespace {
@@ -542,10 +538,17 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
} // namespace
VideoEncoder::~VideoEncoder() {
+ // TODO-VideoEncoder: Unify destructor with ~AudioEncoder()
if (avFormatContext_ && avFormatContext_->pb) {
- avio_flush(avFormatContext_->pb);
- avio_close(avFormatContext_->pb);
- avFormatContext_->pb = nullptr;
+ if (avFormatContext_->pb->error == 0) {
+ avio_flush(avFormatContext_->pb);
+ }
+ if (!avioContextHolder_) {
+ if (avFormatContext_->pb->error == 0) {
+ avio_close(avFormatContext_->pb);
+ }
+ avFormatContext_->pb = nullptr;
+ }
}
}
@@ -581,6 +584,36 @@ VideoEncoder::VideoEncoder(
initializeEncoder(videoStreamOptions);
}
+VideoEncoder::VideoEncoder(
+ const torch::Tensor& frames,
+ int frameRate,
+ std::string_view formatName,
+ std::unique_ptr avioContextHolder,
+ const VideoStreamOptions& videoStreamOptions)
+ : frames_(validateFrames(frames)),
+ inFrameRate_(frameRate),
+ avioContextHolder_(std::move(avioContextHolder)) {
+ setFFmpegLogLevel();
+ // Map mkv -> matroska when used as format name
+ formatName = (formatName == "mkv") ? "matroska" : formatName;
+ AVFormatContext* avFormatContext = nullptr;
+ int status = avformat_alloc_output_context2(
+ &avFormatContext, nullptr, formatName.data(), nullptr);
+
+ TORCH_CHECK(
+ avFormatContext != nullptr,
+ "Couldn't allocate AVFormatContext. ",
+ "Check the desired format? Got format=",
+ formatName,
+ ". ",
+ getFFMPEGErrorStringFromErrorCode(status));
+ avFormatContext_.reset(avFormatContext);
+
+ avFormatContext_->pb = avioContextHolder_->getAVIOContext();
+
+ initializeEncoder(videoStreamOptions);
+}
+
void VideoEncoder::initializeEncoder(
const VideoStreamOptions& videoStreamOptions) {
const AVCodec* avCodec =
@@ -751,6 +784,17 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
return avFrame;
}
+torch::Tensor VideoEncoder::encodeToTensor() {
+ TORCH_CHECK(
+ avioContextHolder_ != nullptr,
+ "Cannot encode to tensor, avio tensor context doesn't exist.");
+ encode();
+ auto avioToTensorContext =
+ dynamic_cast(avioContextHolder_.get());
+ TORCH_CHECK(avioToTensorContext != nullptr, "Invalid AVIO context holder.");
+ return avioToTensorContext->getOutputTensor();
+}
+
void VideoEncoder::encodeFrame(
AutoAVPacket& autoAVPacket,
const UniqueAVFrame& avFrame) {
diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h
index 62d30a624..168591616 100644
--- a/src/torchcodec/_core/Encoder.h
+++ b/src/torchcodec/_core/Encoder.h
@@ -141,8 +141,17 @@ class VideoEncoder {
std::string_view fileName,
const VideoStreamOptions& videoStreamOptions);
+ VideoEncoder(
+ const torch::Tensor& frames,
+ int frameRate,
+ std::string_view formatName,
+ std::unique_ptr avioContextHolder,
+ const VideoStreamOptions& videoStreamOptions);
+
void encode();
+ torch::Tensor encodeToTensor();
+
private:
void initializeEncoder(const VideoStreamOptions& videoStreamOptions);
UniqueAVFrame convertTensorToAVFrame(
@@ -153,7 +162,7 @@ class VideoEncoder {
UniqueEncodingAVFormatContext avFormatContext_;
UniqueAVCodecContext avCodecContext_;
- AVStream* avStream_;
+ AVStream* avStream_ = nullptr;
UniqueSwsContext swsContext_;
const torch::Tensor frames_;
@@ -167,6 +176,8 @@ class VideoEncoder {
int outHeight_ = -1;
AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE;
+ std::unique_ptr avioContextHolder_;
+
bool encodeWasCalled_ = false;
};
diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp
index 0570f06cf..97ff082e1 100644
--- a/src/torchcodec/_core/FFMPEGCommon.cpp
+++ b/src/torchcodec/_core/FFMPEGCommon.cpp
@@ -149,7 +149,7 @@ int getNumChannels(const UniqueAVFrame& avFrame) {
#endif
}
-int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
+int getNumChannels(const SharedAVCodecContext& avCodecContext) {
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
return avCodecContext->ch_layout.nb_channels;
diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h
index 19cddcc37..337616ddc 100644
--- a/src/torchcodec/_core/FFMPEGCommon.h
+++ b/src/torchcodec/_core/FFMPEGCommon.h
@@ -71,6 +71,14 @@ using UniqueEncodingAVFormatContext = std::unique_ptr<
using UniqueAVCodecContext = std::unique_ptr<
AVCodecContext,
Deleterp>;
+using SharedAVCodecContext = std::shared_ptr;
+
+// create SharedAVCodecContext with custom deleter
+inline SharedAVCodecContext makeSharedAVCodecContext(AVCodecContext* ctx) {
+ return SharedAVCodecContext(
+ ctx, Deleterp{});
+}
+
using UniqueAVFrame =
std::unique_ptr>;
using UniqueAVFilterGraph = std::unique_ptr<
@@ -171,7 +179,7 @@ const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec);
const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec);
int getNumChannels(const UniqueAVFrame& avFrame);
-int getNumChannels(const UniqueAVCodecContext& avCodecContext);
+int getNumChannels(const SharedAVCodecContext& avCodecContext);
void setDefaultChannelLayout(
UniqueAVCodecContext& avCodecContext,
diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp
index afc44d96d..605b814a8 100644
--- a/src/torchcodec/_core/FilterGraph.cpp
+++ b/src/torchcodec/_core/FilterGraph.cpp
@@ -130,7 +130,8 @@ FilterGraph::FilterGraph(
TORCH_CHECK(
status >= 0,
"Failed to configure filter graph: ",
- getFFMPEGErrorStringFromErrorCode(status));
+ getFFMPEGErrorStringFromErrorCode(status),
+ ", provided filters: " + filtersContext.filtergraphStr);
}
UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) {
diff --git a/src/torchcodec/_core/Frame.cpp b/src/torchcodec/_core/Frame.cpp
index 9fa87a1cb..62fb46c65 100644
--- a/src/torchcodec/_core/Frame.cpp
+++ b/src/torchcodec/_core/Frame.cpp
@@ -8,6 +8,11 @@
namespace facebook::torchcodec {
+FrameDims::FrameDims(int height, int width) : height(height), width(width) {
+ TORCH_CHECK(height > 0, "FrameDims.height must be > 0, got: ", height);
+ TORCH_CHECK(width > 0, "FrameDims.width must be > 0, got: ", width);
+}
+
FrameBatchOutput::FrameBatchOutput(
int64_t numFrames,
const FrameDims& outputDims,
diff --git a/src/torchcodec/_core/Frame.h b/src/torchcodec/_core/Frame.h
index 4b27d5bdd..67e4d2b79 100644
--- a/src/torchcodec/_core/Frame.h
+++ b/src/torchcodec/_core/Frame.h
@@ -19,7 +19,7 @@ struct FrameDims {
FrameDims() = default;
- FrameDims(int h, int w) : height(h), width(w) {}
+ FrameDims(int h, int w);
};
// All public video decoding entry points return either a FrameOutput or a
diff --git a/src/torchcodec/_core/NVCUVIDRuntimeLoader.cpp b/src/torchcodec/_core/NVCUVIDRuntimeLoader.cpp
new file mode 100644
index 000000000..2bb501fc2
--- /dev/null
+++ b/src/torchcodec/_core/NVCUVIDRuntimeLoader.cpp
@@ -0,0 +1,320 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#ifdef FBCODE_CAFFE2
+// No need to do anything on fbcode. NVCUVID is available there, we can take a
+// hard dependency on it.
+// The FBCODE_CAFFE2 macro is defined in the upstream fbcode build of torch, so
+// we can rely on it, that's what torch does too.
+
+namespace facebook::torchcodec {
+bool loadNVCUVIDLibrary() {
+ return true;
+}
+} // namespace facebook::torchcodec
+#else
+
+#include "src/torchcodec/_core/NVCUVIDRuntimeLoader.h"
+
+#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h"
+#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h"
+
+#include
+#include
+#include
+
+#if defined(WIN64) || defined(_WIN64)
+#include
+typedef HMODULE tHandle;
+#else
+#include
+typedef void* tHandle;
+#endif
+
+namespace facebook::torchcodec {
+
+/* clang-format off */
+// This file defines the logic to load the NVCUVID library **at runtime**,
+// along with the corresponding NVCUVID functions that we'll need.
+//
+// We do this because we *do not want* to link (statically or dynamically)
+// against libnvcuvid.so: it is not always available on the users machine! If we
+// were to link against libnvcuvid.so, that would mean that our
+// libtorchcodec_coreN.so would try to look for it when loaded at import time.
+// And if it's not on the users machine, that causes `import torchcodec` to
+// fail. Source: that's what we did, and we got user reports.
+//
+// So, we don't link against libnvcuvid.so. But we still want to call its
+// functions. So here's how it's done, we'll use cuvidCreateVideoParser as an
+// example, but it works the same for all. We are largely following the
+// instructions from the NVCUVID docs:
+// https://docs.nvidia.com/video-technologies/video-codec-sdk/13.0/nvdec-video-decoder-api-prog-guide/index.html#dynamic-loading-nvidia-components
+//
+// This:
+// typedef CUresult CUDAAPI tcuvidCreateVideoParser(CUvideoparser*, CUVIDPARSERPARAMS*);
+// defines tcuvidCreateVideoParser, which is the *type* of a *function*.
+// We define such a function of that type just below with:
+// static tcuvidCreateVideoParser* dl_cuvidCreateVideoParser = nullptr;
+// "dl" is for "dynamically loaded. For now dl_cuvidCreateVideoParser is
+// nullptr, but later it will be a proper function [pointer] that can be called
+// with dl_cuvidCreateVideoParser(...);
+//
+// For that to happen we need to call loadNVCUVIDLibrary(): in there, we first
+// dlopen(libnvcuvid.so) which loads the .so somewhere in memory. Then we call
+// dlsym(...), which binds dl_cuvidCreateVideoParser to its actual address: it
+// literally sets the value of the dl_cuvidCreateVideoParser pointer to the
+// address of the actual code section. If all went well, by now, we can safely
+// call dl_cuvidCreateVideoParser(...);
+// All of that happens at runtime *after* import time, when the first instance
+// of the Beta CUDA interface is created, i.e. only when the user explicitly
+// requests it.
+//
+// At the bottom of this file we have an `extern "C"` section with function
+// definitions like:
+//
+// CUresult CUDAAPI cuvidCreateVideoParser(
+// CUvideoparser* videoParser,
+// CUVIDPARSERPARAMS* parserParams) {...}
+//
+// These are the actual functions that are compiled against and called by the
+// Beta CUDA interface code. Crucially, these functions signature match exactly
+// the NVCUVID functions (as defined in cuviddec.h). Inside of
+// cuvidCreateVideoParser(...) we simply call the dl_cuvidCreateVideoParser
+// function [pointer] that we dynamically loaded earlier.
+//
+// At runtime, within the Beta CUDA interface code we have a fallback mechanism
+// to switch back to the CPU backend if any of the NVCUVID functions are not
+// available, or if libnvcuvid.so itself couldn't be found. This is what FFmpeg
+// does too.
+
+
+// Function pointers types
+typedef CUresult CUDAAPI tcuvidCreateVideoParser(CUvideoparser*, CUVIDPARSERPARAMS*);
+typedef CUresult CUDAAPI tcuvidParseVideoData(CUvideoparser, CUVIDSOURCEDATAPACKET*);
+typedef CUresult CUDAAPI tcuvidDestroyVideoParser(CUvideoparser);
+typedef CUresult CUDAAPI tcuvidGetDecoderCaps(CUVIDDECODECAPS*);
+typedef CUresult CUDAAPI tcuvidCreateDecoder(CUvideodecoder*, CUVIDDECODECREATEINFO*);
+typedef CUresult CUDAAPI tcuvidDestroyDecoder(CUvideodecoder);
+typedef CUresult CUDAAPI tcuvidDecodePicture(CUvideodecoder, CUVIDPICPARAMS*);
+typedef CUresult CUDAAPI tcuvidMapVideoFrame(CUvideodecoder, int, unsigned int*, unsigned int*, CUVIDPROCPARAMS*);
+typedef CUresult CUDAAPI tcuvidUnmapVideoFrame(CUvideodecoder, unsigned int);
+typedef CUresult CUDAAPI tcuvidMapVideoFrame64(CUvideodecoder, int, unsigned long long*, unsigned int*, CUVIDPROCPARAMS*);
+typedef CUresult CUDAAPI tcuvidUnmapVideoFrame64(CUvideodecoder, unsigned long long);
+/* clang-format on */
+
+// Global function pointers - will be dynamically loaded
+static tcuvidCreateVideoParser* dl_cuvidCreateVideoParser = nullptr;
+static tcuvidParseVideoData* dl_cuvidParseVideoData = nullptr;
+static tcuvidDestroyVideoParser* dl_cuvidDestroyVideoParser = nullptr;
+static tcuvidGetDecoderCaps* dl_cuvidGetDecoderCaps = nullptr;
+static tcuvidCreateDecoder* dl_cuvidCreateDecoder = nullptr;
+static tcuvidDestroyDecoder* dl_cuvidDestroyDecoder = nullptr;
+static tcuvidDecodePicture* dl_cuvidDecodePicture = nullptr;
+static tcuvidMapVideoFrame* dl_cuvidMapVideoFrame = nullptr;
+static tcuvidUnmapVideoFrame* dl_cuvidUnmapVideoFrame = nullptr;
+static tcuvidMapVideoFrame64* dl_cuvidMapVideoFrame64 = nullptr;
+static tcuvidUnmapVideoFrame64* dl_cuvidUnmapVideoFrame64 = nullptr;
+
+static tHandle g_nvcuvid_handle = nullptr;
+static std::mutex g_nvcuvid_mutex;
+
+bool isLoaded() {
+ return (
+ g_nvcuvid_handle && dl_cuvidCreateVideoParser && dl_cuvidParseVideoData &&
+ dl_cuvidDestroyVideoParser && dl_cuvidGetDecoderCaps &&
+ dl_cuvidCreateDecoder && dl_cuvidDestroyDecoder &&
+ dl_cuvidDecodePicture && dl_cuvidMapVideoFrame &&
+ dl_cuvidUnmapVideoFrame && dl_cuvidMapVideoFrame64 &&
+ dl_cuvidUnmapVideoFrame64);
+}
+
+template
+T* bindFunction(const char* functionName) {
+#if defined(WIN64) || defined(_WIN64)
+ return reinterpret_cast(GetProcAddress(g_nvcuvid_handle, functionName));
+#else
+ return reinterpret_cast(dlsym(g_nvcuvid_handle, functionName));
+#endif
+}
+
+bool _loadLibrary() {
+ // Helper that just calls dlopen or equivalent on Windows. In a separate
+ // function because of the #ifdef uglyness.
+#if defined(WIN64) || defined(_WIN64)
+#ifdef UNICODE
+ static LPCWSTR nvcuvidDll = L"nvcuvid.dll";
+#else
+ static LPCSTR nvcuvidDll = "nvcuvid.dll";
+#endif
+ g_nvcuvid_handle = LoadLibrary(nvcuvidDll);
+ if (g_nvcuvid_handle == nullptr) {
+ return false;
+ }
+#else
+ g_nvcuvid_handle = dlopen("libnvcuvid.so", RTLD_NOW);
+ if (g_nvcuvid_handle == nullptr) {
+ g_nvcuvid_handle = dlopen("libnvcuvid.so.1", RTLD_NOW);
+ }
+ if (g_nvcuvid_handle == nullptr) {
+ return false;
+ }
+#endif
+
+ return true;
+}
+
+bool loadNVCUVIDLibrary() {
+ // Loads NVCUVID library and all required function pointers.
+ // Returns true on success, false on failure.
+ std::lock_guard lock(g_nvcuvid_mutex);
+
+ if (isLoaded()) {
+ return true;
+ }
+
+ if (!_loadLibrary()) {
+ return false;
+ }
+
+ // Load all function pointers. They'll be set to nullptr if not found.
+ dl_cuvidCreateVideoParser =
+ bindFunction("cuvidCreateVideoParser");
+ dl_cuvidParseVideoData =
+ bindFunction("cuvidParseVideoData");
+ dl_cuvidDestroyVideoParser =
+ bindFunction("cuvidDestroyVideoParser");
+ dl_cuvidGetDecoderCaps =
+ bindFunction("cuvidGetDecoderCaps");
+ dl_cuvidCreateDecoder =
+ bindFunction("cuvidCreateDecoder");
+ dl_cuvidDestroyDecoder =
+ bindFunction("cuvidDestroyDecoder");
+ dl_cuvidDecodePicture =
+ bindFunction("cuvidDecodePicture");
+ dl_cuvidMapVideoFrame =
+ bindFunction("cuvidMapVideoFrame");
+ dl_cuvidUnmapVideoFrame =
+ bindFunction("cuvidUnmapVideoFrame");
+ dl_cuvidMapVideoFrame64 =
+ bindFunction("cuvidMapVideoFrame64");
+ dl_cuvidUnmapVideoFrame64 =
+ bindFunction("cuvidUnmapVideoFrame64");
+
+ return isLoaded();
+}
+
+} // namespace facebook::torchcodec
+
+extern "C" {
+
+CUresult CUDAAPI cuvidCreateVideoParser(
+ CUvideoparser* videoParser,
+ CUVIDPARSERPARAMS* parserParams) {
+ TORCH_CHECK(
+ facebook::torchcodec::dl_cuvidCreateVideoParser,
+ "cuvidCreateVideoParser called but NVCUVID not loaded!");
+ return facebook::torchcodec::dl_cuvidCreateVideoParser(
+ videoParser, parserParams);
+}
+
+CUresult CUDAAPI cuvidParseVideoData(
+ CUvideoparser videoParser,
+ CUVIDSOURCEDATAPACKET* cuvidPacket) {
+ TORCH_CHECK(
+ facebook::torchcodec::dl_cuvidParseVideoData,
+ "cuvidParseVideoData called but NVCUVID not loaded!");
+ return facebook::torchcodec::dl_cuvidParseVideoData(videoParser, cuvidPacket);
+}
+
+CUresult CUDAAPI cuvidDestroyVideoParser(CUvideoparser videoParser) {
+ TORCH_CHECK(
+ facebook::torchcodec::dl_cuvidDestroyVideoParser,
+ "cuvidDestroyVideoParser called but NVCUVID not loaded!");
+ return facebook::torchcodec::dl_cuvidDestroyVideoParser(videoParser);
+}
+
+CUresult CUDAAPI cuvidGetDecoderCaps(CUVIDDECODECAPS* caps) {
+ TORCH_CHECK(
+ facebook::torchcodec::dl_cuvidGetDecoderCaps,
+ "cuvidGetDecoderCaps called but NVCUVID not loaded!");
+ return facebook::torchcodec::dl_cuvidGetDecoderCaps(caps);
+}
+
+CUresult CUDAAPI cuvidCreateDecoder(
+ CUvideodecoder* decoder,
+ CUVIDDECODECREATEINFO* decoderParams) {
+ TORCH_CHECK(
+ facebook::torchcodec::dl_cuvidCreateDecoder,
+ "cuvidCreateDecoder called but NVCUVID not loaded!");
+ return facebook::torchcodec::dl_cuvidCreateDecoder(decoder, decoderParams);
+}
+
+CUresult CUDAAPI cuvidDestroyDecoder(CUvideodecoder decoder) {
+ TORCH_CHECK(
+ facebook::torchcodec::dl_cuvidDestroyDecoder,
+ "cuvidDestroyDecoder called but NVCUVID not loaded!");
+ return facebook::torchcodec::dl_cuvidDestroyDecoder(decoder);
+}
+
+CUresult CUDAAPI
+cuvidDecodePicture(CUvideodecoder decoder, CUVIDPICPARAMS* picParams) {
+ TORCH_CHECK(
+ facebook::torchcodec::dl_cuvidDecodePicture,
+ "cuvidDecodePicture called but NVCUVID not loaded!");
+ return facebook::torchcodec::dl_cuvidDecodePicture(decoder, picParams);
+}
+
+#if !defined(__CUVID_DEVPTR64) || defined(__CUVID_INTERNAL)
+// We need to protect the definition of the 32bit versions under the above
+// conditions (see cuviddec.h). Defining them unconditionally would cause
+// conflict compilation errors when cuviddec.h redefines those to the 64bit
+// versions.
+CUresult CUDAAPI cuvidMapVideoFrame(
+ CUvideodecoder decoder,
+ int pixIndex,
+ unsigned int* framePtr,
+ unsigned int* pitch,
+ CUVIDPROCPARAMS* procParams) {
+ TORCH_CHECK(
+ facebook::torchcodec::dl_cuvidMapVideoFrame,
+ "cuvidMapVideoFrame called but NVCUVID not loaded!");
+ return facebook::torchcodec::dl_cuvidMapVideoFrame(
+ decoder, pixIndex, framePtr, pitch, procParams);
+}
+
+CUresult CUDAAPI
+cuvidUnmapVideoFrame(CUvideodecoder decoder, unsigned int framePtr) {
+ TORCH_CHECK(
+ facebook::torchcodec::dl_cuvidUnmapVideoFrame,
+ "cuvidUnmapVideoFrame called but NVCUVID not loaded!");
+ return facebook::torchcodec::dl_cuvidUnmapVideoFrame(decoder, framePtr);
+}
+#endif
+
+CUresult CUDAAPI cuvidMapVideoFrame64(
+ CUvideodecoder decoder,
+ int pixIndex,
+ unsigned long long* framePtr,
+ unsigned int* pitch,
+ CUVIDPROCPARAMS* procParams) {
+ TORCH_CHECK(
+ facebook::torchcodec::dl_cuvidMapVideoFrame64,
+ "cuvidMapVideoFrame64 called but NVCUVID not loaded!");
+ return facebook::torchcodec::dl_cuvidMapVideoFrame64(
+ decoder, pixIndex, framePtr, pitch, procParams);
+}
+
+CUresult CUDAAPI
+cuvidUnmapVideoFrame64(CUvideodecoder decoder, unsigned long long framePtr) {
+ TORCH_CHECK(
+ facebook::torchcodec::dl_cuvidUnmapVideoFrame64,
+ "cuvidUnmapVideoFrame64 called but NVCUVID not loaded!");
+ return facebook::torchcodec::dl_cuvidUnmapVideoFrame64(decoder, framePtr);
+}
+
+} // extern "C"
+
+#endif // FBCODE_CAFFE2
diff --git a/src/torchcodec/_core/NVCUVIDRuntimeLoader.h b/src/torchcodec/_core/NVCUVIDRuntimeLoader.h
new file mode 100644
index 000000000..e6ee40a05
--- /dev/null
+++ b/src/torchcodec/_core/NVCUVIDRuntimeLoader.h
@@ -0,0 +1,14 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+// All rights reserved.
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+
+namespace facebook::torchcodec {
+
+// See note in corresponding cpp file
+bool loadNVCUVIDLibrary();
+
+} // namespace facebook::torchcodec
diff --git a/src/torchcodec/_core/NVDECCache.cpp b/src/torchcodec/_core/NVDECCache.cpp
index 87ab5b0dc..302433cd4 100644
--- a/src/torchcodec/_core/NVDECCache.cpp
+++ b/src/torchcodec/_core/NVDECCache.cpp
@@ -7,6 +7,7 @@
#include
#include
+#include "src/torchcodec/_core/CUDACommon.h"
#include "src/torchcodec/_core/FFMPEGCommon.h"
#include "src/torchcodec/_core/NVDECCache.h"
@@ -19,20 +20,9 @@ extern "C" {
namespace facebook::torchcodec {
-NVDECCache& NVDECCache::getCache(int deviceIndex) {
- const int MAX_CUDA_GPUS = 128;
- TORCH_CHECK(
- deviceIndex >= -1 && deviceIndex < MAX_CUDA_GPUS,
- "Invalid device index = ",
- deviceIndex);
+NVDECCache& NVDECCache::getCache(const torch::Device& device) {
static NVDECCache cacheInstances[MAX_CUDA_GPUS];
- if (deviceIndex == -1) {
- // TODO NVDEC P3: Unify with existing getNonNegativeDeviceIndex()
- TORCH_CHECK(
- cudaGetDevice(&deviceIndex) == cudaSuccess,
- "Failed to get current CUDA device.");
- }
- return cacheInstances[deviceIndex];
+ return cacheInstances[getDeviceIndex(device)];
}
UniqueCUvideodecoder NVDECCache::getDecoder(CUVIDEOFORMAT* videoFormat) {
diff --git a/src/torchcodec/_core/NVDECCache.h b/src/torchcodec/_core/NVDECCache.h
index 17fc99902..a0f2fb862 100644
--- a/src/torchcodec/_core/NVDECCache.h
+++ b/src/torchcodec/_core/NVDECCache.h
@@ -11,6 +11,9 @@
#include
#include
+#include
+
+#include "src/torchcodec/_core/NVCUVIDRuntimeLoader.h"
#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h"
#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h"
@@ -36,7 +39,7 @@ using UniqueCUvideodecoder =
// per GPU device, and it is accessed through the static getCache() method.
class NVDECCache {
public:
- static NVDECCache& getCache(int deviceIndex);
+ static NVDECCache& getCache(const torch::Device& device);
// Get decoder from cache - returns nullptr if none available
UniqueCUvideodecoder getDecoder(CUVIDEOFORMAT* videoFormat);
@@ -68,11 +71,6 @@ class NVDECCache {
CacheKey(const CacheKey&) = default;
CacheKey& operator=(const CacheKey&) = default;
- // TODONVDEC P2: we only implement operator< which is enough for std::map,
- // but:
- // - we should consider using std::unordered_map
- // - we should consider a more sophisticated and potentially less strict
- // cache key comparison logic
bool operator<(const CacheKey& other) const {
return std::tie(
codecType,
diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp
index d06c47922..8d9e9f651 100644
--- a/src/torchcodec/_core/SingleStreamDecoder.cpp
+++ b/src/torchcodec/_core/SingleStreamDecoder.cpp
@@ -12,6 +12,7 @@
#include
#include
#include
+#include "Metadata.h"
#include "torch/types.h"
namespace facebook::torchcodec {
@@ -429,7 +430,6 @@ void SingleStreamDecoder::addStream(
TORCH_CHECK(
deviceInterface_ != nullptr,
"Failed to create device interface. This should never happen, please report.");
- deviceInterface_->initialize(streamInfo.stream, formatContext_);
// TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
// addStream() which is supposed to be generic
@@ -441,7 +441,7 @@ void SingleStreamDecoder::addStream(
AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);
TORCH_CHECK(codecContext != nullptr);
- streamInfo.codecContext.reset(codecContext);
+ streamInfo.codecContext = makeSharedAVCodecContext(codecContext);
int retVal = avcodec_parameters_to_context(
streamInfo.codecContext.get(), streamInfo.stream->codecpar);
@@ -453,14 +453,19 @@ void SingleStreamDecoder::addStream(
// Note that we must make sure to register the harware device context
// with the codec context before calling avcodec_open2(). Otherwise, decoding
// will happen on the CPU and not the hardware device.
- deviceInterface_->registerHardwareDeviceWithCodec(codecContext);
+ deviceInterface_->registerHardwareDeviceWithCodec(
+ streamInfo.codecContext.get());
retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal));
- codecContext->time_base = streamInfo.stream->time_base;
+ streamInfo.codecContext->time_base = streamInfo.stream->time_base;
+
+ // Initialize the device interface with the codec context
+ deviceInterface_->initialize(
+ streamInfo.stream, formatContext_, streamInfo.codecContext);
containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
- std::string(avcodec_get_name(codecContext->codec_id));
+ std::string(avcodec_get_name(streamInfo.codecContext->codec_id));
// We will only need packets from the active stream, so we tell FFmpeg to
// discard packets from the other streams. Note that av_read_frame() may still
@@ -523,6 +528,7 @@ void SingleStreamDecoder::addVideoStream(
if (transform->getOutputFrameDims().has_value()) {
resizedOutputDims_ = transform->getOutputFrameDims().value();
}
+ transform->validate(streamMetadata);
// Note that we are claiming ownership of the transform objects passed in to
// us.
@@ -1149,8 +1155,6 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() {
getFFMPEGErrorStringFromErrorCode(status));
decodeStats_.numFlushes++;
- avcodec_flush_buffers(streamInfo.codecContext.get());
-
deviceInterface_->flush();
}
@@ -1169,24 +1173,16 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
cursorWasJustSet_ = false;
}
- StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
UniqueAVFrame avFrame(av_frame_alloc());
AutoAVPacket autoAVPacket;
int status = AVSUCCESS;
bool reachedEOF = false;
- // TODONVDEC P2: Instead of calling canDecodePacketDirectly() and rely on
- // if/else blocks to dispatch to the interface or to FFmpeg, consider *always*
- // dispatching to the interface. The default implementation of the interface's
- // receiveFrame and sendPacket could just be calling avcodec_receive_frame and
- // avcodec_send_packet. This would make the decoding loop even more generic.
+ // The default implementation uses avcodec_receive_frame and
+ // avcodec_send_packet, while specialized interfaces can override for
+ // hardware-specific optimizations.
while (true) {
- if (deviceInterface_->canDecodePacketDirectly()) {
- status = deviceInterface_->receiveFrame(avFrame);
- } else {
- status =
- avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
- }
+ status = deviceInterface_->receiveFrame(avFrame);
if (status != AVSUCCESS && status != AVERROR(EAGAIN)) {
// Non-retriable error
@@ -1222,13 +1218,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
if (status == AVERROR_EOF) {
// End of file reached. We must drain the decoder
- if (deviceInterface_->canDecodePacketDirectly()) {
- status = deviceInterface_->sendEOFPacket();
- } else {
- status = avcodec_send_packet(
- streamInfo.codecContext.get(),
- /*avpkt=*/nullptr);
- }
+ status = deviceInterface_->sendEOFPacket();
TORCH_CHECK(
status >= AVSUCCESS,
"Could not flush decoder: ",
@@ -1253,11 +1243,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
// We got a valid packet. Send it to the decoder, and we'll receive it in
// the next iteration.
- if (deviceInterface_->canDecodePacketDirectly()) {
- status = deviceInterface_->sendPacket(packet);
- } else {
- status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
- }
+ status = deviceInterface_->sendPacket(packet);
TORCH_CHECK(
status >= AVSUCCESS,
"Could not push packet to decoder: ",
@@ -1716,4 +1702,9 @@ double SingleStreamDecoder::getPtsSecondsForFrame(int64_t frameIndex) {
streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase);
}
+std::string SingleStreamDecoder::getDeviceInterfaceDetails() const {
+ TORCH_CHECK(deviceInterface_ != nullptr, "Device interface doesn't exist.");
+ return deviceInterface_->getDetails();
+}
+
} // namespace facebook::torchcodec
diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h
index 48821ff09..4d4c11aa2 100644
--- a/src/torchcodec/_core/SingleStreamDecoder.h
+++ b/src/torchcodec/_core/SingleStreamDecoder.h
@@ -186,6 +186,8 @@ class SingleStreamDecoder {
DecodeStats getDecodeStats() const;
void resetDecodeStats();
+ std::string getDeviceInterfaceDetails() const;
+
private:
// --------------------------------------------------------------------------
// STREAMINFO AND ASSOCIATED STRUCTS
@@ -221,7 +223,7 @@ class SingleStreamDecoder {
AVMediaType avMediaType = AVMEDIA_TYPE_UNKNOWN;
AVRational timeBase = {};
- UniqueAVCodecContext codecContext;
+ SharedAVCodecContext codecContext;
// The FrameInfo indices we built when scanFileAndUpdateMetadataAndIndex was
// called.
@@ -311,7 +313,7 @@ class SingleStreamDecoder {
int streamIndex,
AVMediaType mediaType,
const torch::Device& device = torch::kCPU,
- const std::string_view deviceVariant = "default",
+ const std::string_view deviceVariant = "ffmpeg",
std::optional ffmpegThreadCount = std::nullopt);
// Returns the "best" stream index for a given media type. The "best" is
diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h
index 7728a676e..e5ab256e1 100644
--- a/src/torchcodec/_core/StreamOptions.h
+++ b/src/torchcodec/_core/StreamOptions.h
@@ -41,8 +41,8 @@ struct VideoStreamOptions {
// By default we use CPU for decoding for both C++ and python users.
torch::Device device = torch::kCPU;
- // Device variant (e.g., "default", "beta", etc.)
- std::string_view deviceVariant = "default";
+ // Device variant (e.g., "ffmpeg", "beta", etc.)
+ std::string_view deviceVariant = "ffmpeg";
// Encoding options
// TODO-VideoEncoder: Consider adding other optional fields here
diff --git a/src/torchcodec/_core/Transform.cpp b/src/torchcodec/_core/Transform.cpp
index d0a5104f3..6083986e1 100644
--- a/src/torchcodec/_core/Transform.cpp
+++ b/src/torchcodec/_core/Transform.cpp
@@ -57,4 +57,31 @@ int ResizeTransform::getSwsFlags() const {
return toSwsInterpolation(interpolationMode_);
}
+CropTransform::CropTransform(const FrameDims& dims, int x, int y)
+ : outputDims_(dims), x_(x), y_(y) {
+ TORCH_CHECK(x_ >= 0, "Crop x position must be >= 0, got: ", x_);
+ TORCH_CHECK(y_ >= 0, "Crop y position must be >= 0, got: ", y_);
+}
+
+std::string CropTransform::getFilterGraphCpu() const {
+ return "crop=" + std::to_string(outputDims_.width) + ":" +
+ std::to_string(outputDims_.height) + ":" + std::to_string(x_) + ":" +
+ std::to_string(y_) + ":exact=1";
+}
+
+std::optional CropTransform::getOutputFrameDims() const {
+ return outputDims_;
+}
+
+void CropTransform::validate(const StreamMetadata& streamMetadata) const {
+ TORCH_CHECK(x_ <= streamMetadata.width, "Crop x position out of bounds");
+ TORCH_CHECK(
+ x_ + outputDims_.width <= streamMetadata.width,
+ "Crop x position out of bounds")
+ TORCH_CHECK(y_ <= streamMetadata.height, "Crop y position out of bounds");
+ TORCH_CHECK(
+ y_ + outputDims_.height <= streamMetadata.height,
+ "Crop y position out of bounds");
+}
+
} // namespace facebook::torchcodec
diff --git a/src/torchcodec/_core/Transform.h b/src/torchcodec/_core/Transform.h
index 6aea255ab..28d8c28a2 100644
--- a/src/torchcodec/_core/Transform.h
+++ b/src/torchcodec/_core/Transform.h
@@ -9,6 +9,7 @@
#include
#include
#include "src/torchcodec/_core/Frame.h"
+#include "src/torchcodec/_core/Metadata.h"
namespace facebook::torchcodec {
@@ -33,6 +34,16 @@ class Transform {
virtual bool isResize() const {
return false;
}
+
+ // The validity of some transforms depends on the characteristics of the
+ // AVStream they're being applied to. For example, some transforms will
+ // specify coordinates inside a frame, we need to validate that those are
+ // within the frame's bounds.
+ //
+ // Note that the validation function does not return anything. We expect
+ // invalid configurations to throw an exception.
+ virtual void validate(
+ [[maybe_unused]] const StreamMetadata& streamMetadata) const {}
};
class ResizeTransform : public Transform {
@@ -56,4 +67,18 @@ class ResizeTransform : public Transform {
InterpolationMode interpolationMode_;
};
+class CropTransform : public Transform {
+ public:
+ CropTransform(const FrameDims& dims, int x, int y);
+
+ std::string getFilterGraphCpu() const override;
+ std::optional getOutputFrameDims() const override;
+ void validate(const StreamMetadata& streamMetadata) const override;
+
+ private:
+ FrameDims outputDims_;
+ int x_;
+ int y_;
+};
+
} // namespace facebook::torchcodec
diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py
index 24e54af0e..55ff697b3 100644
--- a/src/torchcodec/_core/__init__.py
+++ b/src/torchcodec/_core/__init__.py
@@ -14,6 +14,7 @@
)
from .ops import (
_add_video_stream,
+ _get_backend_details,
_get_key_frame_indices,
_test_frame_pts_equality,
add_audio_stream,
@@ -26,6 +27,8 @@
encode_audio_to_file_like,
encode_audio_to_tensor,
encode_video_to_file,
+ encode_video_to_file_like,
+ encode_video_to_tensor,
get_ffmpeg_library_versions,
get_frame_at_index,
get_frame_at_pts,
diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp
index 5ba98e2c1..c6204de8c 100644
--- a/src/torchcodec/_core/custom_ops.cpp
+++ b/src/torchcodec/_core/custom_ops.cpp
@@ -32,20 +32,24 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
m.def(
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
- m.def(
- "encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
m.def(
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
m.def(
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
+ m.def(
+ "encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
+ m.def(
+ "encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
+ m.def(
+ "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()");
m.def(
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
m.def(
"_create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor");
m.def(
- "_add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"default\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()");
+ "_add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()");
m.def(
- "add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"default\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()");
+ "add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()");
m.def(
"add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()");
m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()");
@@ -70,6 +74,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str");
m.def("_get_json_ffmpeg_library_versions() -> str");
+ m.def("_get_backend_details(Tensor(a!) decoder) -> str");
m.def(
"_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool");
m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()");
@@ -212,6 +217,26 @@ Transform* makeResizeTransform(
return new ResizeTransform(FrameDims(height, width));
}
+// Crop transform specs take the form:
+//
+// "crop, , , , "
+//
+// Where "crop" is the string literal and , , and are
+// positive integers. and are the x and y coordinates of the top left
+// corner of the crop. Note that we follow the PyTorch convention of (height,
+// width) for specifying image dimensions; FFmpeg uses (width, height).
+Transform* makeCropTransform(
+ const std::vector& cropTransformSpec) {
+ TORCH_CHECK(
+ cropTransformSpec.size() == 5,
+ "cropTransformSpec must have 5 elements including its name");
+ int height = checkedToPositiveInt(cropTransformSpec[1]);
+ int width = checkedToPositiveInt(cropTransformSpec[2]);
+ int x = checkedToPositiveInt(cropTransformSpec[3]);
+ int y = checkedToPositiveInt(cropTransformSpec[4]);
+ return new CropTransform(FrameDims(height, width), x, y);
+}
+
std::vector split(const std::string& str, char delimiter) {
std::vector tokens;
std::string token;
@@ -239,6 +264,8 @@ std::vector makeTransforms(const std::string& transformSpecsRaw) {
auto name = transformSpec[0];
if (name == "resize") {
transforms.push_back(makeResizeTransform(transformSpec));
+ } else if (name == "crop") {
+ transforms.push_back(makeCropTransform(transformSpec));
} else {
TORCH_CHECK(false, "Invalid transform name: " + name);
}
@@ -319,7 +346,7 @@ void _add_video_stream(
std::optional dimension_order = std::nullopt,
std::optional stream_index = std::nullopt,
std::string_view device = "cpu",
- std::string_view device_variant = "default",
+ std::string_view device_variant = "ffmpeg",
std::string_view transform_specs = "",
std::optional>
custom_frame_mappings = std::nullopt,
@@ -376,7 +403,7 @@ void add_video_stream(
std::optional dimension_order = std::nullopt,
std::optional stream_index = std::nullopt,
std::string_view device = "cpu",
- std::string_view device_variant = "default",
+ std::string_view device_variant = "ffmpeg",
std::string_view transform_specs = "",
const std::optional>&
custom_frame_mappings = std::nullopt) {
@@ -498,21 +525,6 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
return makeOpsAudioFramesOutput(result);
}
-void encode_video_to_file(
- const at::Tensor& frames,
- int64_t frame_rate,
- std::string_view file_name,
- std::optional crf = std::nullopt) {
- VideoStreamOptions videoStreamOptions;
- videoStreamOptions.crf = crf;
- VideoEncoder(
- frames,
- validateInt64ToInt(frame_rate, "frame_rate"),
- file_name,
- videoStreamOptions)
- .encode();
-}
-
void encode_audio_to_file(
const at::Tensor& samples,
int64_t sample_rate,
@@ -587,6 +599,62 @@ void _encode_audio_to_file_like(
encoder.encode();
}
+void encode_video_to_file(
+ const at::Tensor& frames,
+ int64_t frame_rate,
+ std::string_view file_name,
+ std::optional crf = std::nullopt) {
+ VideoStreamOptions videoStreamOptions;
+ videoStreamOptions.crf = crf;
+ VideoEncoder(
+ frames,
+ validateInt64ToInt(frame_rate, "frame_rate"),
+ file_name,
+ videoStreamOptions)
+ .encode();
+}
+
+at::Tensor encode_video_to_tensor(
+ const at::Tensor& frames,
+ int64_t frame_rate,
+ std::string_view format,
+ std::optional crf = std::nullopt) {
+ auto avioContextHolder = std::make_unique();
+ VideoStreamOptions videoStreamOptions;
+ videoStreamOptions.crf = crf;
+ return VideoEncoder(
+ frames,
+ validateInt64ToInt(frame_rate, "frame_rate"),
+ format,
+ std::move(avioContextHolder),
+ videoStreamOptions)
+ .encodeToTensor();
+}
+
+void _encode_video_to_file_like(
+ const at::Tensor& frames,
+ int64_t frame_rate,
+ std::string_view format,
+ int64_t file_like_context,
+ std::optional crf = std::nullopt) {
+ auto fileLikeContext =
+ reinterpret_cast(file_like_context);
+ TORCH_CHECK(
+ fileLikeContext != nullptr, "file_like_context must be a valid pointer");
+ std::unique_ptr avioContextHolder(fileLikeContext);
+
+ VideoStreamOptions videoStreamOptions;
+ videoStreamOptions.crf = crf;
+
+ VideoEncoder encoder(
+ frames,
+ validateInt64ToInt(frame_rate, "frame_rate"),
+ format,
+ std::move(avioContextHolder),
+ videoStreamOptions);
+ encoder.encode();
+}
+
// For testing only. We need to implement this operation as a core library
// function because what we're testing is round-tripping pts values as
// double-precision floating point numbers from C++ to Python and back to C++.
@@ -828,6 +896,11 @@ std::string _get_json_ffmpeg_library_versions() {
return ss.str();
}
+std::string get_backend_details(at::Tensor& decoder) {
+ auto videoDecoder = unwrapTensorToGetDecoder(decoder);
+ return videoDecoder->getDeviceInterfaceDetails();
+}
+
// Scans video packets to get more accurate metadata like frame count, exact
// keyframe positions, etc. Exact keyframe positions are useful for efficient
// accurate seeking. Note that this function reads the entire video but it does
@@ -847,9 +920,11 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl("encode_audio_to_file", &encode_audio_to_file);
- m.impl("encode_video_to_file", &encode_video_to_file);
m.impl("encode_audio_to_tensor", &encode_audio_to_tensor);
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
+ m.impl("encode_video_to_file", &encode_video_to_file);
+ m.impl("encode_video_to_tensor", &encode_video_to_tensor);
+ m.impl("_encode_video_to_file_like", &_encode_video_to_file_like);
m.impl("seek_to_pts", &seek_to_pts);
m.impl("add_video_stream", &add_video_stream);
m.impl("_add_video_stream", &_add_video_stream);
@@ -870,6 +945,8 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl(
"scan_all_streams_to_update_metadata",
&scan_all_streams_to_update_metadata);
+
+ m.impl("_get_backend_details", &get_backend_details);
}
} // namespace facebook::torchcodec
diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py
index 44dc89e2b..32995c964 100644
--- a/src/torchcodec/_core/ops.py
+++ b/src/torchcodec/_core/ops.py
@@ -69,7 +69,7 @@ def load_torchcodec_shared_libraries():
raise RuntimeError(
f"""Could not load libtorchcodec. Likely causes:
1. FFmpeg is not properly installed in your environment. We support
- versions 4, 5, 6 and 7.
+ versions 4, 5, 6, and 7 on all platforms, and 8 on Mac and Linux.
2. The PyTorch version ({torch.__version__}) is not compatible with
this version of TorchCodec. Refer to the version compatibility
table:
@@ -92,15 +92,21 @@ def load_torchcodec_shared_libraries():
encode_audio_to_file = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns.encode_audio_to_file.default
)
-encode_video_to_file = torch._dynamo.disallow_in_graph(
- torch.ops.torchcodec_ns.encode_video_to_file.default
-)
encode_audio_to_tensor = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns.encode_audio_to_tensor.default
)
_encode_audio_to_file_like = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns._encode_audio_to_file_like.default
)
+encode_video_to_file = torch._dynamo.disallow_in_graph(
+ torch.ops.torchcodec_ns.encode_video_to_file.default
+)
+encode_video_to_tensor = torch._dynamo.disallow_in_graph(
+ torch.ops.torchcodec_ns.encode_video_to_tensor.default
+)
+_encode_video_to_file_like = torch._dynamo.disallow_in_graph(
+ torch.ops.torchcodec_ns._encode_video_to_file_like.default
+)
create_from_tensor = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns.create_from_tensor.default
)
@@ -136,6 +142,7 @@ def load_torchcodec_shared_libraries():
_get_json_ffmpeg_library_versions = (
torch.ops.torchcodec_ns._get_json_ffmpeg_library_versions.default
)
+_get_backend_details = torch.ops.torchcodec_ns._get_backend_details.default
# =============================
@@ -200,6 +207,33 @@ def encode_audio_to_file_like(
)
+def encode_video_to_file_like(
+ frames: torch.Tensor,
+ frame_rate: int,
+ format: str,
+ file_like: Union[io.RawIOBase, io.BufferedIOBase],
+ crf: Optional[int] = None,
+) -> None:
+ """Encode video frames to a file-like object.
+
+ Args:
+ frames: Video frames tensor
+ frame_rate: Frame rate in frames per second
+ format: Video format (e.g., "mp4", "mov", "mkv")
+ file_like: File-like object that supports write() and seek() methods
+ crf: Optional constant rate factor for encoding quality
+ """
+ assert _pybind_ops is not None
+
+ _encode_video_to_file_like(
+ frames,
+ frame_rate,
+ format,
+ _pybind_ops.create_file_like_context(file_like, True), # True means for writing
+ crf,
+ )
+
+
def get_frames_at_indices(
decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, list[int]]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -254,16 +288,6 @@ def encode_audio_to_file_abstract(
return
-@register_fake("torchcodec_ns::encode_video_to_file")
-def encode_video_to_file_abstract(
- frames: torch.Tensor,
- frame_rate: int,
- filename: str,
- crf: Optional[int] = None,
-) -> None:
- return
-
-
@register_fake("torchcodec_ns::encode_audio_to_tensor")
def encode_audio_to_tensor_abstract(
samples: torch.Tensor,
@@ -289,6 +313,37 @@ def _encode_audio_to_file_like_abstract(
return
+@register_fake("torchcodec_ns::encode_video_to_file")
+def encode_video_to_file_abstract(
+ frames: torch.Tensor,
+ frame_rate: int,
+ filename: str,
+ crf: Optional[int],
+) -> None:
+ return
+
+
+@register_fake("torchcodec_ns::encode_video_to_tensor")
+def encode_video_to_tensor_abstract(
+ frames: torch.Tensor,
+ frame_rate: int,
+ format: str,
+ crf: Optional[int],
+) -> torch.Tensor:
+ return torch.empty([], dtype=torch.long)
+
+
+@register_fake("torchcodec_ns::_encode_video_to_file_like")
+def _encode_video_to_file_like_abstract(
+ frames: torch.Tensor,
+ frame_rate: int,
+ format: str,
+ file_like_context: int,
+ crf: Optional[int] = None,
+) -> None:
+ return
+
+
@register_fake("torchcodec_ns::create_from_tensor")
def create_from_tensor_abstract(
video_tensor: torch.Tensor, seek_mode: Optional[str]
@@ -304,7 +359,7 @@ def _add_video_stream_abstract(
dimension_order: Optional[str] = None,
stream_index: Optional[int] = None,
device: str = "cpu",
- device_variant: str = "default",
+ device_variant: str = "ffmpeg",
transform_specs: str = "",
custom_frame_mappings: Optional[
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
@@ -322,7 +377,7 @@ def add_video_stream_abstract(
dimension_order: Optional[str] = None,
stream_index: Optional[int] = None,
device: str = "cpu",
- device_variant: str = "default",
+ device_variant: str = "ffmpeg",
transform_specs: str = "",
custom_frame_mappings: Optional[
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
@@ -496,3 +551,8 @@ def scan_all_streams_to_update_metadata_abstract(decoder: torch.Tensor) -> None:
def get_ffmpeg_library_versions():
versions_json = _get_json_ffmpeg_library_versions()
return json.loads(versions_json)
+
+
+@register_fake("torchcodec_ns::_get_backend_details")
+def _get_backend_details_abstract(decoder: torch.Tensor) -> str:
+ return ""
diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py
index 331c7ba79..130927c2e 100644
--- a/src/torchcodec/decoders/_video_decoder.py
+++ b/src/torchcodec/decoders/_video_decoder.py
@@ -147,16 +147,6 @@ def __init__(
device = str(device)
device_variant = _get_cuda_backend()
- if device_variant == "ffmpeg":
- # TODONVDEC P2 rename 'default' into 'ffmpeg' everywhere.
- device_variant = "default"
-
- # Legacy support for device="cuda:0:beta" syntax
- # TODONVDEC P2: remove support for this everywhere. This will require
- # updating our tests.
- if device == "cuda:0:beta":
- device = "cuda:0"
- device_variant = "beta"
core.add_video_stream(
self._decoder,
diff --git a/test/generate_reference_resources.py b/test/generate_reference_resources.py
index 5ae062111..953fb996e 100644
--- a/test/generate_reference_resources.py
+++ b/test/generate_reference_resources.py
@@ -6,17 +6,20 @@
import subprocess
from pathlib import Path
+from typing import Optional
import numpy as np
import torch
from PIL import Image
+from .utils import AV1_VIDEO, H265_VIDEO, NASA_VIDEO, TestVideo
+
# Run this script to update the resources used in unit tests. The resources are all derived
# from source media already checked into the repo.
-def convert_image_to_tensor(image_path):
+def convert_image_to_tensor(image_path: str) -> None:
image_path = Path(image_path)
if not image_path.exists():
return
@@ -31,26 +34,56 @@ def convert_image_to_tensor(image_path):
image_path.unlink()
-def get_frame_by_index(video_path, frame, output_path, stream):
+def generate_frame_by_index(
+ video: TestVideo,
+ *,
+ frame_index: int,
+ stream_index: int,
+ filters: Optional[str] = None,
+) -> None:
+ # Note that we are using 0-based index naming. As a result, we are
+ # generating files one-by-one, giving the actual file name that we want.
+ # ffmpeg does have an option to generate multiple files for us, but it uses
+ # 1-based indexing. We can't use 1-based indexing because we want to match
+ # the 0-based indexing in our tests.
+ base_path = video.get_base_path_by_index(
+ frame_index, stream_index=stream_index, filters=filters
+ )
+ output_bmp = f"{base_path}.bmp"
+
+ # Note that we have an exlicit format conversion to rgb24 in our filtergraph specification,
+ # which always happens BEFORE any of the filters that we receive as input. We do this to
+ # ensure that the color conversion happens BEFORE the filters, matching the behavior of the
+ # torchcodec filtergraph implementation.
+ #
+ # Not doing this would result in the color conversion happening AFTER the filters, which
+ # would result in different color values for the same frame.
+ filtergraph = f"select='eq(n\\,{frame_index})',format=rgb24"
+ if filters is not None:
+ filtergraph = filtergraph + f",{filters}"
+
cmd = [
"ffmpeg",
"-y",
"-i",
- video_path,
+ video.path,
"-map",
- f"0:{stream}",
+ f"0:{stream_index}",
"-vf",
- f"select=eq(n\\,{frame})",
- "-vsync",
- "vfr",
- "-q:v",
- "2",
- output_path,
+ filtergraph,
+ "-fps_mode",
+ "passthrough",
+ "-update",
+ "1",
+ output_bmp,
]
subprocess.run(cmd, check=True)
+ convert_image_to_tensor(output_bmp)
-def get_frame_by_timestamp(video_path, timestamp, output_path):
+def generate_frame_by_timestamp(
+ video_path: str, timestamp: float, output_path: str
+) -> None:
cmd = [
"ffmpeg",
"-y",
@@ -63,60 +96,58 @@ def get_frame_by_timestamp(video_path, timestamp, output_path):
output_path,
]
subprocess.run(cmd, check=True)
+ convert_image_to_tensor(output_path)
-def main():
- SCRIPT_DIR = Path(__file__).resolve().parent
- TORCHCODEC_PATH = SCRIPT_DIR.parent
- RESOURCES_DIR = TORCHCODEC_PATH / "test" / "resources"
- VIDEO_PATH = RESOURCES_DIR / "nasa_13013.mp4"
-
- # Last generated with ffmpeg version 4.3
- #
+def generate_nasa_13013_references():
# Note: The naming scheme used here must match the naming scheme used to load
# tensors in ./utils.py.
- STREAMS = [0, 3]
- FRAMES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 20, 25, 30, 35, 386, 387, 388, 389]
- for stream in STREAMS:
- for frame in FRAMES:
- # Note that we are using 0-based index naming. Asking ffmpeg to number output
- # frames would result in 1-based index naming. We enforce 0-based index naming
- # so that the name of reference frames matches the index when accessing that
- # frame in the Python decoder.
- output_bmp = f"{VIDEO_PATH}.stream{stream}.frame{frame:06d}.bmp"
- get_frame_by_index(VIDEO_PATH, frame, output_bmp, stream=stream)
- convert_image_to_tensor(output_bmp)
+ streams = [0, 3]
+ frames = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 20, 25, 30, 35, 386, 387, 388, 389]
+ for stream in streams:
+ for frame in frames:
+ generate_frame_by_index(NASA_VIDEO, frame_index=frame, stream_index=stream)
# Extract individual frames at specific timestamps, including the last frame of the video.
seek_timestamp = [6.0, 6.1, 10.0, 12.979633]
timestamp_name = [f"{seek_timestamp:06f}" for seek_timestamp in seek_timestamp]
for timestamp, name in zip(seek_timestamp, timestamp_name):
- output_bmp = f"{VIDEO_PATH}.time{name}.bmp"
- get_frame_by_timestamp(VIDEO_PATH, timestamp, output_bmp)
- convert_image_to_tensor(output_bmp)
+ output_bmp = f"{NASA_VIDEO.path}.time{name}.bmp"
+ generate_frame_by_timestamp(NASA_VIDEO.path, timestamp, output_bmp)
+
+ # Extract frames with specific filters. We have tests that assume these exact filters.
+ frames = [0, 15, 200, 389]
+ crop_filter = "crop=300:200:50:35:exact=1"
+ for frame in frames:
+ generate_frame_by_index(
+ NASA_VIDEO, frame_index=frame, stream_index=3, filters=crop_filter
+ )
+
+def generate_h265_video_references():
# This video was generated by running the following:
# conda install -c conda-forge x265
# ./configure --enable-nonfree --enable-gpl --prefix=$(readlink -f ../bin) --enable-libx265 --enable-rpath --extra-ldflags=-Wl,-rpath=$CONDA_PREFIX/lib --enable-filter=drawtext --enable-libfontconfig --enable-libfreetype --enable-libharfbuzz
# ffmpeg -f lavfi -i color=size=128x128:duration=1:rate=10:color=blue -vf "drawtext=fontsize=30:fontcolor=white:x=(w-text_w)/2:y=(h-text_h)/2:text='Frame %{frame_num}'" -vcodec libx265 -pix_fmt yuv420p -g 2 -crf 10 h265_video.mp4 -y
# Note that this video only has 1 stream, at index 0.
- VIDEO_PATH = RESOURCES_DIR / "h265_video.mp4"
- FRAMES = [5]
- for frame in FRAMES:
- output_bmp = f"{VIDEO_PATH}.stream0.frame{frame:06d}.bmp"
- get_frame_by_index(VIDEO_PATH, frame, output_bmp, stream=0)
- convert_image_to_tensor(output_bmp)
+ frames = [5]
+ for frame in frames:
+ generate_frame_by_index(H265_VIDEO, frame_index=frame, stream_index=0)
+
+def generate_av1_video_references():
# This video was generated by running the following:
# ffmpeg -f lavfi -i testsrc=duration=5:size=640x360:rate=25,format=yuv420p -c:v libaom-av1 -crf 30 -colorspace bt709 -color_primaries bt709 -color_trc bt709 av1_video.mkv
# Note that this video only has 1 stream, at index 0.
- VIDEO_PATH = RESOURCES_DIR / "av1_video.mkv"
- FRAMES = [10]
+ frames = [10]
+ for frame in frames:
+ generate_frame_by_index(AV1_VIDEO, frame_index=frame, stream_index=0)
- for frame in FRAMES:
- output_bmp = f"{VIDEO_PATH}.stream0.frame{frame:06d}.bmp"
- get_frame_by_index(VIDEO_PATH, frame, output_bmp, stream=0)
- convert_image_to_tensor(output_bmp)
+
+def main():
+ generate_nasa_13013_references()
+ generate_h265_video_references()
+ generate_av1_video_references()
if __name__ == "__main__":
diff --git a/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000000.pt b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000000.pt
new file mode 100644
index 000000000..c69af7cee
Binary files /dev/null and b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000000.pt differ
diff --git a/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000015.pt b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000015.pt
new file mode 100644
index 000000000..c4f0bdc95
Binary files /dev/null and b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000015.pt differ
diff --git a/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000200.pt b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000200.pt
new file mode 100644
index 000000000..9849cfa99
Binary files /dev/null and b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000200.pt differ
diff --git a/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000389.pt b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000389.pt
new file mode 100644
index 000000000..3451f2a98
Binary files /dev/null and b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000389.pt differ
diff --git a/test/test_decoders.py b/test/test_decoders.py
index 300c953bf..6e08e05a4 100644
--- a/test/test_decoders.py
+++ b/test/test_decoders.py
@@ -35,6 +35,7 @@
H265_10BITS,
H265_VIDEO,
in_fbcode,
+ make_video_decoder,
NASA_AUDIO,
NASA_AUDIO_MP3,
NASA_AUDIO_MP3_44100,
@@ -51,7 +52,6 @@
TEST_SRC_2_720P_MPEG4,
TEST_SRC_2_720P_VP8,
TEST_SRC_2_720P_VP9,
- unsplit_device_str,
)
@@ -179,13 +179,12 @@ def test_create_fails(self):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_getitem_int(self, num_ffmpeg_threads, device, seek_mode):
- decoder = VideoDecoder(
+ decoder, device = make_video_decoder(
NASA_VIDEO.path,
num_ffmpeg_threads=num_ffmpeg_threads,
device=device,
seek_mode=seek_mode,
)
- device, _ = unsplit_device_str(device)
ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device)
ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device)
@@ -230,8 +229,9 @@ def test_getitem_numpy_int(self):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_getitem_slice(self, device, seek_mode):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
- device, _ = unsplit_device_str(device)
+ decoder, device = make_video_decoder(
+ NASA_VIDEO.path, device=device, seek_mode=seek_mode
+ )
# ensure that the degenerate case of a range of size 1 works
@@ -391,7 +391,9 @@ def test_device_instance(self):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_getitem_fails(self, device, seek_mode):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
+ decoder, _ = make_video_decoder(
+ NASA_VIDEO.path, device=device, seek_mode=seek_mode
+ )
with pytest.raises(IndexError, match="Invalid frame index"):
frame = decoder[1000] # noqa
@@ -408,8 +410,9 @@ def test_getitem_fails(self, device, seek_mode):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_iteration(self, device, seek_mode):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
- device, _ = unsplit_device_str(device)
+ decoder, device = make_video_decoder(
+ NASA_VIDEO.path, device=device, seek_mode=seek_mode
+ )
ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device)
ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device)
@@ -456,8 +459,9 @@ def test_iteration_slow(self):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_get_frame_at(self, device, seek_mode):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
- device, _ = unsplit_device_str(device)
+ decoder, device = make_video_decoder(
+ NASA_VIDEO.path, device=device, seek_mode=seek_mode
+ )
ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9).to(device)
frame9 = decoder.get_frame_at(9)
@@ -494,7 +498,7 @@ def test_get_frame_at(self, device, seek_mode):
@pytest.mark.parametrize("device", all_supported_devices())
def test_get_frame_at_tuple_unpacking(self, device):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device)
+ decoder, _ = make_video_decoder(NASA_VIDEO.path, device=device)
frame = decoder.get_frame_at(50)
data, pts, duration = decoder.get_frame_at(50)
@@ -506,7 +510,9 @@ def test_get_frame_at_tuple_unpacking(self, device):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_get_frame_at_fails(self, device, seek_mode):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
+ decoder, _ = make_video_decoder(
+ NASA_VIDEO.path, device=device, seek_mode=seek_mode
+ )
with pytest.raises(
IndexError,
@@ -520,8 +526,9 @@ def test_get_frame_at_fails(self, device, seek_mode):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_get_frames_at(self, device, seek_mode):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
- device, _ = unsplit_device_str(device)
+ decoder, device = make_video_decoder(
+ NASA_VIDEO.path, device=device, seek_mode=seek_mode
+ )
# test positive and negative frame index
frames = decoder.get_frames_at([35, 25, -1, -2])
@@ -572,7 +579,9 @@ def test_get_frames_at(self, device, seek_mode):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_get_frames_at_fails(self, device, seek_mode):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
+ decoder, _ = make_video_decoder(
+ NASA_VIDEO.path, device=device, seek_mode=seek_mode
+ )
with pytest.raises(
IndexError,
@@ -596,8 +605,7 @@ def test_get_frame_at_av1(self, device):
if "cuda" in device and in_fbcode():
pytest.skip("decoding on CUDA is not supported internally")
- decoder = VideoDecoder(AV1_VIDEO.path, device=device)
- device, _ = unsplit_device_str(device)
+ decoder, device = make_video_decoder(AV1_VIDEO.path, device=device)
ref_frame10 = AV1_VIDEO.get_frame_data_by_index(10)
ref_frame_info10 = AV1_VIDEO.get_frame_info(10)
decoded_frame10 = decoder.get_frame_at(10)
@@ -608,8 +616,9 @@ def test_get_frame_at_av1(self, device):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_get_frame_played_at(self, device, seek_mode):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
- device, _ = unsplit_device_str(device)
+ decoder, device = make_video_decoder(
+ NASA_VIDEO.path, device=device, seek_mode=seek_mode
+ )
ref_frame_played_at_6 = NASA_VIDEO.get_frame_data_by_index(180).to(device)
assert_frames_equal(
@@ -638,7 +647,9 @@ def test_get_frame_played_at_h265(self):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_get_frame_played_at_fails(self, device, seek_mode):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
+ decoder, _ = make_video_decoder(
+ NASA_VIDEO.path, device=device, seek_mode=seek_mode
+ )
with pytest.raises(IndexError, match="Invalid pts in seconds"):
frame = decoder.get_frame_played_at(-1.0) # noqa
@@ -650,8 +661,9 @@ def test_get_frame_played_at_fails(self, device, seek_mode):
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
@pytest.mark.parametrize("input_type", ("list", "tensor"))
def test_get_frames_played_at(self, device, seek_mode, input_type):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
- device, _ = unsplit_device_str(device)
+ decoder, device = make_video_decoder(
+ NASA_VIDEO.path, device=device, seek_mode=seek_mode
+ )
# Note: We know the frame at ~0.84s has index 25, the one at 1.16s has
# index 35. We use those indices as reference to test against.
@@ -693,7 +705,9 @@ def test_get_frames_played_at(self, device, seek_mode, input_type):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_get_frames_played_at_fails(self, device, seek_mode):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
+ decoder, _ = make_video_decoder(
+ NASA_VIDEO.path, device=device, seek_mode=seek_mode
+ )
with pytest.raises(RuntimeError, match="must be greater than or equal to"):
decoder.get_frames_played_at([-1])
@@ -710,13 +724,12 @@ def test_get_frames_played_at_fails(self, device, seek_mode):
@pytest.mark.parametrize("stream_index", [0, 3, None])
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_get_frames_in_range(self, stream_index, device, seek_mode):
- decoder = VideoDecoder(
+ decoder, device = make_video_decoder(
NASA_VIDEO.path,
stream_index=stream_index,
device=device,
seek_mode=seek_mode,
)
- device, _ = unsplit_device_str(device)
# test degenerate case where we only actually get 1 frame
ref_frames9 = NASA_VIDEO.get_frame_data_by_range(
@@ -815,13 +828,12 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode):
- decoder = VideoDecoder(
+ decoder, device = make_video_decoder(
NASA_VIDEO.path,
stream_index=3,
device=device,
seek_mode=seek_mode,
)
- device, _ = unsplit_device_str(device)
# high range ends get capped to num_frames
frames387_389 = decoder.get_frames_in_range(start=387, stop=1000)
@@ -891,13 +903,12 @@ def test_get_frames_with_missing_num_frames_metadata(
# Set the return value of the mock to be the mock_stream_dict
mock_get_stream_json_metadata.return_value = json.dumps(mock_stream_dict)
- decoder = VideoDecoder(
+ decoder, device = make_video_decoder(
NASA_VIDEO.path,
stream_index=3,
device=device,
seek_mode=seek_mode,
)
- device, _ = unsplit_device_str(device)
assert decoder.metadata.num_frames_from_header is None
assert decoder.metadata.num_frames_from_content is None
@@ -932,7 +943,7 @@ def test_get_frames_with_missing_num_frames_metadata(
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_dimension_order(self, dimension_order, frame_getter, device, seek_mode):
- decoder = VideoDecoder(
+ decoder, _ = make_video_decoder(
NASA_VIDEO.path,
dimension_order=dimension_order,
device=device,
@@ -960,13 +971,12 @@ def test_dimension_order_fails(self):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode):
- decoder = VideoDecoder(
+ decoder, device = make_video_decoder(
NASA_VIDEO.path,
stream_index=stream_index,
device=device,
seek_mode=seek_mode,
)
- device, _ = unsplit_device_str(device)
# Note that we are comparing the results of VideoDecoder's method:
# get_frames_played_in_range()
@@ -1100,7 +1110,9 @@ def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode):
@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
def test_get_frames_by_pts_in_range_fails(self, device, seek_mode):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
+ decoder, _ = make_video_decoder(
+ NASA_VIDEO.path, device=device, seek_mode=seek_mode
+ )
with pytest.raises(ValueError, match="Invalid start seconds"):
frame = decoder.get_frames_played_in_range(100.0, 1.0) # noqa
@@ -1113,7 +1125,9 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode):
@pytest.mark.parametrize("device", all_supported_devices())
def test_get_key_frame_indices(self, device):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact")
+ decoder, _ = make_video_decoder(
+ NASA_VIDEO.path, device=device, seek_mode="exact"
+ )
key_frame_indices = decoder._get_key_frame_indices()
# The key frame indices were generated from the following command:
@@ -1134,7 +1148,9 @@ def test_get_key_frame_indices(self, device):
key_frame_indices, nasa_reference_key_frame_indices, atol=0, rtol=0
)
- decoder = VideoDecoder(AV1_VIDEO.path, device=device, seek_mode="exact")
+ decoder, _ = make_video_decoder(
+ AV1_VIDEO.path, device=device, seek_mode="exact"
+ )
key_frame_indices = decoder._get_key_frame_indices()
# $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/av1_video.mkv | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt
@@ -1144,7 +1160,9 @@ def test_get_key_frame_indices(self, device):
key_frame_indices, av1_reference_key_frame_indices, atol=0, rtol=0
)
- decoder = VideoDecoder(H265_VIDEO.path, device=device, seek_mode="exact")
+ decoder, _ = make_video_decoder(
+ H265_VIDEO.path, device=device, seek_mode="exact"
+ )
key_frame_indices = decoder._get_key_frame_indices()
# ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/h265_video.mp4 | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt
@@ -1158,8 +1176,7 @@ def test_get_key_frame_indices(self, device):
@pytest.mark.skipif(in_fbcode(), reason="Compile test fails internally.")
@pytest.mark.parametrize("device", all_supported_devices())
def test_compile(self, device):
- decoder = VideoDecoder(NASA_VIDEO.path, device=device)
- device, _ = unsplit_device_str(device)
+ decoder, device = make_video_decoder(NASA_VIDEO.path, device=device)
@contextlib.contextmanager
def restore_capture_scalar_outputs():
@@ -1297,17 +1314,17 @@ def test_10bit_videos(self, device, asset):
# This just validates that we can decode 10-bit videos.
# TODO validate against the ref that the decoded frames are correct
- if device == "cuda:0:beta" and asset is H264_10BITS:
+ if device == "cuda:beta" and asset is H264_10BITS:
# This fails on the BETA interface with:
#
# RuntimeError: Codec configuration not supported on this GPU.
# Codec: 4, chroma format: 1, bit depth: 10
#
- # It works on the default interface because FFmpeg fallsback to the
+ # It works on the ffmpeg interface because FFmpeg fallsback to the
# CPU, while the BETA interface doesn't.
pytest.skip("Asset not supported by NVDEC")
- decoder = VideoDecoder(asset.path, device=device)
+ decoder, _ = make_video_decoder(asset.path, device=device)
decoder.get_frame_at(10)
def setup_frame_mappings(tmp_path, file, stream_index):
@@ -1346,13 +1363,12 @@ def test_custom_frame_mappings_json_and_bytes(
if hasattr(custom_frame_mappings, "read")
else contextlib.nullcontext()
) as custom_frame_mappings:
- decoder = VideoDecoder(
+ decoder, device = make_video_decoder(
NASA_VIDEO.path,
stream_index=stream_index,
device=device,
custom_frame_mappings=custom_frame_mappings,
)
- device, _ = unsplit_device_str(device)
frame_0 = decoder.get_frame_at(0)
frame_5 = decoder.get_frame_at(5)
assert_frames_equal(
@@ -1483,9 +1499,8 @@ def test_beta_cuda_interface_get_frame_at(
pytest.skip("AV1 CUDA not supported internally")
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
- beta_decoder = VideoDecoder(
- asset.path, device="cuda:0:beta", seek_mode=seek_mode
- )
+ with set_cuda_backend("beta"):
+ beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
assert ref_decoder.metadata == beta_decoder.metadata
@@ -1531,9 +1546,8 @@ def test_beta_cuda_interface_get_frames_at(
pytest.skip("AV1 CUDA not supported internally")
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
- beta_decoder = VideoDecoder(
- asset.path, device="cuda:0:beta", seek_mode=seek_mode
- )
+ with set_cuda_backend("beta"):
+ beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
assert ref_decoder.metadata == beta_decoder.metadata
@@ -1577,9 +1591,8 @@ def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode):
pytest.skip("AV1 CUDA not supported internally")
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
- beta_decoder = VideoDecoder(
- asset.path, device="cuda:0:beta", seek_mode=seek_mode
- )
+ with set_cuda_backend("beta"):
+ beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
assert ref_decoder.metadata == beta_decoder.metadata
@@ -1620,9 +1633,8 @@ def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode):
pytest.skip("AV1 CUDA not supported internally")
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
- beta_decoder = VideoDecoder(
- asset.path, device="cuda:0:beta", seek_mode=seek_mode
- )
+ with set_cuda_backend("beta"):
+ beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
assert ref_decoder.metadata == beta_decoder.metadata
@@ -1664,9 +1676,8 @@ def test_beta_cuda_interface_backwards(self, asset, seek_mode):
pytest.skip("AV1 CUDA not supported internally")
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
- beta_decoder = VideoDecoder(
- asset.path, device="cuda:0:beta", seek_mode=seek_mode
- )
+ with set_cuda_backend("beta"):
+ beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
assert ref_decoder.metadata == beta_decoder.metadata
@@ -1690,17 +1701,19 @@ def test_beta_cuda_interface_backwards(self, asset, seek_mode):
assert beta_frame.duration_seconds == ref_frame.duration_seconds
@needs_cuda
- def test_beta_cuda_interface_small_h265(self):
- # Test to illustrate current difference in behavior between the BETA and
- # the default interface: this video isn't supported by NVDEC, but in the
- # default interface, FFMPEG fallsback to the CPU while we don't.
+ def test_beta_cuda_interface_cpu_fallback(self):
+ # Non-regression test for the CPU fallback behavior of the BETA CUDA
+ # interface.
+ # We know that the H265_VIDEO asset isn't supported by NVDEC, its
+ # dimensions are too small. We also know that the FFmpeg CUDA interface
+ # fallbacks to the CPU path in such cases. We assert that we fall back
+ # to the CPU path, too.
+
+ ffmpeg = VideoDecoder(H265_VIDEO.path, device="cuda").get_frame_at(0)
+ with set_cuda_backend("beta"):
+ beta = VideoDecoder(H265_VIDEO.path, device="cuda").get_frame_at(0)
- VideoDecoder(H265_VIDEO.path, device="cuda").get_frame_at(0)
- with pytest.raises(
- RuntimeError,
- match="Video is too small in at least one dimension. Provided: 128x128 vs supported:144x144",
- ):
- VideoDecoder(H265_VIDEO.path, device="cuda:0:beta").get_frame_at(0)
+ torch.testing.assert_close(ffmpeg.data, beta.data, rtol=0, atol=0)
@needs_cuda
def test_beta_cuda_interface_error(self):
@@ -1725,21 +1738,10 @@ def test_set_cuda_backend(self):
with set_cuda_backend("BETA"):
assert _get_cuda_backend() == "beta"
- def assert_decoder_uses(decoder, *, expected_backend):
- # Assert that a decoder instance is using a given backend.
- #
- # We know H265_VIDEO fails on the BETA backend while it works on the
- # ffmpeg one.
- if expected_backend == "ffmpeg":
- decoder.get_frame_at(0) # this would fail if this was BETA
- else:
- with pytest.raises(RuntimeError, match="Video is too small"):
- decoder.get_frame_at(0)
-
# Check that the default is the ffmpeg backend
assert _get_cuda_backend() == "ffmpeg"
dec = VideoDecoder(H265_VIDEO.path, device="cuda")
- assert_decoder_uses(dec, expected_backend="ffmpeg")
+ assert _core._get_backend_details(dec._decoder).startswith("FFmpeg CUDA")
# Check the setting "beta" effectively uses the BETA backend.
# We also show that the affects decoder creation only. When the decoder
@@ -1748,9 +1750,9 @@ def assert_decoder_uses(decoder, *, expected_backend):
with set_cuda_backend("beta"):
dec = VideoDecoder(H265_VIDEO.path, device="cuda")
assert _get_cuda_backend() == "ffmpeg"
- assert_decoder_uses(dec, expected_backend="beta")
+ assert _core._get_backend_details(dec._decoder).startswith("Beta CUDA")
with set_cuda_backend("ffmpeg"):
- assert_decoder_uses(dec, expected_backend="beta")
+ assert _core._get_backend_details(dec._decoder).startswith("Beta CUDA")
# Hacky way to ensure passing "cuda:1" is supported by both backends. We
# just check that there's an error when passing cuda:N where N is too
diff --git a/test/test_encoders.py b/test/test_encoders.py
index f8b5b3519..c5946654d 100644
--- a/test/test_encoders.py
+++ b/test/test_encoders.py
@@ -16,6 +16,7 @@
from .utils import (
assert_tensor_close_on_at_least,
get_ffmpeg_major_version,
+ get_ffmpeg_minor_version,
in_fbcode,
IS_WINDOWS,
NASA_AUDIO_MP3,
@@ -23,6 +24,11 @@
TestContainerFile,
)
+IS_WINDOWS_WITH_FFMPEG_LE_70 = IS_WINDOWS and (
+ get_ffmpeg_major_version() < 7
+ or (get_ffmpeg_major_version() == 7 and get_ffmpeg_minor_version() == 0)
+)
+
@pytest.fixture
def with_ffmpeg_debug_logs():
@@ -155,7 +161,11 @@ def test_bad_input_parametrized(self, method, tmp_path):
avcodec_open2_failed_msg = "avcodec_open2 failed: Invalid argument"
with pytest.raises(
RuntimeError,
- match=avcodec_open2_failed_msg if IS_WINDOWS else "invalid sample rate=10",
+ match=(
+ avcodec_open2_failed_msg
+ if IS_WINDOWS_WITH_FFMPEG_LE_70
+ else "invalid sample rate=10"
+ ),
):
getattr(decoder, method)(**valid_params)
@@ -164,14 +174,18 @@ def test_bad_input_parametrized(self, method, tmp_path):
)
with pytest.raises(
RuntimeError,
- match=avcodec_open2_failed_msg if IS_WINDOWS else "invalid sample rate=10",
+ match=(
+ avcodec_open2_failed_msg
+ if IS_WINDOWS_WITH_FFMPEG_LE_70
+ else "invalid sample rate=10"
+ ),
):
getattr(decoder, method)(sample_rate=10, **valid_params)
with pytest.raises(
RuntimeError,
match=(
avcodec_open2_failed_msg
- if IS_WINDOWS
+ if IS_WINDOWS_WITH_FFMPEG_LE_70
else "invalid sample rate=99999999"
),
):
@@ -192,7 +206,7 @@ def test_bad_input_parametrized(self, method, tmp_path):
for num_channels in (0, 3):
match = (
avcodec_open2_failed_msg
- if IS_WINDOWS
+ if IS_WINDOWS_WITH_FFMPEG_LE_70
else re.escape(
f"Desired number of channels ({num_channels}) is not supported"
)
@@ -316,7 +330,7 @@ def test_against_cli(
else:
rtol, atol = None, None
- if IS_WINDOWS and format == "mp3":
+ if IS_WINDOWS_WITH_FFMPEG_LE_70 and format == "mp3":
# We're getting a "Could not open input file" on Windows mp3 files when decoding.
# TODO: https://github.com/pytorch/torchcodec/issues/837
return
@@ -370,7 +384,7 @@ def test_against_to_file(
else:
raise ValueError(f"Unknown method: {method}")
- if not (IS_WINDOWS and format == "mp3"):
+ if not (IS_WINDOWS_WITH_FFMPEG_LE_70 and format == "mp3"):
# We're getting a "Could not open input file" on Windows mp3 files when decoding.
# TODO: https://github.com/pytorch/torchcodec/issues/837
torch.testing.assert_close(
diff --git a/test/test_ops.py b/test/test_ops.py
index fddd4043c..627829689 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-import contextlib
import io
import os
from functools import partial
@@ -29,6 +28,8 @@
create_from_tensor,
encode_audio_to_file,
encode_video_to_file,
+ encode_video_to_file_like,
+ encode_video_to_tensor,
get_ffmpeg_library_versions,
get_frame_at_index,
get_frame_at_pts,
@@ -41,6 +42,7 @@
get_next_frame,
seek_to_pts,
)
+from torchcodec.decoders import VideoDecoder
from .utils import (
all_supported_devices,
@@ -607,103 +609,6 @@ def test_color_conversion_library(self, color_conversion_library):
)
assert_frames_equal(frame_time6, reference_frame_time6)
- # We choose arbitrary values for width and height scaling to get better
- # test coverage. Some pairs upscale the image while others downscale it.
- @pytest.mark.parametrize(
- "width_scaling_factor,height_scaling_factor",
- ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)),
- )
- @pytest.mark.parametrize("input_video", [NASA_VIDEO])
- def test_color_conversion_library_with_scaling(
- self, input_video, width_scaling_factor, height_scaling_factor
- ):
- decoder = create_from_file(str(input_video.path))
- add_video_stream(decoder)
- metadata = get_json_metadata(decoder)
- metadata_dict = json.loads(metadata)
- assert metadata_dict["width"] == input_video.width
- assert metadata_dict["height"] == input_video.height
-
- target_height = int(input_video.height * height_scaling_factor)
- target_width = int(input_video.width * width_scaling_factor)
- if width_scaling_factor != 1.0:
- assert target_width != input_video.width
- if height_scaling_factor != 1.0:
- assert target_height != input_video.height
-
- filtergraph_decoder = create_from_file(str(input_video.path))
- _add_video_stream(
- filtergraph_decoder,
- transform_specs=f"resize, {target_height}, {target_width}",
- color_conversion_library="filtergraph",
- )
- filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder)
-
- swscale_decoder = create_from_file(str(input_video.path))
- _add_video_stream(
- swscale_decoder,
- transform_specs=f"resize, {target_height}, {target_width}",
- color_conversion_library="swscale",
- )
- swscale_frame0, _, _ = get_next_frame(swscale_decoder)
- assert_frames_equal(filtergraph_frame0, swscale_frame0)
- assert filtergraph_frame0.shape == (3, target_height, target_width)
-
- @needs_cuda
- def test_scaling_on_cuda_fails(self):
- decoder = create_from_file(str(NASA_VIDEO.path))
- with pytest.raises(
- RuntimeError,
- match="Transforms are only supported for CPU devices.",
- ):
- add_video_stream(decoder, device="cuda", transform_specs="resize, 100, 100")
-
- def test_transform_fails(self):
- decoder = create_from_file(str(NASA_VIDEO.path))
- with pytest.raises(
- RuntimeError,
- match="Invalid transform spec",
- ):
- add_video_stream(decoder, transform_specs=";")
-
- with pytest.raises(
- RuntimeError,
- match="Invalid transform name",
- ):
- add_video_stream(decoder, transform_specs="invalid, 1, 2")
-
- def test_resize_transform_fails(self):
- decoder = create_from_file(str(NASA_VIDEO.path))
- with pytest.raises(
- RuntimeError,
- match="must have 3 elements",
- ):
- add_video_stream(decoder, transform_specs="resize, 100, 100, 100")
-
- with pytest.raises(
- RuntimeError,
- match="must be a positive integer",
- ):
- add_video_stream(decoder, transform_specs="resize, -10, 100")
-
- with pytest.raises(
- RuntimeError,
- match="must be a positive integer",
- ):
- add_video_stream(decoder, transform_specs="resize, 100, 0")
-
- with pytest.raises(
- RuntimeError,
- match="cannot be converted to an int",
- ):
- add_video_stream(decoder, transform_specs="resize, blah, 100")
-
- with pytest.raises(
- RuntimeError,
- match="out of range",
- ):
- add_video_stream(decoder, transform_specs="resize, 100, 1000000000000")
-
@pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW"))
@pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale"))
def test_color_conversion_library_with_dimension_order(
@@ -743,86 +648,6 @@ def test_color_conversion_library_with_dimension_order(
assert frames.shape[1:] == expected_shape
assert_frames_equal(frames[0], frame0_ref)
- @pytest.mark.parametrize(
- "width_scaling_factor,height_scaling_factor",
- ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)),
- )
- @pytest.mark.parametrize("width", [30, 32, 300])
- @pytest.mark.parametrize("height", [128])
- def test_color_conversion_library_with_generated_videos(
- self, tmp_path, width, height, width_scaling_factor, height_scaling_factor
- ):
-
- # We consider filtergraph to be the reference color conversion library.
- # However the video decoder sometimes uses swscale as that is faster.
- # The exact color conversion library used is an implementation detail
- # of the video decoder and depends on the video's width.
- #
- # In this test we compare the output of filtergraph (which is the
- # reference) with the output of the video decoder (which may use
- # swscale if it chooses for certain video widths) to make sure they are
- # always the same.
- video_path = f"{tmp_path}/frame_numbers_{width}x{height}.mp4"
- # We don't specify a particular encoder because the ffmpeg binary could
- # be configured with different encoders. For the purposes of this test,
- # the actual encoder is irrelevant.
- with contextlib.ExitStack() as stack:
- ffmpeg_cli = "ffmpeg"
-
- if os.environ.get("IN_FBCODE_TORCHCODEC") == "1":
- import importlib.resources
-
- ffmpeg_cli = stack.enter_context(
- importlib.resources.path(__package__, "ffmpeg")
- )
-
- command = [
- ffmpeg_cli,
- "-y",
- "-f",
- "lavfi",
- "-i",
- "color=blue",
- "-pix_fmt",
- "yuv420p",
- "-s",
- f"{width}x{height}",
- "-frames:v",
- "1",
- video_path,
- ]
- subprocess.check_call(command)
-
- decoder = create_from_file(str(video_path))
- add_video_stream(decoder)
- metadata = get_json_metadata(decoder)
- metadata_dict = json.loads(metadata)
- assert metadata_dict["width"] == width
- assert metadata_dict["height"] == height
-
- target_height = int(height * height_scaling_factor)
- target_width = int(width * width_scaling_factor)
- if width_scaling_factor != 1.0:
- assert target_width != width
- if height_scaling_factor != 1.0:
- assert target_height != height
-
- filtergraph_decoder = create_from_file(str(video_path))
- _add_video_stream(
- filtergraph_decoder,
- transform_specs=f"resize, {target_height}, {target_width}",
- color_conversion_library="filtergraph",
- )
- filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder)
-
- auto_decoder = create_from_file(str(video_path))
- add_video_stream(
- auto_decoder,
- transform_specs=f"resize, {target_height}, {target_width}",
- )
- auto_frame0, _, _ = get_next_frame(auto_decoder)
- assert_frames_equal(filtergraph_frame0, auto_frame0)
-
@needs_cuda
def test_cuda_decoder(self):
decoder = create_from_file(str(NASA_VIDEO.path))
@@ -1327,7 +1152,8 @@ def test_bad_input(self, tmp_path):
class TestVideoEncoderOps:
-
+ # TODO-VideoEncoder: Test encoding against different memory layouts (ex. test_contiguity)
+ # TODO-VideoEncoder: Parametrize test after moving to test_encoders
def test_bad_input(self, tmp_path):
output_file = str(tmp_path / ".mp4")
@@ -1378,15 +1204,25 @@ def test_bad_input(self, tmp_path):
filename="./bad/path.mp3",
)
- def decode(self, file_path) -> torch.Tensor:
- decoder = create_from_file(str(file_path), seek_mode="approximate")
- add_video_stream(decoder)
- frames, *_ = get_frames_in_range(decoder, start=0, stop=60)
- return frames
+ with pytest.raises(
+ RuntimeError,
+ match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format",
+ ):
+ encode_video_to_tensor(
+ frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8),
+ frame_rate=10,
+ format="bad_format",
+ )
+
+ def decode(self, source=None) -> torch.Tensor:
+ return VideoDecoder(source).get_frames_in_range(start=0, stop=60)
- @pytest.mark.parametrize("format", ("mov", "mp4", "mkv", "webm"))
- def test_video_encoder_round_trip(self, tmp_path, format):
- # Test that decode(encode(decode(asset))) == decode(asset)
+ @pytest.mark.parametrize(
+ "format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow))
+ )
+ @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
+ def test_video_encoder_round_trip(self, tmp_path, format, method):
+ # Test that decode(encode(decode(frames))) == decode(frames)
ffmpeg_version = get_ffmpeg_major_version()
# In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm.
# As a result, we skip the round trip test.
@@ -1398,15 +1234,36 @@ def test_video_encoder_round_trip(self, tmp_path, format):
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
):
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
- asset = TEST_SRC_2_720P
- source_frames = self.decode(str(asset.path)).data
+ source_frames = self.decode(TEST_SRC_2_720P.path).data
+
+ params = dict(
+ frame_rate=30, crf=0
+ ) # Frame rate is fixed with num frames decoded
+ if method == "to_file":
+ encoded_path = str(tmp_path / f"encoder_output.{format}")
+ encode_video_to_file(
+ frames=source_frames,
+ filename=encoded_path,
+ **params,
+ )
+ round_trip_frames = self.decode(encoded_path).data
+ elif method == "to_tensor":
+ encoded_tensor = encode_video_to_tensor(
+ source_frames, format=format, **params
+ )
+ round_trip_frames = self.decode(encoded_tensor).data
+ elif method == "to_file_like":
+ file_like = io.BytesIO()
+ encode_video_to_file_like(
+ frames=source_frames,
+ format=format,
+ file_like=file_like,
+ **params,
+ )
+ round_trip_frames = self.decode(file_like.getvalue()).data
+ else:
+ raise ValueError(f"Unknown method: {method}")
- encoded_path = str(tmp_path / f"encoder_output.{format}")
- frame_rate = 30 # Frame rate is fixed with num frames decoded
- encode_video_to_file(
- frames=source_frames, frame_rate=frame_rate, filename=encoded_path, crf=0
- )
- round_trip_frames = self.decode(encoded_path).data
assert source_frames.shape == round_trip_frames.shape
assert source_frames.dtype == round_trip_frames.dtype
@@ -1422,24 +1279,75 @@ def test_video_encoder_round_trip(self, tmp_path, format):
assert psnr(s_frame, rt_frame) > 30
assert_close(s_frame, rt_frame, atol=atol, rtol=0)
+ @pytest.mark.parametrize(
+ "format",
+ (
+ "mov",
+ "mp4",
+ "avi",
+ "mkv",
+ "flv",
+ "gif",
+ pytest.param("webm", marks=pytest.mark.slow),
+ ),
+ )
+ @pytest.mark.parametrize("method", ("to_tensor", "to_file_like"))
+ def test_against_to_file(self, tmp_path, format, method):
+ # Test that to_file, to_tensor, and to_file_like produce the same results
+ ffmpeg_version = get_ffmpeg_major_version()
+ if format == "webm" and (
+ ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
+ ):
+ pytest.skip("Codec for webm is not available in this FFmpeg installation.")
+
+ source_frames = self.decode(TEST_SRC_2_720P.path).data
+ params = dict(frame_rate=30, crf=0)
+
+ encoded_file = tmp_path / f"output.{format}"
+ encode_video_to_file(frames=source_frames, filename=str(encoded_file), **params)
+
+ if method == "to_tensor":
+ encoded_output = encode_video_to_tensor(
+ source_frames, format=format, **params
+ )
+ else: # to_file_like
+ file_like = io.BytesIO()
+ encode_video_to_file_like(
+ frames=source_frames,
+ file_like=file_like,
+ format=format,
+ **params,
+ )
+ encoded_output = file_like.getvalue()
+
+ torch.testing.assert_close(
+ self.decode(encoded_file).data,
+ self.decode(encoded_output).data,
+ atol=0,
+ rtol=0,
+ )
+
@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")
@pytest.mark.parametrize(
- "format", ("mov", "mp4", "avi", "mkv", "webm", "flv", "gif")
+ "format",
+ (
+ "mov",
+ "mp4",
+ "avi",
+ "mkv",
+ "flv",
+ "gif",
+ pytest.param("webm", marks=pytest.mark.slow),
+ ),
)
def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
ffmpeg_version = get_ffmpeg_major_version()
- if format == "webm":
- if ffmpeg_version == 4:
- pytest.skip(
- "Codec for webm is not available in the FFmpeg4 installation."
- )
- if IS_WINDOWS and ffmpeg_version in (6, 7):
- pytest.skip(
- "Codec for webm is not available in the FFmpeg6/7 installation on Windows."
- )
- asset = TEST_SRC_2_720P
- source_frames = self.decode(str(asset.path)).data
- frame_rate = 30
+ if format == "webm" and (
+ ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
+ ):
+ pytest.skip("Codec for webm is not available in this FFmpeg installation.")
+
+ source_frames = self.decode(TEST_SRC_2_720P.path).data
# Encode with FFmpeg CLI
temp_raw_path = str(tmp_path / "temp_input.raw")
@@ -1447,8 +1355,8 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes())
ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}")
+ frame_rate = 30
crf = 0
- quality_params = ["-crf", str(crf)]
# Some codecs (ex. MPEG4) do not support CRF.
# Flags not supported by the selected codec will be ignored.
ffmpeg_cmd = [
@@ -1464,7 +1372,8 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
str(frame_rate),
"-i",
temp_raw_path,
- *quality_params,
+ "-crf",
+ str(crf),
ffmpeg_encoded_path,
]
subprocess.run(ffmpeg_cmd, check=True)
@@ -1496,6 +1405,82 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
ff_frame, enc_frame, percentage=percentage, atol=2
)
+ def test_to_file_like_custom_file_object(self):
+ """Test with a custom file-like object that implements write and seek."""
+
+ class CustomFileObject:
+ def __init__(self):
+ self._file = io.BytesIO()
+
+ def write(self, data):
+ return self._file.write(data)
+
+ def seek(self, offset, whence=0):
+ return self._file.seek(offset, whence)
+
+ def get_encoded_data(self):
+ return self._file.getvalue()
+
+ source_frames = self.decode(TEST_SRC_2_720P.path).data
+ file_like = CustomFileObject()
+ encode_video_to_file_like(
+ source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like
+ )
+ decoded_samples = self.decode(file_like.get_encoded_data())
+
+ torch.testing.assert_close(
+ decoded_samples.data,
+ source_frames,
+ atol=2,
+ rtol=0,
+ )
+
+ def test_to_file_like_real_file(self, tmp_path):
+ """Test to_file_like with a real file opened in binary write mode."""
+ source_frames = self.decode(TEST_SRC_2_720P.path).data
+ file_path = tmp_path / "test_file_like.mp4"
+
+ with open(file_path, "wb") as file_like:
+ encode_video_to_file_like(
+ source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like
+ )
+ decoded_samples = self.decode(str(file_path))
+
+ torch.testing.assert_close(
+ decoded_samples.data,
+ source_frames,
+ atol=2,
+ rtol=0,
+ )
+
+ def test_to_file_like_bad_methods(self):
+ source_frames = self.decode(TEST_SRC_2_720P.path).data
+
+ class NoWriteMethod:
+ def seek(self, offset, whence=0):
+ return 0
+
+ with pytest.raises(
+ RuntimeError, match="File like object must implement a write method"
+ ):
+ encode_video_to_file_like(
+ source_frames,
+ frame_rate=30,
+ format="mp4",
+ file_like=NoWriteMethod(),
+ )
+
+ class NoSeekMethod:
+ def write(self, data):
+ return len(data)
+
+ with pytest.raises(
+ RuntimeError, match="File like object must implement a seek method"
+ ):
+ encode_video_to_file_like(
+ source_frames, frame_rate=30, format="mp4", file_like=NoSeekMethod()
+ )
+
if __name__ == "__main__":
pytest.main()
diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py
new file mode 100644
index 000000000..8d1ba5e53
--- /dev/null
+++ b/test/test_transform_ops.py
@@ -0,0 +1,279 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import contextlib
+
+import json
+import os
+import subprocess
+
+import pytest
+
+import torch
+
+from torchcodec._core import (
+ _add_video_stream,
+ add_video_stream,
+ create_from_file,
+ get_frame_at_index,
+ get_json_metadata,
+ get_next_frame,
+)
+
+from torchvision.transforms import v2
+
+from .utils import assert_frames_equal, NASA_VIDEO, needs_cuda
+
+torch._dynamo.config.capture_dynamic_output_shape_ops = True
+
+
+class TestVideoDecoderTransformOps:
+ # We choose arbitrary values for width and height scaling to get better
+ # test coverage. Some pairs upscale the image while others downscale it.
+ @pytest.mark.parametrize(
+ "width_scaling_factor,height_scaling_factor",
+ ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)),
+ )
+ @pytest.mark.parametrize("input_video", [NASA_VIDEO])
+ def test_color_conversion_library_with_scaling(
+ self, input_video, width_scaling_factor, height_scaling_factor
+ ):
+ decoder = create_from_file(str(input_video.path))
+ add_video_stream(decoder)
+ metadata = get_json_metadata(decoder)
+ metadata_dict = json.loads(metadata)
+ assert metadata_dict["width"] == input_video.width
+ assert metadata_dict["height"] == input_video.height
+
+ target_height = int(input_video.height * height_scaling_factor)
+ target_width = int(input_video.width * width_scaling_factor)
+ if width_scaling_factor != 1.0:
+ assert target_width != input_video.width
+ if height_scaling_factor != 1.0:
+ assert target_height != input_video.height
+
+ filtergraph_decoder = create_from_file(str(input_video.path))
+ _add_video_stream(
+ filtergraph_decoder,
+ transform_specs=f"resize, {target_height}, {target_width}",
+ color_conversion_library="filtergraph",
+ )
+ filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder)
+
+ swscale_decoder = create_from_file(str(input_video.path))
+ _add_video_stream(
+ swscale_decoder,
+ transform_specs=f"resize, {target_height}, {target_width}",
+ color_conversion_library="swscale",
+ )
+ swscale_frame0, _, _ = get_next_frame(swscale_decoder)
+ assert_frames_equal(filtergraph_frame0, swscale_frame0)
+ assert filtergraph_frame0.shape == (3, target_height, target_width)
+
+ @pytest.mark.parametrize(
+ "width_scaling_factor,height_scaling_factor",
+ ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)),
+ )
+ @pytest.mark.parametrize("width", [30, 32, 300])
+ @pytest.mark.parametrize("height", [128])
+ def test_color_conversion_library_with_generated_videos(
+ self, tmp_path, width, height, width_scaling_factor, height_scaling_factor
+ ):
+ # We consider filtergraph to be the reference color conversion library.
+ # However the video decoder sometimes uses swscale as that is faster.
+ # The exact color conversion library used is an implementation detail
+ # of the video decoder and depends on the video's width.
+ #
+ # In this test we compare the output of filtergraph (which is the
+ # reference) with the output of the video decoder (which may use
+ # swscale if it chooses for certain video widths) to make sure they are
+ # always the same.
+ video_path = f"{tmp_path}/frame_numbers_{width}x{height}.mp4"
+ # We don't specify a particular encoder because the ffmpeg binary could
+ # be configured with different encoders. For the purposes of this test,
+ # the actual encoder is irrelevant.
+ with contextlib.ExitStack() as stack:
+ ffmpeg_cli = "ffmpeg"
+
+ if os.environ.get("IN_FBCODE_TORCHCODEC") == "1":
+ import importlib.resources
+
+ ffmpeg_cli = stack.enter_context(
+ importlib.resources.path(__package__, "ffmpeg")
+ )
+
+ command = [
+ ffmpeg_cli,
+ "-y",
+ "-f",
+ "lavfi",
+ "-i",
+ "color=blue",
+ "-pix_fmt",
+ "yuv420p",
+ "-s",
+ f"{width}x{height}",
+ "-frames:v",
+ "1",
+ video_path,
+ ]
+ subprocess.check_call(command)
+
+ decoder = create_from_file(str(video_path))
+ add_video_stream(decoder)
+ metadata = get_json_metadata(decoder)
+ metadata_dict = json.loads(metadata)
+ assert metadata_dict["width"] == width
+ assert metadata_dict["height"] == height
+
+ target_height = int(height * height_scaling_factor)
+ target_width = int(width * width_scaling_factor)
+ if width_scaling_factor != 1.0:
+ assert target_width != width
+ if height_scaling_factor != 1.0:
+ assert target_height != height
+
+ filtergraph_decoder = create_from_file(str(video_path))
+ _add_video_stream(
+ filtergraph_decoder,
+ transform_specs=f"resize, {target_height}, {target_width}",
+ color_conversion_library="filtergraph",
+ )
+ filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder)
+
+ auto_decoder = create_from_file(str(video_path))
+ add_video_stream(
+ auto_decoder,
+ transform_specs=f"resize, {target_height}, {target_width}",
+ )
+ auto_frame0, _, _ = get_next_frame(auto_decoder)
+ assert_frames_equal(filtergraph_frame0, auto_frame0)
+
+ @needs_cuda
+ def test_scaling_on_cuda_fails(self):
+ decoder = create_from_file(str(NASA_VIDEO.path))
+ with pytest.raises(
+ RuntimeError,
+ match="Transforms are only supported for CPU devices.",
+ ):
+ add_video_stream(decoder, device="cuda", transform_specs="resize, 100, 100")
+
+ def test_transform_fails(self):
+ decoder = create_from_file(str(NASA_VIDEO.path))
+ with pytest.raises(
+ RuntimeError,
+ match="Invalid transform spec",
+ ):
+ add_video_stream(decoder, transform_specs=";")
+
+ with pytest.raises(
+ RuntimeError,
+ match="Invalid transform name",
+ ):
+ add_video_stream(decoder, transform_specs="invalid, 1, 2")
+
+ def test_resize_transform_fails(self):
+ decoder = create_from_file(str(NASA_VIDEO.path))
+ with pytest.raises(
+ RuntimeError,
+ match="must have 3 elements",
+ ):
+ add_video_stream(decoder, transform_specs="resize, 100, 100, 100")
+
+ with pytest.raises(
+ RuntimeError,
+ match="must be a positive integer",
+ ):
+ add_video_stream(decoder, transform_specs="resize, -10, 100")
+
+ with pytest.raises(
+ RuntimeError,
+ match="must be a positive integer",
+ ):
+ add_video_stream(decoder, transform_specs="resize, 100, 0")
+
+ with pytest.raises(
+ RuntimeError,
+ match="cannot be converted to an int",
+ ):
+ add_video_stream(decoder, transform_specs="resize, blah, 100")
+
+ with pytest.raises(
+ RuntimeError,
+ match="out of range",
+ ):
+ add_video_stream(decoder, transform_specs="resize, 100, 1000000000000")
+
+ def test_crop_transform(self):
+ # Note that filtergraph accepts dimensions as (w, h) and we accept them as (h, w).
+ width = 300
+ height = 200
+ x = 50
+ y = 35
+ crop_spec = f"crop, {height}, {width}, {x}, {y}"
+ crop_filtergraph = f"crop={width}:{height}:{x}:{y}:exact=1"
+ expected_shape = (NASA_VIDEO.get_num_color_channels(), height, width)
+
+ decoder_crop = create_from_file(str(NASA_VIDEO.path))
+ add_video_stream(decoder_crop, transform_specs=crop_spec)
+
+ decoder_full = create_from_file(str(NASA_VIDEO.path))
+ add_video_stream(decoder_full)
+
+ for frame_index in [0, 15, 200, 389]:
+ frame, *_ = get_frame_at_index(decoder_crop, frame_index=frame_index)
+ frame_ref = NASA_VIDEO.get_frame_data_by_index(
+ frame_index, filters=crop_filtergraph
+ )
+
+ frame_full, *_ = get_frame_at_index(decoder_full, frame_index=frame_index)
+ frame_tv = v2.functional.crop(
+ frame_full, top=y, left=x, height=height, width=width
+ )
+
+ assert frame.shape == expected_shape
+ assert frame_ref.shape == expected_shape
+ assert frame_tv.shape == expected_shape
+
+ assert_frames_equal(frame, frame_tv)
+ assert_frames_equal(frame, frame_ref)
+
+ def test_crop_transform_fails(self):
+
+ with pytest.raises(
+ RuntimeError,
+ match="must have 5 elements",
+ ):
+ decoder = create_from_file(str(NASA_VIDEO.path))
+ add_video_stream(decoder, transform_specs="crop, 100, 100")
+
+ with pytest.raises(
+ RuntimeError,
+ match="must be a positive integer",
+ ):
+ decoder = create_from_file(str(NASA_VIDEO.path))
+ add_video_stream(decoder, transform_specs="crop, -10, 100, 100, 100")
+
+ with pytest.raises(
+ RuntimeError,
+ match="cannot be converted to an int",
+ ):
+ decoder = create_from_file(str(NASA_VIDEO.path))
+ add_video_stream(decoder, transform_specs="crop, 100, 100, blah, 100")
+
+ with pytest.raises(
+ RuntimeError,
+ match="x position out of bounds",
+ ):
+ decoder = create_from_file(str(NASA_VIDEO.path))
+ add_video_stream(decoder, transform_specs="crop, 100, 100, 9999, 100")
+
+ with pytest.raises(
+ RuntimeError,
+ match="y position out of bounds",
+ ):
+ decoder = create_from_file(str(NASA_VIDEO.path))
+ add_video_stream(decoder, transform_specs="crop, 999, 100, 100, 100")
diff --git a/test/utils.py b/test/utils.py
index 7c91f307c..cbd6a5bf4 100644
--- a/test/utils.py
+++ b/test/utils.py
@@ -14,6 +14,7 @@
import torch
from torchcodec._core import get_ffmpeg_library_versions
+from torchcodec.decoders import set_cuda_backend, VideoDecoder
from torchcodec.decoders._video_decoder import _read_custom_frame_mappings
IS_WINDOWS = sys.platform in ("win32", "cygwin")
@@ -26,40 +27,76 @@ def needs_cuda(test_item):
return pytest.mark.needs_cuda(test_item)
+# This is a special device string that we use to test the "beta" CUDA backend.
+# It only exists here, in this test utils file. Public and core APIs have no
+# idea that this is how we're tesing them. That is, that's not a supported
+# `device` parameter for the VideoDecoder or for the _core APIs.
+# Tests using all_supported_devices() will get this device string, and the test
+# need to clean it up by calling either make_video_decoder for VideoDecoder, or
+# unsplit_device_str for core APIs.
+_CUDA_BETA_DEVICE_STR = "cuda:beta"
+
+
def all_supported_devices():
return (
"cpu",
pytest.param("cuda", marks=pytest.mark.needs_cuda),
- pytest.param("cuda:0:beta", marks=pytest.mark.needs_cuda),
+ pytest.param(_CUDA_BETA_DEVICE_STR, marks=pytest.mark.needs_cuda),
)
def unsplit_device_str(device_str: str) -> str:
# helper meant to be used as
# device, device_variant = unsplit_device_str(device)
- # when `device` comes from all_supported_devices() and may be "cuda:0:beta".
+ # when `device` comes from all_supported_devices() and may be _CUDA_BETA_DEVICE_STR.
# It is used:
- # - before calling `.to(device)` where device can't be "cuda:0:beta"
+ # - before calling `.to(device)` where device can't be _CUDA_BETA_DEVICE_STR.
# - before calling add_video_stream(device=device, device_variant=device_variant)
- #
- # TODONVDEC P2: Find a less clunky way to test the BETA CUDA interface. It
- # will ultimately depend on how we want to publicly expose it.
- if device_str == "cuda:0:beta":
+ if device_str == _CUDA_BETA_DEVICE_STR:
return "cuda", "beta"
else:
- return device_str, "default"
+ return device_str, "ffmpeg"
-def get_ffmpeg_major_version():
+def make_video_decoder(*args, **kwargs) -> tuple[VideoDecoder, str]:
+ # Helper to create a VideoDecoder with the right cuda backend if needed.
+ # kwargs is expected to have a "device" key which comes from
+ # all_supported_devices(), and can be _CUDA_BETA_DEVICE_STR.
+ device = kwargs.pop("device", "cpu")
+ if device == _CUDA_BETA_DEVICE_STR:
+ clean_device, backend = "cuda", "beta"
+ else:
+ clean_device, backend = device, "ffmpeg"
+
+ # set_cuda_backend is a no-op if the device is "cpu", so we can use it
+ # unconditionally.
+ with set_cuda_backend(backend):
+ dec = VideoDecoder(*args, **kwargs, device=clean_device)
+
+ return dec, clean_device
+
+
+def _get_ffmpeg_version_string():
ffmpeg_version = get_ffmpeg_library_versions()["ffmpeg_version"]
# When building FFmpeg from source there can be a `n` prefix in the version
# string. This is quite brittle as we're using av_version_info(), which has
# no stable format. See https://github.com/pytorch/torchcodec/issues/100
if ffmpeg_version.startswith("n"):
ffmpeg_version = ffmpeg_version.removeprefix("n")
+
+ return ffmpeg_version
+
+
+def get_ffmpeg_major_version():
+ ffmpeg_version = _get_ffmpeg_version_string()
return int(ffmpeg_version.split(".")[0])
+def get_ffmpeg_minor_version():
+ ffmpeg_version = _get_ffmpeg_version_string()
+ return int(ffmpeg_version.split(".")[1])
+
+
def cuda_version_used_for_building_torch() -> Optional[tuple[int, int]]:
# Return the CUDA version that was used to build PyTorch. That's not always
# the same as the CUDA version that is currently installed on the running
@@ -130,6 +167,12 @@ def assert_tensor_close_on_at_least(
)
+# We embed filtergraph expressions in filenames, but they contain characters that
+# some filesystems don't like. We turn all special characters into underscores.
+def sanitize_filtergraph_expression(expression: str) -> str:
+ return "".join(c if c.isalnum() else "_" for c in expression)
+
+
def in_fbcode() -> bool:
return os.environ.get("IN_FBCODE_TORCHCODEC") == "1"
@@ -328,16 +371,32 @@ def empty_duration_seconds(self) -> torch.Tensor:
class TestVideo(TestContainerFile):
"""Base class for the *video* streams of a video container"""
+ def get_base_path_by_index(
+ self, idx: int, *, stream_index: int, filters: Optional[str] = None
+ ) -> pathlib.Path:
+ stream_and_frame = f"stream{stream_index}.frame{idx:06d}"
+ if filters is not None:
+ full_name = f"{self.filename}.{sanitize_filtergraph_expression(filters)}.{stream_and_frame}"
+ else:
+ full_name = f"{self.filename}.{stream_and_frame}"
+
+ return _get_file_path(full_name)
+
def get_frame_data_by_index(
- self, idx: int, *, stream_index: Optional[int] = None
+ self,
+ idx: int,
+ *,
+ stream_index: Optional[int] = None,
+ filters: Optional[str] = None,
) -> torch.Tensor:
if stream_index is None:
stream_index = self.default_stream_index
- file_path = _get_file_path(
- f"{self.filename}.stream{stream_index}.frame{idx:06d}.pt"
+ base_path = self.get_base_path_by_index(
+ idx, stream_index=stream_index, filters=filters
)
- return torch.load(file_path, weights_only=True).permute(2, 0, 1)
+ tensor_file_path = f"{base_path}.pt"
+ return torch.load(tensor_file_path, weights_only=True).permute(2, 0, 1)
def get_frame_data_by_range(
self,
@@ -434,6 +493,114 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor:
)
+H265_VIDEO = TestVideo(
+ filename="h265_video.mp4",
+ default_stream_index=0,
+ # This metadata is extracted manually.
+ # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of json test/resources/h265_video.mp4 > out.json
+ stream_infos={
+ 0: TestVideoStreamInfo(width=128, height=128, num_color_channels=3),
+ },
+ frames={
+ 0: {
+ 6: TestFrameInfo(pts_seconds=0.6, duration_seconds=0.1),
+ },
+ },
+)
+
+AV1_VIDEO = TestVideo(
+ filename="av1_video.mkv",
+ default_stream_index=0,
+ # This metadata is extracted manually.
+ # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of json test/resources/av1_video.mkv > out.json
+ stream_infos={
+ 0: TestVideoStreamInfo(width=640, height=360, num_color_channels=3),
+ },
+ frames={
+ 0: {
+ 10: TestFrameInfo(pts_seconds=0.400000, duration_seconds=0.040000),
+ },
+ },
+)
+
+
+# This is a BT.709 full range video, generated with:
+# ffmpeg -f lavfi -i testsrc2=duration=1:size=1920x720:rate=30 \
+# -c:v libx264 -pix_fmt yuv420p -color_primaries bt709 -color_trc bt709 \
+# -colorspace bt709 -color_range pc bt709_full_range.mp4
+#
+# We can confirm the color space and color range with:
+# ffprobe -v quiet -select_streams v:0 -show_entries stream=color_space,color_transfer,color_primaries,color_range -of default=noprint_wrappers=1 test/resources/bt709_full_range.mp4
+# color_range=pc
+# color_space=bt709
+# color_transfer=bt709
+# color_primaries=bt709
+BT709_FULL_RANGE = TestVideo(
+ filename="bt709_full_range.mp4",
+ default_stream_index=0,
+ stream_infos={
+ 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
+ },
+ frames={0: {}}, # Not needed for now
+)
+
+# ffmpeg -f lavfi -i testsrc2=duration=2:size=1280x720:rate=30 -c:v libx264 -profile:v baseline -level 3.1 -pix_fmt yuv420p -b:v 2500k -r 30 -movflags +faststart output_720p_2s.mp4
+TEST_SRC_2_720P = TestVideo(
+ filename="testsrc2.mp4",
+ default_stream_index=0,
+ stream_infos={
+ 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
+ },
+ frames={0: {}}, # Not needed for now
+)
+# ffmpeg -f lavfi -i testsrc2=duration=10:size=1280x720:rate=30 -c:v libx265 -crf 23 -preset medium output.mp4
+TEST_SRC_2_720P_H265 = TestVideo(
+ filename="testsrc2_h265.mp4",
+ default_stream_index=0,
+ stream_infos={
+ 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
+ },
+ frames={0: {}}, # Not needed for now
+)
+
+# ffmpeg -f lavfi -i testsrc2=size=1280x720:rate=30:duration=1 -c:v libvpx-vp9 -b:v 1M output_vp9.webm
+TEST_SRC_2_720P_VP9 = TestVideo(
+ filename="testsrc2_vp9.webm",
+ default_stream_index=0,
+ stream_infos={
+ 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
+ },
+ frames={0: {}}, # Not needed for now
+)
+
+# ffmpeg -f lavfi -i testsrc2=size=1280x720:rate=30:duration=1 -c:v libvpx -b:v 1M output_vp8.webm
+TEST_SRC_2_720P_VP8 = TestVideo(
+ filename="testsrc2_vp8.webm",
+ default_stream_index=0,
+ stream_infos={
+ 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
+ },
+ frames={0: {}}, # Not needed for now
+)
+
+# ffmpeg -f lavfi -i testsrc2=size=1280x720:rate=30:duration=1 -c:v mpeg4 -q:v 5 output_mpeg4.avi
+TEST_SRC_2_720P_MPEG4 = TestVideo(
+ filename="testsrc2_mpeg4.avi",
+ default_stream_index=0,
+ stream_infos={
+ 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
+ },
+ frames={0: {}}, # Not needed for now
+)
+
+
+def supports_approximate_mode(asset: TestVideo) -> bool:
+ # Those are missing the `duration` field so they fail in approximate mode (on all devices).
+ # TODO: we should address this, see
+ # https://github.com/meta-pytorch/torchcodec/issues/945
+ return asset not in (AV1_VIDEO, TEST_SRC_2_720P_VP9, TEST_SRC_2_720P_VP8)
+
+
@dataclass
class TestAudio(TestContainerFile):
"""Base class for the *audio* streams of a container (potentially a video),
@@ -647,110 +814,3 @@ def sample_format(self) -> str:
)
},
)
-
-H265_VIDEO = TestVideo(
- filename="h265_video.mp4",
- default_stream_index=0,
- # This metadata is extracted manually.
- # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of json test/resources/h265_video.mp4 > out.json
- stream_infos={
- 0: TestVideoStreamInfo(width=128, height=128, num_color_channels=3),
- },
- frames={
- 0: {
- 6: TestFrameInfo(pts_seconds=0.6, duration_seconds=0.1),
- },
- },
-)
-
-AV1_VIDEO = TestVideo(
- filename="av1_video.mkv",
- default_stream_index=0,
- # This metadata is extracted manually.
- # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of json test/resources/av1_video.mkv > out.json
- stream_infos={
- 0: TestVideoStreamInfo(width=640, height=360, num_color_channels=3),
- },
- frames={
- 0: {
- 10: TestFrameInfo(pts_seconds=0.400000, duration_seconds=0.040000),
- },
- },
-)
-
-
-# This is a BT.709 full range video, generated with:
-# ffmpeg -f lavfi -i testsrc2=duration=1:size=1920x720:rate=30 \
-# -c:v libx264 -pix_fmt yuv420p -color_primaries bt709 -color_trc bt709 \
-# -colorspace bt709 -color_range pc bt709_full_range.mp4
-#
-# We can confirm the color space and color range with:
-# ffprobe -v quiet -select_streams v:0 -show_entries stream=color_space,color_transfer,color_primaries,color_range -of default=noprint_wrappers=1 test/resources/bt709_full_range.mp4
-# color_range=pc
-# color_space=bt709
-# color_transfer=bt709
-# color_primaries=bt709
-BT709_FULL_RANGE = TestVideo(
- filename="bt709_full_range.mp4",
- default_stream_index=0,
- stream_infos={
- 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
- },
- frames={0: {}}, # Not needed for now
-)
-
-# ffmpeg -f lavfi -i testsrc2=duration=2:size=1280x720:rate=30 -c:v libx264 -profile:v baseline -level 3.1 -pix_fmt yuv420p -b:v 2500k -r 30 -movflags +faststart output_720p_2s.mp4
-TEST_SRC_2_720P = TestVideo(
- filename="testsrc2.mp4",
- default_stream_index=0,
- stream_infos={
- 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
- },
- frames={0: {}}, # Not needed for now
-)
-# ffmpeg -f lavfi -i testsrc2=duration=10:size=1280x720:rate=30 -c:v libx265 -crf 23 -preset medium output.mp4
-TEST_SRC_2_720P_H265 = TestVideo(
- filename="testsrc2_h265.mp4",
- default_stream_index=0,
- stream_infos={
- 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
- },
- frames={0: {}}, # Not needed for now
-)
-
-# ffmpeg -f lavfi -i testsrc2=size=1280x720:rate=30:duration=1 -c:v libvpx-vp9 -b:v 1M output_vp9.webm
-TEST_SRC_2_720P_VP9 = TestVideo(
- filename="testsrc2_vp9.webm",
- default_stream_index=0,
- stream_infos={
- 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
- },
- frames={0: {}}, # Not needed for now
-)
-
-# ffmpeg -f lavfi -i testsrc2=size=1280x720:rate=30:duration=1 -c:v libvpx -b:v 1M output_vp8.webm
-TEST_SRC_2_720P_VP8 = TestVideo(
- filename="testsrc2_vp8.webm",
- default_stream_index=0,
- stream_infos={
- 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
- },
- frames={0: {}}, # Not needed for now
-)
-
-# ffmpeg -f lavfi -i testsrc2=size=1280x720:rate=30:duration=1 -c:v mpeg4 -q:v 5 output_mpeg4.avi
-TEST_SRC_2_720P_MPEG4 = TestVideo(
- filename="testsrc2_mpeg4.avi",
- default_stream_index=0,
- stream_infos={
- 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
- },
- frames={0: {}}, # Not needed for now
-)
-
-
-def supports_approximate_mode(asset: TestVideo) -> bool:
- # TODONVDEC P2: open an issue about his. That's actually not related to
- # NVDEC at all, those don't support approximate mode because they don't set
- # a duration. CPU decoder fails too!
- return asset not in (AV1_VIDEO, TEST_SRC_2_720P_VP9, TEST_SRC_2_720P_VP8)
diff --git a/version.txt b/version.txt
index a3df0a695..c18d72be3 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.8.0
+0.8.1
\ No newline at end of file