diff --git a/.github/workflows/linux_wheel.yaml b/.github/workflows/linux_wheel.yaml index 099a905c4..25a5a564e 100644 --- a/.github/workflows/linux_wheel.yaml +++ b/.github/workflows/linux_wheel.yaml @@ -72,10 +72,14 @@ jobs: name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_x86_64 path: pytorch/torchcodec/dist/ - name: Setup conda env - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true - miniconda-version: "latest" + # Using miniforge instead of miniconda ensures that the default + # conda channel is conda-forge instead of main/default. This ensures + # ABI consistency between dependencies: + # https://conda-forge.org/docs/user/transitioning_from_defaults/ + miniforge-version: latest activate-environment: test python-version: ${{ matrix.python-version }} - name: Update pip diff --git a/.github/workflows/reference_resources.yaml b/.github/workflows/reference_resources.yaml index 25353d70c..8f97378f1 100644 --- a/.github/workflows/reference_resources.yaml +++ b/.github/workflows/reference_resources.yaml @@ -14,7 +14,40 @@ defaults: shell: bash -l -eo pipefail {0} jobs: + generate-matrix: + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + with: + package-type: wheel + os: linux + test-infra-repository: pytorch/test-infra + test-infra-ref: main + with-xpu: disable + with-rocm: disable + with-cuda: disable + build-python-only: "disable" + + build: + needs: generate-matrix + strategy: + fail-fast: false + name: Build and Upload Linux wheel + uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main + with: + repository: meta-pytorch/torchcodec + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: packaging/pre_build_script.sh + post-script: packaging/post_build_script.sh + smoke-test-script: packaging/fake_smoke_test.py + package-name: torchcodec + trigger-event: ${{ github.event_name }} + build-platform: "python-build-package" + build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 python -m build --wheel -vvv --no-isolation" + test-reference-resource-generation: + needs: build runs-on: ubuntu-latest strategy: fail-fast: false @@ -22,6 +55,10 @@ jobs: python-version: ['3.10'] ffmpeg-version-for-tests: ['4.4.2', '5.1.2', '6.1.1', '7.0.1'] steps: + - uses: actions/download-artifact@v4 + with: + name: meta-pytorch_torchcodec__${{ matrix.python-version }}_cpu_x86_64 + path: pytorch/torchcodec/dist/ - name: Setup conda env uses: conda-incubator/setup-miniconda@v2 with: @@ -43,11 +80,16 @@ jobs: # Note that we're installing stable - this is for running a script where we're a normal PyTorch # user, not for building TorhCodec. python -m pip install torch --index-url https://download.pytorch.org/whl/cpu - python -m pip install numpy pillow + python -m pip install numpy pillow pytest + - name: Install torchcodec from the wheel + run: | + wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"` + echo Installing $wheel_path + python -m pip install $wheel_path -vvv - name: Check out repo uses: actions/checkout@v3 - name: Run generation reference resources run: | - python test/generate_reference_resources.py + python -m test.generate_reference_resources diff --git a/.github/workflows/windows_wheel.yaml b/.github/workflows/windows_wheel.yaml index 39247f770..2fd773fec 100644 --- a/.github/workflows/windows_wheel.yaml +++ b/.github/workflows/windows_wheel.yaml @@ -71,8 +71,7 @@ jobs: # TODO: FFmpeg 5 on Windows segfaults in avcodec_open2() when passing # bad parameters. # See https://github.com/pytorch/torchcodec/pull/806 - # TODO: Support FFmpeg 8 on Windows - ffmpeg-version-for-tests: ['4.4.2', '5.1.2', '6.1.1', '7.0.1'] + ffmpeg-version-for-tests: ['4.4.2', '5.1.2', '6.1.1', '7.0.1', '8.0'] needs: build steps: - uses: actions/download-artifact@v4 @@ -83,7 +82,11 @@ jobs: uses: conda-incubator/setup-miniconda@v2 with: auto-update-conda: true - miniconda-version: "latest" + # Using miniforge instead of miniconda ensures that the default + # conda channel is conda-forge instead of main/default. This ensures + # ABI consistency between dependencies: + # https://conda-forge.org/docs/user/transitioning_from_defaults/ + miniforge-version: latest activate-environment: test python-version: ${{ matrix.python-version }} - name: Update pip diff --git a/README.md b/README.md index 8050cf2a3..6c1721036 100644 --- a/README.md +++ b/README.md @@ -107,8 +107,8 @@ ffmpeg -f lavfi -i \ `torch` and `torchcodec`. 2. Install FFmpeg, if it's not already installed. Linux distributions usually - come with FFmpeg pre-installed. TorchCodec supports all major FFmpeg versions - in [4, 7]. + come with FFmpeg pre-installed. TorchCodec supports major FFmpeg versions + in [4, 7] on all platforms, and FFmpeg version 8 is supported on Mac and Linux. If FFmpeg is not already installed, or you need a more recent version, an easy way to install it is to use `conda`: @@ -131,6 +131,7 @@ The following table indicates the compatibility between versions of | `torchcodec` | `torch` | Python | | ------------------ | ------------------ | ------------------- | | `main` / `nightly` | `main` / `nightly` | `>=3.10`, `<=3.13` | +| `0.8` | `2.9` | `>=3.10`, `<=3.13` | | `0.7` | `2.8` | `>=3.9`, `<=3.13` | | `0.6` | `2.8` | `>=3.9`, `<=3.13` | | `0.5` | `2.7` | `>=3.9`, `<=3.13` | @@ -147,7 +148,8 @@ format you want. Refer to Nvidia's GPU support matrix for more details [here](https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new). 1. Install FFmpeg with NVDEC support. - TorchCodec with CUDA should work with FFmpeg versions in [4, 7]. + TorchCodec with CUDA should work with FFmpeg versions in [4, 7] on all platforms, + and FFmpeg version 8 is supported on Linux. If FFmpeg is not already installed, or you need a more recent version, an easy way to install it is to use `conda`: diff --git a/benchmarks/decoders/benchmark_transforms.py b/benchmarks/decoders/benchmark_transforms.py new file mode 100644 index 000000000..75a49d63b --- /dev/null +++ b/benchmarks/decoders/benchmark_transforms.py @@ -0,0 +1,164 @@ +import math +from argparse import ArgumentParser +from pathlib import Path +from time import perf_counter_ns + +import torch +from torch import Tensor +from torchcodec._core import add_video_stream, create_from_file, get_frames_by_pts +from torchcodec.decoders import VideoDecoder +from torchvision.transforms import v2 + +DEFAULT_NUM_EXP = 20 + + +def bench(f, *args, num_exp=DEFAULT_NUM_EXP, warmup=1) -> Tensor: + + for _ in range(warmup): + f(*args) + + times = [] + for _ in range(num_exp): + start = perf_counter_ns() + f(*args) + end = perf_counter_ns() + times.append(end - start) + return torch.tensor(times).float() + + +def report_stats(times: Tensor, unit: str = "ms", prefix: str = "") -> float: + mul = { + "ns": 1, + "µs": 1e-3, + "ms": 1e-6, + "s": 1e-9, + }[unit] + times = times * mul + std = times.std().item() + med = times.median().item() + mean = times.mean().item() + min = times.min().item() + max = times.max().item() + print( + f"{prefix:<45} {med = :.2f}, {mean = :.2f} +- {std:.2f}, {min = :.2f}, {max = :.2f} - in {unit}" + ) + + +def torchvision_resize( + path: Path, pts_seconds: list[float], dims: tuple[int, int] +) -> None: + decoder = create_from_file(str(path), seek_mode="approximate") + add_video_stream(decoder) + raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds) + return v2.functional.resize(raw_frames, size=dims) + + +def torchvision_crop( + path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int +) -> None: + decoder = create_from_file(str(path), seek_mode="approximate") + add_video_stream(decoder) + raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds) + return v2.functional.crop(raw_frames, top=y, left=x, height=dims[0], width=dims[1]) + + +def decoder_native_resize( + path: Path, pts_seconds: list[float], dims: tuple[int, int] +) -> None: + decoder = create_from_file(str(path), seek_mode="approximate") + add_video_stream(decoder, transform_specs=f"resize, {dims[0]}, {dims[1]}") + return get_frames_by_pts(decoder, timestamps=pts_seconds)[0] + + +def decoder_native_crop( + path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int +) -> None: + decoder = create_from_file(str(path), seek_mode="approximate") + add_video_stream(decoder, transform_specs=f"crop, {dims[0]}, {dims[1]}, {x}, {y}") + return get_frames_by_pts(decoder, timestamps=pts_seconds)[0] + + +def main(): + parser = ArgumentParser() + parser.add_argument("--path", type=str, help="path to file", required=True) + parser.add_argument( + "--num-exp", + type=int, + default=DEFAULT_NUM_EXP, + help="number of runs to average over", + ) + + args = parser.parse_args() + path = Path(args.path) + + metadata = VideoDecoder(path).metadata + duration = metadata.duration_seconds + + print( + f"Benchmarking {path.name}, duration: {duration}, codec: {metadata.codec}, averaging over {args.num_exp} runs:" + ) + + input_height = metadata.height + input_width = metadata.width + fraction_of_total_frames_to_sample = [0.005, 0.01, 0.05, 0.1] + fraction_of_input_dimensions = [0.5, 0.25, 0.125] + + for num_fraction in fraction_of_total_frames_to_sample: + num_frames_to_sample = math.ceil(metadata.num_frames * num_fraction) + print( + f"Sampling {num_fraction * 100}%, {num_frames_to_sample}, of {metadata.num_frames} frames" + ) + uniform_timestamps = [ + i * duration / num_frames_to_sample for i in range(num_frames_to_sample) + ] + + for dims_fraction in fraction_of_input_dimensions: + dims = (int(input_height * dims_fraction), int(input_width * dims_fraction)) + + times = bench( + torchvision_resize, path, uniform_timestamps, dims, num_exp=args.num_exp + ) + report_stats(times, prefix=f"torchvision_resize({dims})") + + times = bench( + decoder_native_resize, + path, + uniform_timestamps, + dims, + num_exp=args.num_exp, + ) + report_stats(times, prefix=f"decoder_native_resize({dims})") + print() + + center_x = (input_height - dims[0]) // 2 + center_y = (input_width - dims[1]) // 2 + times = bench( + torchvision_crop, + path, + uniform_timestamps, + dims, + center_x, + center_y, + num_exp=args.num_exp, + ) + report_stats( + times, prefix=f"torchvision_crop({dims}, {center_x}, {center_y})" + ) + + times = bench( + decoder_native_crop, + path, + uniform_timestamps, + dims, + center_x, + center_y, + num_exp=args.num_exp, + ) + report_stats( + times, prefix=f"decoder_native_crop({dims}, {center_x}, {center_y})" + ) + print() + + +if __name__ == "__main__": + main() diff --git a/docs/source/index.rst b/docs/source/index.rst index 8dea1dc8b..85f9a067c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,7 +11,7 @@ We achieve these capabilities through: * Pythonic APIs that mirror Python and PyTorch conventions. * Relying on `FFmpeg `_ to do the decoding / encoding. - TorchCodec uses the version of FFmpeg you already have installed. FMPEG is a + TorchCodec uses the version of FFmpeg you already have installed. FFmpeg is a mature library with broad coverage available on most systems. It is, however, not easy to use. TorchCodec abstracts FFmpeg's complexity to ensure it is used correctly and efficiently. diff --git a/src/torchcodec/_core/AVIOTensorContext.cpp b/src/torchcodec/_core/AVIOTensorContext.cpp index 3f45f5be5..238475761 100644 --- a/src/torchcodec/_core/AVIOTensorContext.cpp +++ b/src/torchcodec/_core/AVIOTensorContext.cpp @@ -18,15 +18,15 @@ constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB int read(void* opaque, uint8_t* buf, int buf_size) { auto tensorContext = static_cast(opaque); TORCH_CHECK( - tensorContext->current <= tensorContext->data.numel(), - "Tried to read outside of the buffer: current=", - tensorContext->current, + tensorContext->current_pos <= tensorContext->data.numel(), + "Tried to read outside of the buffer: current_pos=", + tensorContext->current_pos, ", size=", tensorContext->data.numel()); int64_t numBytesRead = std::min( static_cast(buf_size), - tensorContext->data.numel() - tensorContext->current); + tensorContext->data.numel() - tensorContext->current_pos); TORCH_CHECK( numBytesRead >= 0, @@ -34,8 +34,8 @@ int read(void* opaque, uint8_t* buf, int buf_size) { numBytesRead, ", size=", tensorContext->data.numel(), - ", current=", - tensorContext->current); + ", current_pos=", + tensorContext->current_pos); if (numBytesRead == 0) { return AVERROR_EOF; @@ -43,9 +43,9 @@ int read(void* opaque, uint8_t* buf, int buf_size) { std::memcpy( buf, - tensorContext->data.data_ptr() + tensorContext->current, + tensorContext->data.data_ptr() + tensorContext->current_pos, numBytesRead); - tensorContext->current += numBytesRead; + tensorContext->current_pos += numBytesRead; return numBytesRead; } @@ -54,7 +54,7 @@ int write(void* opaque, const uint8_t* buf, int buf_size) { auto tensorContext = static_cast(opaque); int64_t bufSize = static_cast(buf_size); - if (tensorContext->current + bufSize > tensorContext->data.numel()) { + if (tensorContext->current_pos + bufSize > tensorContext->data.numel()) { TORCH_CHECK( tensorContext->data.numel() * 2 <= MAX_TENSOR_SIZE, "We tried to allocate an output encoded tensor larger than ", @@ -68,13 +68,17 @@ int write(void* opaque, const uint8_t* buf, int buf_size) { } TORCH_CHECK( - tensorContext->current + bufSize <= tensorContext->data.numel(), + tensorContext->current_pos + bufSize <= tensorContext->data.numel(), "Re-allocation of the output tensor didn't work. ", "This should not happen, please report on TorchCodec bug tracker"); uint8_t* outputTensorData = tensorContext->data.data_ptr(); - std::memcpy(outputTensorData + tensorContext->current, buf, bufSize); - tensorContext->current += bufSize; + std::memcpy(outputTensorData + tensorContext->current_pos, buf, bufSize); + tensorContext->current_pos += bufSize; + // Track the maximum position written so getOutputTensor's narrow() does not + // truncate the file if final seek was backwards + tensorContext->max_pos = + std::max(tensorContext->current_pos, tensorContext->max_pos); return buf_size; } @@ -88,7 +92,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) { ret = tensorContext->data.numel(); break; case SEEK_SET: - tensorContext->current = offset; + tensorContext->current_pos = offset; ret = offset; break; default: @@ -101,7 +105,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) { } // namespace AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data) - : tensorContext_{data, 0} { + : tensorContext_{data, 0, 0} { TORCH_CHECK(data.numel() > 0, "data must not be empty"); TORCH_CHECK(data.is_contiguous(), "data must be contiguous"); TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8"); @@ -110,14 +114,17 @@ AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data) } AVIOToTensorContext::AVIOToTensorContext() - : tensorContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} { + : tensorContext_{ + torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), + 0, + 0} { createAVIOContext( nullptr, &write, &seek, &tensorContext_, /*isForWriting=*/true); } torch::Tensor AVIOToTensorContext::getOutputTensor() { return tensorContext_.data.narrow( - /*dim=*/0, /*start=*/0, /*length=*/tensorContext_.current); + /*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max_pos); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/AVIOTensorContext.h b/src/torchcodec/_core/AVIOTensorContext.h index 15f97da55..bcd97052b 100644 --- a/src/torchcodec/_core/AVIOTensorContext.h +++ b/src/torchcodec/_core/AVIOTensorContext.h @@ -15,7 +15,8 @@ namespace detail { struct TensorContext { torch::Tensor data; - int64_t current; + int64_t current_pos; + int64_t max_pos; }; } // namespace detail diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 78fa8d635..b0caa9705 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -15,7 +15,7 @@ #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/NVDECCache.h" -// #include // For cudaStreamSynchronize +#include "src/torchcodec/_core/NVCUVIDRuntimeLoader.h" #include "src/torchcodec/_core/nvcuvid_include/cuviddec.h" #include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h" @@ -53,74 +53,6 @@ pfnDisplayPictureCallback(void* pUserData, CUVIDPARSERDISPINFO* dispInfo) { } static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) { - // Check decoder capabilities - same checks as DALI - auto caps = CUVIDDECODECAPS{}; - caps.eCodecType = videoFormat->codec; - caps.eChromaFormat = videoFormat->chroma_format; - caps.nBitDepthMinus8 = videoFormat->bit_depth_luma_minus8; - CUresult result = cuvidGetDecoderCaps(&caps); - TORCH_CHECK(result == CUDA_SUCCESS, "Failed to get decoder caps: ", result); - - TORCH_CHECK( - caps.bIsSupported, - "Codec configuration not supported on this GPU. " - "Codec: ", - static_cast(videoFormat->codec), - ", chroma format: ", - static_cast(videoFormat->chroma_format), - ", bit depth: ", - videoFormat->bit_depth_luma_minus8 + 8); - - TORCH_CHECK( - videoFormat->coded_width >= caps.nMinWidth && - videoFormat->coded_height >= caps.nMinHeight, - "Video is too small in at least one dimension. Provided: ", - videoFormat->coded_width, - "x", - videoFormat->coded_height, - " vs supported:", - caps.nMinWidth, - "x", - caps.nMinHeight); - - TORCH_CHECK( - videoFormat->coded_width <= caps.nMaxWidth && - videoFormat->coded_height <= caps.nMaxHeight, - "Video is too large in at least one dimension. Provided: ", - videoFormat->coded_width, - "x", - videoFormat->coded_height, - " vs supported:", - caps.nMaxWidth, - "x", - caps.nMaxHeight); - - // See nMaxMBCount in cuviddec.h - constexpr unsigned int macroblockConstant = 256; - TORCH_CHECK( - videoFormat->coded_width * videoFormat->coded_height / - macroblockConstant <= - caps.nMaxMBCount, - "Video is too large (too many macroblocks). " - "Provided (width * height / ", - macroblockConstant, - "): ", - videoFormat->coded_width * videoFormat->coded_height / macroblockConstant, - " vs supported:", - caps.nMaxMBCount); - - // Below we'll set the decoderParams.OutputFormat to NV12, so we need to make - // sure it's actually supported. - TORCH_CHECK( - (caps.nOutputFormatMask >> cudaVideoSurfaceFormat_NV12) & 1, - "NV12 output format is not supported for this configuration. ", - "Codec: ", - static_cast(videoFormat->codec), - ", chroma format: ", - static_cast(videoFormat->chroma_format), - ", bit depth: ", - videoFormat->bit_depth_luma_minus8 + 8); - // Decoder creation parameters, most are taken from DALI CUVIDDECODECREATEINFO decoderParams = {}; decoderParams.bitDepthMinus8 = videoFormat->bit_depth_luma_minus8; @@ -129,7 +61,7 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) { // automatically converted to 8bits by NVDEC itself. That is, the raw frames // we get back from cuvidMapVideoFrame will already be in 8bit format. We // won't need to do the conversion ourselves, so that's a lot easier. - // In the default interface, we have to do the 10 -> 8bits conversion + // In the ffmpeg CUDA interface, we have to do the 10 -> 8bits conversion // ourselves later in convertAVFrameToFrameOutput(), because FFmpeg explicitly // requests 10 or 16bits output formats for >8-bit videos! // https://github.com/FFmpeg/FFmpeg/blob/e05f8acabff468c1382277c1f31fa8e9d90c3202/libavcodec/nvdec.c#L376-L403 @@ -157,13 +89,39 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) { decoderParams.display_area.bottom = videoFormat->display_area.bottom; CUvideodecoder* decoder = new CUvideodecoder(); - result = cuvidCreateDecoder(decoder, &decoderParams); + CUresult result = cuvidCreateDecoder(decoder, &decoderParams); TORCH_CHECK( result == CUDA_SUCCESS, "Failed to create NVDEC decoder: ", result); return UniqueCUvideodecoder(decoder, CUvideoDecoderDeleter{}); } -cudaVideoCodec validateCodecSupport(AVCodecID codecId) { +std::optional validateChromaSupport( + const AVPixFmtDescriptor* desc) { + // Return the corresponding cudaVideoChromaFormat if supported, std::nullopt + // otherwise. + TORCH_CHECK(desc != nullptr, "desc can't be null"); + + if (desc->nb_components == 1) { + return cudaVideoChromaFormat_Monochrome; + } else if (desc->nb_components >= 3 && !(desc->flags & AV_PIX_FMT_FLAG_RGB)) { + // Make sure it's YUV: has chroma planes and isn't RGB + if (desc->log2_chroma_w == 0 && desc->log2_chroma_h == 0) { + return cudaVideoChromaFormat_444; // 1x1 subsampling = 4:4:4 + } else if (desc->log2_chroma_w == 1 && desc->log2_chroma_h == 1) { + return cudaVideoChromaFormat_420; // 2x2 subsampling = 4:2:0 + } else if (desc->log2_chroma_w == 1 && desc->log2_chroma_h == 0) { + return cudaVideoChromaFormat_422; // 2x1 subsampling = 4:2:2 + } + } + + return std::nullopt; +} + +std::optional validateCodecSupport(AVCodecID codecId) { + // Return the corresponding cudaVideoCodec if supported, std::nullopt + // otherwise + // Note that we currently return nullopt (and thus fallback to CPU) for some + // codecs that are technically supported by NVDEC, see comment below. switch (codecId) { case AV_CODEC_ID_H264: return cudaVideoCodec_H264; @@ -189,12 +147,72 @@ cudaVideoCodec validateCodecSupport(AVCodecID codecId) { // return cudaVideoCodec_JPEG; // case AV_CODEC_ID_VC1: // return cudaVideoCodec_VC1; - default: { - TORCH_CHECK(false, "Unsupported codec type: ", avcodec_get_name(codecId)); - } + default: + return std::nullopt; } } +bool nativeNVDECSupport(const SharedAVCodecContext& codecContext) { + // Return true iff the input video stream is supported by our NVDEC + // implementation. + + auto codecType = validateCodecSupport(codecContext->codec_id); + if (!codecType.has_value()) { + return false; + } + + const AVPixFmtDescriptor* desc = av_pix_fmt_desc_get(codecContext->pix_fmt); + if (!desc) { + return false; + } + + auto chromaFormat = validateChromaSupport(desc); + if (!chromaFormat.has_value()) { + return false; + } + + auto caps = CUVIDDECODECAPS{}; + caps.eCodecType = codecType.value(); + caps.eChromaFormat = chromaFormat.value(); + caps.nBitDepthMinus8 = desc->comp[0].depth - 8; + + CUresult result = cuvidGetDecoderCaps(&caps); + if (result != CUDA_SUCCESS) { + return false; + } + + if (!caps.bIsSupported) { + return false; + } + + auto coded_width = static_cast(codecContext->coded_width); + auto coded_height = static_cast(codecContext->coded_height); + if (coded_width < static_cast(caps.nMinWidth) || + coded_height < static_cast(caps.nMinHeight) || + coded_width > caps.nMaxWidth || coded_height > caps.nMaxHeight) { + return false; + } + + // See nMaxMBCount in cuviddec.h + constexpr unsigned int macroblockConstant = 256; + if (coded_width * coded_height / macroblockConstant > caps.nMaxMBCount) { + return false; + } + + // We'll set the decoderParams.OutputFormat to NV12, so we need to make + // sure it's actually supported. + // TODO: If this fail, we could consider decoding to something else than NV12 + // (like cudaVideoSurfaceFormat_P016) instead of falling back to CPU. This is + // what FFmpeg does. + bool supportsNV12Output = + (caps.nOutputFormatMask >> cudaVideoSurfaceFormat_NV12) & 1; + if (!supportsNV12Output) { + return false; + } + + return true; +} + } // namespace BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device) @@ -205,6 +223,8 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device) initializeCudaContextWithPytorch(device_); nppCtx_ = getNppStreamContext(device_); + + nvcuvidAvailable_ = loadNVCUVIDLibrary(); } BetaCudaDeviceInterface::~BetaCudaDeviceInterface() { @@ -216,12 +236,11 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() { // unclear. flush(); unmapPreviousFrame(); - NVDECCache::getCache(device_.index()) - .returnDecoder(&videoFormat_, std::move(decoder_)); + NVDECCache::getCache(device_).returnDecoder( + &videoFormat_, std::move(decoder_)); } if (videoParser_) { - // TODONVDEC P2: consider caching this? Does DALI do that? cuvidDestroyVideoParser(videoParser_); videoParser_ = nullptr; } @@ -231,7 +250,21 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() { void BetaCudaDeviceInterface::initialize( const AVStream* avStream, - const UniqueDecodingAVFormatContext& avFormatCtx) { + const UniqueDecodingAVFormatContext& avFormatCtx, + [[maybe_unused]] const SharedAVCodecContext& codecContext) { + if (!nvcuvidAvailable_ || !nativeNVDECSupport(codecContext)) { + cpuFallback_ = createDeviceInterface(torch::kCPU); + TORCH_CHECK( + cpuFallback_ != nullptr, "Failed to create CPU device interface"); + cpuFallback_->initialize(avStream, avFormatCtx, codecContext); + cpuFallback_->initializeVideo( + VideoStreamOptions(), + {}, + /*resizedOutputDims=*/std::nullopt); + // We'll always use the CPU fallback from now on, so we can return early. + return; + } + TORCH_CHECK(avStream != nullptr, "AVStream cannot be null"); timeBase_ = avStream->time_base; frameRateAvgFromFFmpeg_ = avStream->r_frame_rate; @@ -243,7 +276,11 @@ void BetaCudaDeviceInterface::initialize( // Create parser. Default values that aren't obvious are taken from DALI. CUVIDPARSERPARAMS parserParams = {}; - parserParams.CodecType = validateCodecSupport(codecPar->codec_id); + auto codecType = validateCodecSupport(codecPar->codec_id); + TORCH_CHECK( + codecType.has_value(), + "This should never happen, we should be using the CPU fallback by now. Please report a bug."); + parserParams.CodecType = codecType.value(); parserParams.ulMaxNumDecodeSurfaces = 8; parserParams.ulMaxDisplayDelay = 0; // Callback setup, all are triggered by the parser within a call @@ -362,11 +399,12 @@ int BetaCudaDeviceInterface::streamPropertyChange(CUVIDEOFORMAT* videoFormat) { } if (!decoder_) { - decoder_ = NVDECCache::getCache(device_.index()).getDecoder(videoFormat); + decoder_ = NVDECCache::getCache(device_).getDecoder(videoFormat); if (!decoder_) { // TODONVDEC P2: consider re-configuring an existing decoder instead of - // re-creating one. See docs, see DALI. + // re-creating one. See docs, see DALI. Re-configuration doesn't seem to + // be enabled in DALI by default. decoder_ = createDecoder(videoFormat); } @@ -382,6 +420,10 @@ int BetaCudaDeviceInterface::streamPropertyChange(CUVIDEOFORMAT* videoFormat) { // Moral equivalent of avcodec_send_packet(). Here, we pass the AVPacket down to // the NVCUVID parser. int BetaCudaDeviceInterface::sendPacket(ReferenceAVPacket& packet) { + if (cpuFallback_) { + return cpuFallback_->sendPacket(packet); + } + TORCH_CHECK( packet.get() && packet->data && packet->size > 0, "sendPacket received an empty packet, this is unexpected, please report."); @@ -405,6 +447,10 @@ int BetaCudaDeviceInterface::sendPacket(ReferenceAVPacket& packet) { } int BetaCudaDeviceInterface::sendEOFPacket() { + if (cpuFallback_) { + return cpuFallback_->sendEOFPacket(); + } + CUVIDSOURCEDATAPACKET cuvidPacket = {}; cuvidPacket.flags = CUVID_PKT_ENDOFSTREAM; eofSent_ = true; @@ -466,6 +512,10 @@ int BetaCudaDeviceInterface::frameReadyInDisplayOrder( // Moral equivalent of avcodec_receive_frame(). int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) { + if (cpuFallback_) { + return cpuFallback_->receiveFrame(avFrame); + } + if (readyFrames_.empty()) { // No frame found, instruct caller to try again later after sending more // packets, or to stop if EOF was already sent. @@ -480,8 +530,7 @@ int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) { procParams.top_field_first = dispInfo.top_field_first; procParams.unpaired_field = dispInfo.repeat_first_field < 0; // We set the NVDEC stream to the current stream. It will be waited upon by - // the NPP stream before any color conversion. Currently, that syncing logic - // is in the default interface. + // the NPP stream before any color conversion. // Re types: we get a cudaStream_t from PyTorch but it's interchangeable with // CUstream procParams.output_stream = reinterpret_cast( @@ -601,6 +650,11 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame( } void BetaCudaDeviceInterface::flush() { + if (cpuFallback_) { + cpuFallback_->flush(); + return; + } + // The NVCUVID docs mention that after seeking, i.e. when flush() is called, // we should send a packet with the CUVID_PKT_DISCONTINUITY flag. The docs // don't say whether this should be an empty packet, or whether it should be a @@ -618,8 +672,23 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - // TODONVDEC P2: we may need to handle 10bit videos the same way the default - // interface does it with maybeConvertAVFrameToNV12OrRGB24(). + if (cpuFallback_) { + // CPU decoded frame - need to do CPU color conversion then transfer to GPU + FrameOutput cpuFrameOutput; + cpuFallback_->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); + + // Transfer CPU frame to GPU + if (preAllocatedOutputTensor.has_value()) { + preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data); + frameOutput.data = preAllocatedOutputTensor.value(); + } else { + frameOutput.data = cpuFrameOutput.data.to(device_); + } + return; + } + + // TODONVDEC P2: we may need to handle 10bit videos the same way the CUDA + // ffmpeg interface does it with maybeConvertAVFrameToNV12OrRGB24(). TORCH_CHECK( avFrame->format == AV_PIX_FMT_CUDA, "Expected CUDA format frame from BETA CUDA interface"); @@ -633,4 +702,17 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); } +std::string BetaCudaDeviceInterface::getDetails() { + std::string details = "Beta CUDA Device Interface."; + if (cpuFallback_) { + details += " Using CPU fallback."; + if (!nvcuvidAvailable_) { + details += " NVCUVID not available!"; + } + } else { + details += " Using NVDEC."; + } + return details; +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index 0bf9951d6..29511df50 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -40,7 +40,8 @@ class BetaCudaDeviceInterface : public DeviceInterface { void initialize( const AVStream* avStream, - const UniqueDecodingAVFormatContext& avFormatCtx) override; + const UniqueDecodingAVFormatContext& avFormatCtx, + const SharedAVCodecContext& codecContext) override; void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, @@ -48,10 +49,6 @@ class BetaCudaDeviceInterface : public DeviceInterface { std::optional preAllocatedOutputTensor = std::nullopt) override; - bool canDecodePacketDirectly() const override { - return true; - } - int sendPacket(ReferenceAVPacket& packet) override; int sendEOFPacket() override; int receiveFrame(UniqueAVFrame& avFrame) override; @@ -62,6 +59,8 @@ class BetaCudaDeviceInterface : public DeviceInterface { int frameReadyForDecoding(CUVIDPICPARAMS* picParams); int frameReadyInDisplayOrder(CUVIDPARSERDISPINFO* dispInfo); + std::string getDetails() override; + private: int sendCuvidPacket(CUVIDSOURCEDATAPACKET& cuvidPacket); @@ -97,6 +96,9 @@ class BetaCudaDeviceInterface : public DeviceInterface { // NPP context for color conversion UniqueNppContext nppCtx_; + + std::unique_ptr cpuFallback_; + bool nvcuvidAvailable_ = false; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 75d1b036c..6b4ccb5d4 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -99,7 +99,7 @@ function(make_torchcodec_libraries ) if(ENABLE_CUDA) - list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp) + list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp NVCUVIDRuntimeLoader.cpp) endif() set(core_library_dependencies @@ -108,27 +108,9 @@ function(make_torchcodec_libraries ) if(ENABLE_CUDA) - # Try to find NVCUVID. Try the normal way first. This should work locally. - find_library(NVCUVID_LIBRARY NAMES nvcuvid) - # If not found, try with version suffix, or hardcoded path. Appears - # to be necessary on the CI. - if(NOT NVCUVID_LIBRARY) - find_library(NVCUVID_LIBRARY NAMES nvcuvid.1 PATHS /usr/lib64 /usr/lib) - endif() - if(NOT NVCUVID_LIBRARY) - set(NVCUVID_LIBRARY "/usr/lib64/libnvcuvid.so.1") - endif() - - if(NVCUVID_LIBRARY) - message(STATUS "Found NVCUVID: ${NVCUVID_LIBRARY}") - else() - message(FATAL_ERROR "Could not find NVCUVID library") - endif() - list(APPEND core_library_dependencies ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} - ${NVCUVID_LIBRARY} ) endif() diff --git a/src/torchcodec/_core/CUDACommon.cpp b/src/torchcodec/_core/CUDACommon.cpp index 4f3664031..4532e3c76 100644 --- a/src/torchcodec/_core/CUDACommon.cpp +++ b/src/torchcodec/_core/CUDACommon.cpp @@ -5,14 +5,12 @@ // LICENSE file in the root directory of this source tree. #include "src/torchcodec/_core/CUDACommon.h" +#include "src/torchcodec/_core/Cache.h" // for PerGpuCache namespace facebook::torchcodec { namespace { -// Pytorch can only handle up to 128 GPUs. -// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44 -const int MAX_CUDA_GPUS = 128; // Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching. // Set to a positive number to have a cache of that size. const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; @@ -249,7 +247,7 @@ torch::Tensor convertNV12FrameToRGB( } UniqueNppContext getNppStreamContext(const torch::Device& device) { - torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device); + int deviceIndex = getDeviceIndex(device); UniqueNppContext nppCtx = g_cached_npp_ctxs.get(device); if (nppCtx) { @@ -266,13 +264,13 @@ UniqueNppContext getNppStreamContext(const torch::Device& device) { nppCtx = std::make_unique(); cudaDeviceProp prop{}; - cudaError_t err = cudaGetDeviceProperties(&prop, nonNegativeDeviceIndex); + cudaError_t err = cudaGetDeviceProperties(&prop, deviceIndex); TORCH_CHECK( err == cudaSuccess, "cudaGetDeviceProperties failed: ", cudaGetErrorString(err)); - nppCtx->nCudaDeviceId = nonNegativeDeviceIndex; + nppCtx->nCudaDeviceId = deviceIndex; nppCtx->nMultiProcessorCount = prop.multiProcessorCount; nppCtx->nMaxThreadsPerMultiProcessor = prop.maxThreadsPerMultiProcessor; nppCtx->nMaxThreadsPerBlock = prop.maxThreadsPerBlock; @@ -312,4 +310,21 @@ void validatePreAllocatedTensorShape( } } +int getDeviceIndex(const torch::Device& device) { + // PyTorch uses int8_t as its torch::DeviceIndex, but FFmpeg and CUDA + // libraries use int. So we use int, too. + int deviceIndex = static_cast(device.index()); + TORCH_CHECK( + deviceIndex >= -1 && deviceIndex < MAX_CUDA_GPUS, + "Invalid device index = ", + deviceIndex); + + if (deviceIndex == -1) { + TORCH_CHECK( + cudaGetDevice(&deviceIndex) == cudaSuccess, + "Failed to get current CUDA device."); + } + return deviceIndex; +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CUDACommon.h b/src/torchcodec/_core/CUDACommon.h index b935cd4bf..588f60e49 100644 --- a/src/torchcodec/_core/CUDACommon.h +++ b/src/torchcodec/_core/CUDACommon.h @@ -11,7 +11,6 @@ #include #include -#include "src/torchcodec/_core/Cache.h" #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/Frame.h" @@ -22,6 +21,10 @@ extern "C" { namespace facebook::torchcodec { +// Pytorch can only handle up to 128 GPUs. +// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44 +constexpr int MAX_CUDA_GPUS = 128; + void initializeCudaContextWithPytorch(const torch::Device& device); // Unique pointer type for NPP stream context @@ -43,4 +46,6 @@ void validatePreAllocatedTensorShape( const std::optional& preAllocatedOutputTensor, const UniqueAVFrame& avFrame); +int getDeviceIndex(const torch::Device& device); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Cache.h b/src/torchcodec/_core/Cache.h index 7b088a145..b2c93e8ea 100644 --- a/src/torchcodec/_core/Cache.h +++ b/src/torchcodec/_core/Cache.h @@ -95,30 +95,16 @@ class PerGpuCache { std::vector>> cache_; }; -// Note: this function is inline for convenience, not performance. Because the -// rest of this file is template functions, they must all be defined in this -// header. This function is not a template function, and should, in principle, -// be defined in a .cpp file to preserve the One Definition Rule. That's -// annoying for such a small amount of code, so we just inline it. If this file -// grows, and there are more such functions, we should break them out into a -// .cpp file. -inline torch::DeviceIndex getNonNegativeDeviceIndex( - const torch::Device& device) { - torch::DeviceIndex deviceIndex = device.index(); - // For single GPU machines libtorch returns -1 for the device index. So for - // that case we set the device index to 0. That's used in per-gpu cache - // implementation and during initialization of CUDA and FFmpeg contexts - // which require non negative indices. - deviceIndex = std::max(deviceIndex, 0); - TORCH_CHECK(deviceIndex >= 0, "Device index out of range"); - return deviceIndex; -} +// Forward declaration of getDeviceIndex which exists in CUDACommon.h +// This avoids circular dependency between Cache.h and CUDACommon.cpp which also +// needs to include Cache.h +int getDeviceIndex(const torch::Device& device); template bool PerGpuCache::addIfCacheHasCapacity( const torch::Device& device, element_type&& obj) { - torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device); + int deviceIndex = getDeviceIndex(device); TORCH_CHECK( static_cast(deviceIndex) < cache_.size(), "Device index out of range"); @@ -128,7 +114,7 @@ bool PerGpuCache::addIfCacheHasCapacity( template typename PerGpuCache::element_type PerGpuCache::get( const torch::Device& device) { - torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device); + int deviceIndex = getDeviceIndex(device); TORCH_CHECK( static_cast(deviceIndex) < cache_.size(), "Device index out of range"); diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index e6b96e3e4..5aa20b09e 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -48,8 +48,10 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) void CpuDeviceInterface::initialize( const AVStream* avStream, - [[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx) { + [[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx, + const SharedAVCodecContext& codecContext) { TORCH_CHECK(avStream != nullptr, "avStream is null"); + codecContext_ = codecContext; timeBase_ = avStream->time_base; } @@ -344,4 +346,8 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( return rgbAVFrameToTensor(filterGraph_->convert(avFrame)); } +std::string CpuDeviceInterface::getDetails() { + return std::string("CPU Device Interface."); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 399b0c6be..3f6f7c962 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -25,7 +25,8 @@ class CpuDeviceInterface : public DeviceInterface { virtual void initialize( const AVStream* avStream, - const UniqueDecodingAVFormatContext& avFormatCtx) override; + const UniqueDecodingAVFormatContext& avFormatCtx, + const SharedAVCodecContext& codecContext) override; virtual void initializeVideo( const VideoStreamOptions& videoStreamOptions, @@ -38,6 +39,8 @@ class CpuDeviceInterface : public DeviceInterface { std::optional preAllocatedOutputTensor = std::nullopt) override; + std::string getDetails() override; + private: int convertAVFrameToTensorUsingSwScale( const UniqueAVFrame& avFrame, diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index aea2b2d9a..be45050e6 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -32,9 +32,6 @@ static bool g_cuda = registerDeviceInterface( // from // the cache. If the cache is empty we create a new cuda context. -// Pytorch can only handle up to 128 GPUs. -// https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44 -const int MAX_CUDA_GPUS = 128; // Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching. // Set to a positive number to have a cache of that size. const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; @@ -54,7 +51,7 @@ int getFlagsAVHardwareDeviceContextCreate() { UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) { enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); - torch::DeviceIndex nonNegativeDeviceIndex = getNonNegativeDeviceIndex(device); + int deviceIndex = getDeviceIndex(device); UniqueAVBufferRef hardwareDeviceCtx = g_cached_hw_device_ctxs.get(device); if (hardwareDeviceCtx) { @@ -63,14 +60,12 @@ UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) { // Create hardware device context c10::cuda::CUDAGuard deviceGuard(device); - // Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1: - // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb - // So we ensure the deviceIndex is not negative. // We set the device because we may be called from a different thread than // the one that initialized the cuda context. - cudaSetDevice(nonNegativeDeviceIndex); + TORCH_CHECK( + cudaSetDevice(deviceIndex) == cudaSuccess, "Failed to set CUDA device"); AVBufferRef* hardwareDeviceCtxRaw = nullptr; - std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex); + std::string deviceOrdinal = std::to_string(deviceIndex); int err = av_hwdevice_ctx_create( &hardwareDeviceCtxRaw, @@ -117,15 +112,17 @@ CudaDeviceInterface::~CudaDeviceInterface() { void CudaDeviceInterface::initialize( const AVStream* avStream, - const UniqueDecodingAVFormatContext& avFormatCtx) { + const UniqueDecodingAVFormatContext& avFormatCtx, + const SharedAVCodecContext& codecContext) { TORCH_CHECK(avStream != nullptr, "avStream is null"); + codecContext_ = codecContext; timeBase_ = avStream->time_base; // TODO: Ideally, we should keep all interface implementations independent. cpuInterface_ = createDeviceInterface(torch::kCPU); TORCH_CHECK( cpuInterface_ != nullptr, "Failed to create CPU device interface"); - cpuInterface_->initialize(avStream, avFormatCtx); + cpuInterface_->initialize(avStream, avFormatCtx, codecContext); cpuInterface_->initializeVideo( VideoStreamOptions(), {}, @@ -287,9 +284,12 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( frameOutput.data = cpuFrameOutput.data.to(device_); } + usingCPUFallback_ = true; return; } + usingCPUFallback_ = false; + // Above we checked that the AVFrame was on GPU, but that's not enough, we // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits), // because this is what the NPP color conversion routines expect. This SHOULD @@ -354,4 +354,12 @@ std::optional CudaDeviceInterface::findCodec( return std::nullopt; } +std::string CudaDeviceInterface::getDetails() { + // Note: for this interface specifically the fallback is only known after a + // frame has been decoded, not before: that's when FFmpeg decides to fallback, + // so we can't know earlier. + return std::string("FFmpeg CUDA Device Interface. Using ") + + (usingCPUFallback_ ? "CPU fallback." : "NVDEC."); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 1a8f184ec..9f171ee3c 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -22,7 +22,8 @@ class CudaDeviceInterface : public DeviceInterface { void initialize( const AVStream* avStream, - const UniqueDecodingAVFormatContext& avFormatCtx) override; + const UniqueDecodingAVFormatContext& avFormatCtx, + const SharedAVCodecContext& codecContext) override; void initializeVideo( const VideoStreamOptions& videoStreamOptions, @@ -39,6 +40,8 @@ class CudaDeviceInterface : public DeviceInterface { std::optional preAllocatedOutputTensor = std::nullopt) override; + std::string getDetails() override; + private: // Our CUDA decoding code assumes NV12 format. In order to handle other // kinds of input, we need to convert them to NV12. Our current implementation @@ -59,6 +62,8 @@ class CudaDeviceInterface : public DeviceInterface { // maybeConvertAVFrameToNV12(). std::unique_ptr nv12ConversionContext_; std::unique_ptr nv12Conversion_; + + bool usingCPUFallback_ = false; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index cac29e838..773317e83 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -21,7 +21,7 @@ namespace facebook::torchcodec { // Key for device interface registration with device type + variant support struct DeviceInterfaceKey { torch::DeviceType deviceType; - std::string_view variant = "default"; // e.g., "default", "beta", etc. + std::string_view variant = "ffmpeg"; // e.g., "ffmpeg", "beta", etc. bool operator<(const DeviceInterfaceKey& other) const { if (deviceType != other.deviceType) { @@ -54,7 +54,8 @@ class DeviceInterface { // Initialize the device with parameters generic to all kinds of decoding. virtual void initialize( const AVStream* avStream, - const UniqueDecodingAVFormatContext& avFormatCtx) = 0; + const UniqueDecodingAVFormatContext& avFormatCtx, + const SharedAVCodecContext& codecContext) = 0; // Initialize the device with parameters specific to video decoding. There is // a default empty implementation. @@ -80,52 +81,51 @@ class DeviceInterface { // Extension points for custom decoding paths // ------------------------------------------ - // Override to return true if this device interface can decode packets - // directly. This means that the following two member functions can both - // be called: - // - // 1. sendPacket() - // 2. receiveFrame() - virtual bool canDecodePacketDirectly() const { - return false; - } - - // Moral equivalent of avcodec_send_packet() // Returns AVSUCCESS on success, AVERROR(EAGAIN) if decoder queue full, or // other AVERROR on failure - virtual int sendPacket([[maybe_unused]] ReferenceAVPacket& avPacket) { + // Default implementation uses FFmpeg directly + virtual int sendPacket(ReferenceAVPacket& avPacket) { TORCH_CHECK( - false, - "Send/receive packet decoding not implemented for this device interface"); - return AVERROR(ENOSYS); + codecContext_ != nullptr, + "Codec context not available for default packet sending"); + return avcodec_send_packet(codecContext_.get(), avPacket.get()); } // Send an EOF packet to flush the decoder // Returns AVSUCCESS on success, or other AVERROR on failure + // Default implementation uses FFmpeg directly virtual int sendEOFPacket() { TORCH_CHECK( - false, "Send EOF packet not implemented for this device interface"); - return AVERROR(ENOSYS); + codecContext_ != nullptr, + "Codec context not available for default EOF packet sending"); + return avcodec_send_packet(codecContext_.get(), nullptr); } - // Moral equivalent of avcodec_receive_frame() // Returns AVSUCCESS on success, AVERROR(EAGAIN) if no frame ready, // AVERROR_EOF if end of stream, or other AVERROR on failure - virtual int receiveFrame([[maybe_unused]] UniqueAVFrame& avFrame) { + // Default implementation uses FFmpeg directly + virtual int receiveFrame(UniqueAVFrame& avFrame) { TORCH_CHECK( - false, - "Send/receive packet decoding not implemented for this device interface"); - return AVERROR(ENOSYS); + codecContext_ != nullptr, + "Codec context not available for default frame receiving"); + return avcodec_receive_frame(codecContext_.get(), avFrame.get()); } // Flush remaining frames from decoder virtual void flush() { - // Default implementation is no-op for standard decoders - // Custom decoders can override this method + TORCH_CHECK( + codecContext_ != nullptr, + "Codec context not available for default flushing"); + avcodec_flush_buffers(codecContext_.get()); + } + + virtual std::string getDetails() { + return ""; } protected: torch::Device device_; + SharedAVCodecContext codecContext_; }; using CreateDeviceInterfaceFn = @@ -141,7 +141,7 @@ void validateDeviceInterface( std::unique_ptr createDeviceInterface( const torch::Device& device, - const std::string_view variant = "default"); + const std::string_view variant = "ffmpeg"); torch::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame); diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 14ef1cb94..1d9c2c089 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -4,10 +4,6 @@ #include "src/torchcodec/_core/Encoder.h" #include "torch/types.h" -extern "C" { -#include -} - namespace facebook::torchcodec { namespace { @@ -542,10 +538,17 @@ torch::Tensor validateFrames(const torch::Tensor& frames) { } // namespace VideoEncoder::~VideoEncoder() { + // TODO-VideoEncoder: Unify destructor with ~AudioEncoder() if (avFormatContext_ && avFormatContext_->pb) { - avio_flush(avFormatContext_->pb); - avio_close(avFormatContext_->pb); - avFormatContext_->pb = nullptr; + if (avFormatContext_->pb->error == 0) { + avio_flush(avFormatContext_->pb); + } + if (!avioContextHolder_) { + if (avFormatContext_->pb->error == 0) { + avio_close(avFormatContext_->pb); + } + avFormatContext_->pb = nullptr; + } } } @@ -581,6 +584,36 @@ VideoEncoder::VideoEncoder( initializeEncoder(videoStreamOptions); } +VideoEncoder::VideoEncoder( + const torch::Tensor& frames, + int frameRate, + std::string_view formatName, + std::unique_ptr avioContextHolder, + const VideoStreamOptions& videoStreamOptions) + : frames_(validateFrames(frames)), + inFrameRate_(frameRate), + avioContextHolder_(std::move(avioContextHolder)) { + setFFmpegLogLevel(); + // Map mkv -> matroska when used as format name + formatName = (formatName == "mkv") ? "matroska" : formatName; + AVFormatContext* avFormatContext = nullptr; + int status = avformat_alloc_output_context2( + &avFormatContext, nullptr, formatName.data(), nullptr); + + TORCH_CHECK( + avFormatContext != nullptr, + "Couldn't allocate AVFormatContext. ", + "Check the desired format? Got format=", + formatName, + ". ", + getFFMPEGErrorStringFromErrorCode(status)); + avFormatContext_.reset(avFormatContext); + + avFormatContext_->pb = avioContextHolder_->getAVIOContext(); + + initializeEncoder(videoStreamOptions); +} + void VideoEncoder::initializeEncoder( const VideoStreamOptions& videoStreamOptions) { const AVCodec* avCodec = @@ -751,6 +784,17 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame( return avFrame; } +torch::Tensor VideoEncoder::encodeToTensor() { + TORCH_CHECK( + avioContextHolder_ != nullptr, + "Cannot encode to tensor, avio tensor context doesn't exist."); + encode(); + auto avioToTensorContext = + dynamic_cast(avioContextHolder_.get()); + TORCH_CHECK(avioToTensorContext != nullptr, "Invalid AVIO context holder."); + return avioToTensorContext->getOutputTensor(); +} + void VideoEncoder::encodeFrame( AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame) { diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 62d30a624..168591616 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -141,8 +141,17 @@ class VideoEncoder { std::string_view fileName, const VideoStreamOptions& videoStreamOptions); + VideoEncoder( + const torch::Tensor& frames, + int frameRate, + std::string_view formatName, + std::unique_ptr avioContextHolder, + const VideoStreamOptions& videoStreamOptions); + void encode(); + torch::Tensor encodeToTensor(); + private: void initializeEncoder(const VideoStreamOptions& videoStreamOptions); UniqueAVFrame convertTensorToAVFrame( @@ -153,7 +162,7 @@ class VideoEncoder { UniqueEncodingAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_; - AVStream* avStream_; + AVStream* avStream_ = nullptr; UniqueSwsContext swsContext_; const torch::Tensor frames_; @@ -167,6 +176,8 @@ class VideoEncoder { int outHeight_ = -1; AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE; + std::unique_ptr avioContextHolder_; + bool encodeWasCalled_ = false; }; diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 0570f06cf..97ff082e1 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -149,7 +149,7 @@ int getNumChannels(const UniqueAVFrame& avFrame) { #endif } -int getNumChannels(const UniqueAVCodecContext& avCodecContext) { +int getNumChannels(const SharedAVCodecContext& avCodecContext) { #if LIBAVFILTER_VERSION_MAJOR > 8 || \ (LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44) return avCodecContext->ch_layout.nb_channels; diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 19cddcc37..337616ddc 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -71,6 +71,14 @@ using UniqueEncodingAVFormatContext = std::unique_ptr< using UniqueAVCodecContext = std::unique_ptr< AVCodecContext, Deleterp>; +using SharedAVCodecContext = std::shared_ptr; + +// create SharedAVCodecContext with custom deleter +inline SharedAVCodecContext makeSharedAVCodecContext(AVCodecContext* ctx) { + return SharedAVCodecContext( + ctx, Deleterp{}); +} + using UniqueAVFrame = std::unique_ptr>; using UniqueAVFilterGraph = std::unique_ptr< @@ -171,7 +179,7 @@ const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec); const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec); int getNumChannels(const UniqueAVFrame& avFrame); -int getNumChannels(const UniqueAVCodecContext& avCodecContext); +int getNumChannels(const SharedAVCodecContext& avCodecContext); void setDefaultChannelLayout( UniqueAVCodecContext& avCodecContext, diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp index afc44d96d..605b814a8 100644 --- a/src/torchcodec/_core/FilterGraph.cpp +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -130,7 +130,8 @@ FilterGraph::FilterGraph( TORCH_CHECK( status >= 0, "Failed to configure filter graph: ", - getFFMPEGErrorStringFromErrorCode(status)); + getFFMPEGErrorStringFromErrorCode(status), + ", provided filters: " + filtersContext.filtergraphStr); } UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) { diff --git a/src/torchcodec/_core/Frame.cpp b/src/torchcodec/_core/Frame.cpp index 9fa87a1cb..62fb46c65 100644 --- a/src/torchcodec/_core/Frame.cpp +++ b/src/torchcodec/_core/Frame.cpp @@ -8,6 +8,11 @@ namespace facebook::torchcodec { +FrameDims::FrameDims(int height, int width) : height(height), width(width) { + TORCH_CHECK(height > 0, "FrameDims.height must be > 0, got: ", height); + TORCH_CHECK(width > 0, "FrameDims.width must be > 0, got: ", width); +} + FrameBatchOutput::FrameBatchOutput( int64_t numFrames, const FrameDims& outputDims, diff --git a/src/torchcodec/_core/Frame.h b/src/torchcodec/_core/Frame.h index 4b27d5bdd..67e4d2b79 100644 --- a/src/torchcodec/_core/Frame.h +++ b/src/torchcodec/_core/Frame.h @@ -19,7 +19,7 @@ struct FrameDims { FrameDims() = default; - FrameDims(int h, int w) : height(h), width(w) {} + FrameDims(int h, int w); }; // All public video decoding entry points return either a FrameOutput or a diff --git a/src/torchcodec/_core/NVCUVIDRuntimeLoader.cpp b/src/torchcodec/_core/NVCUVIDRuntimeLoader.cpp new file mode 100644 index 000000000..2bb501fc2 --- /dev/null +++ b/src/torchcodec/_core/NVCUVIDRuntimeLoader.cpp @@ -0,0 +1,320 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#ifdef FBCODE_CAFFE2 +// No need to do anything on fbcode. NVCUVID is available there, we can take a +// hard dependency on it. +// The FBCODE_CAFFE2 macro is defined in the upstream fbcode build of torch, so +// we can rely on it, that's what torch does too. + +namespace facebook::torchcodec { +bool loadNVCUVIDLibrary() { + return true; +} +} // namespace facebook::torchcodec +#else + +#include "src/torchcodec/_core/NVCUVIDRuntimeLoader.h" + +#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h" +#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h" + +#include +#include +#include + +#if defined(WIN64) || defined(_WIN64) +#include +typedef HMODULE tHandle; +#else +#include +typedef void* tHandle; +#endif + +namespace facebook::torchcodec { + +/* clang-format off */ +// This file defines the logic to load the NVCUVID library **at runtime**, +// along with the corresponding NVCUVID functions that we'll need. +// +// We do this because we *do not want* to link (statically or dynamically) +// against libnvcuvid.so: it is not always available on the users machine! If we +// were to link against libnvcuvid.so, that would mean that our +// libtorchcodec_coreN.so would try to look for it when loaded at import time. +// And if it's not on the users machine, that causes `import torchcodec` to +// fail. Source: that's what we did, and we got user reports. +// +// So, we don't link against libnvcuvid.so. But we still want to call its +// functions. So here's how it's done, we'll use cuvidCreateVideoParser as an +// example, but it works the same for all. We are largely following the +// instructions from the NVCUVID docs: +// https://docs.nvidia.com/video-technologies/video-codec-sdk/13.0/nvdec-video-decoder-api-prog-guide/index.html#dynamic-loading-nvidia-components +// +// This: +// typedef CUresult CUDAAPI tcuvidCreateVideoParser(CUvideoparser*, CUVIDPARSERPARAMS*); +// defines tcuvidCreateVideoParser, which is the *type* of a *function*. +// We define such a function of that type just below with: +// static tcuvidCreateVideoParser* dl_cuvidCreateVideoParser = nullptr; +// "dl" is for "dynamically loaded. For now dl_cuvidCreateVideoParser is +// nullptr, but later it will be a proper function [pointer] that can be called +// with dl_cuvidCreateVideoParser(...); +// +// For that to happen we need to call loadNVCUVIDLibrary(): in there, we first +// dlopen(libnvcuvid.so) which loads the .so somewhere in memory. Then we call +// dlsym(...), which binds dl_cuvidCreateVideoParser to its actual address: it +// literally sets the value of the dl_cuvidCreateVideoParser pointer to the +// address of the actual code section. If all went well, by now, we can safely +// call dl_cuvidCreateVideoParser(...); +// All of that happens at runtime *after* import time, when the first instance +// of the Beta CUDA interface is created, i.e. only when the user explicitly +// requests it. +// +// At the bottom of this file we have an `extern "C"` section with function +// definitions like: +// +// CUresult CUDAAPI cuvidCreateVideoParser( +// CUvideoparser* videoParser, +// CUVIDPARSERPARAMS* parserParams) {...} +// +// These are the actual functions that are compiled against and called by the +// Beta CUDA interface code. Crucially, these functions signature match exactly +// the NVCUVID functions (as defined in cuviddec.h). Inside of +// cuvidCreateVideoParser(...) we simply call the dl_cuvidCreateVideoParser +// function [pointer] that we dynamically loaded earlier. +// +// At runtime, within the Beta CUDA interface code we have a fallback mechanism +// to switch back to the CPU backend if any of the NVCUVID functions are not +// available, or if libnvcuvid.so itself couldn't be found. This is what FFmpeg +// does too. + + +// Function pointers types +typedef CUresult CUDAAPI tcuvidCreateVideoParser(CUvideoparser*, CUVIDPARSERPARAMS*); +typedef CUresult CUDAAPI tcuvidParseVideoData(CUvideoparser, CUVIDSOURCEDATAPACKET*); +typedef CUresult CUDAAPI tcuvidDestroyVideoParser(CUvideoparser); +typedef CUresult CUDAAPI tcuvidGetDecoderCaps(CUVIDDECODECAPS*); +typedef CUresult CUDAAPI tcuvidCreateDecoder(CUvideodecoder*, CUVIDDECODECREATEINFO*); +typedef CUresult CUDAAPI tcuvidDestroyDecoder(CUvideodecoder); +typedef CUresult CUDAAPI tcuvidDecodePicture(CUvideodecoder, CUVIDPICPARAMS*); +typedef CUresult CUDAAPI tcuvidMapVideoFrame(CUvideodecoder, int, unsigned int*, unsigned int*, CUVIDPROCPARAMS*); +typedef CUresult CUDAAPI tcuvidUnmapVideoFrame(CUvideodecoder, unsigned int); +typedef CUresult CUDAAPI tcuvidMapVideoFrame64(CUvideodecoder, int, unsigned long long*, unsigned int*, CUVIDPROCPARAMS*); +typedef CUresult CUDAAPI tcuvidUnmapVideoFrame64(CUvideodecoder, unsigned long long); +/* clang-format on */ + +// Global function pointers - will be dynamically loaded +static tcuvidCreateVideoParser* dl_cuvidCreateVideoParser = nullptr; +static tcuvidParseVideoData* dl_cuvidParseVideoData = nullptr; +static tcuvidDestroyVideoParser* dl_cuvidDestroyVideoParser = nullptr; +static tcuvidGetDecoderCaps* dl_cuvidGetDecoderCaps = nullptr; +static tcuvidCreateDecoder* dl_cuvidCreateDecoder = nullptr; +static tcuvidDestroyDecoder* dl_cuvidDestroyDecoder = nullptr; +static tcuvidDecodePicture* dl_cuvidDecodePicture = nullptr; +static tcuvidMapVideoFrame* dl_cuvidMapVideoFrame = nullptr; +static tcuvidUnmapVideoFrame* dl_cuvidUnmapVideoFrame = nullptr; +static tcuvidMapVideoFrame64* dl_cuvidMapVideoFrame64 = nullptr; +static tcuvidUnmapVideoFrame64* dl_cuvidUnmapVideoFrame64 = nullptr; + +static tHandle g_nvcuvid_handle = nullptr; +static std::mutex g_nvcuvid_mutex; + +bool isLoaded() { + return ( + g_nvcuvid_handle && dl_cuvidCreateVideoParser && dl_cuvidParseVideoData && + dl_cuvidDestroyVideoParser && dl_cuvidGetDecoderCaps && + dl_cuvidCreateDecoder && dl_cuvidDestroyDecoder && + dl_cuvidDecodePicture && dl_cuvidMapVideoFrame && + dl_cuvidUnmapVideoFrame && dl_cuvidMapVideoFrame64 && + dl_cuvidUnmapVideoFrame64); +} + +template +T* bindFunction(const char* functionName) { +#if defined(WIN64) || defined(_WIN64) + return reinterpret_cast(GetProcAddress(g_nvcuvid_handle, functionName)); +#else + return reinterpret_cast(dlsym(g_nvcuvid_handle, functionName)); +#endif +} + +bool _loadLibrary() { + // Helper that just calls dlopen or equivalent on Windows. In a separate + // function because of the #ifdef uglyness. +#if defined(WIN64) || defined(_WIN64) +#ifdef UNICODE + static LPCWSTR nvcuvidDll = L"nvcuvid.dll"; +#else + static LPCSTR nvcuvidDll = "nvcuvid.dll"; +#endif + g_nvcuvid_handle = LoadLibrary(nvcuvidDll); + if (g_nvcuvid_handle == nullptr) { + return false; + } +#else + g_nvcuvid_handle = dlopen("libnvcuvid.so", RTLD_NOW); + if (g_nvcuvid_handle == nullptr) { + g_nvcuvid_handle = dlopen("libnvcuvid.so.1", RTLD_NOW); + } + if (g_nvcuvid_handle == nullptr) { + return false; + } +#endif + + return true; +} + +bool loadNVCUVIDLibrary() { + // Loads NVCUVID library and all required function pointers. + // Returns true on success, false on failure. + std::lock_guard lock(g_nvcuvid_mutex); + + if (isLoaded()) { + return true; + } + + if (!_loadLibrary()) { + return false; + } + + // Load all function pointers. They'll be set to nullptr if not found. + dl_cuvidCreateVideoParser = + bindFunction("cuvidCreateVideoParser"); + dl_cuvidParseVideoData = + bindFunction("cuvidParseVideoData"); + dl_cuvidDestroyVideoParser = + bindFunction("cuvidDestroyVideoParser"); + dl_cuvidGetDecoderCaps = + bindFunction("cuvidGetDecoderCaps"); + dl_cuvidCreateDecoder = + bindFunction("cuvidCreateDecoder"); + dl_cuvidDestroyDecoder = + bindFunction("cuvidDestroyDecoder"); + dl_cuvidDecodePicture = + bindFunction("cuvidDecodePicture"); + dl_cuvidMapVideoFrame = + bindFunction("cuvidMapVideoFrame"); + dl_cuvidUnmapVideoFrame = + bindFunction("cuvidUnmapVideoFrame"); + dl_cuvidMapVideoFrame64 = + bindFunction("cuvidMapVideoFrame64"); + dl_cuvidUnmapVideoFrame64 = + bindFunction("cuvidUnmapVideoFrame64"); + + return isLoaded(); +} + +} // namespace facebook::torchcodec + +extern "C" { + +CUresult CUDAAPI cuvidCreateVideoParser( + CUvideoparser* videoParser, + CUVIDPARSERPARAMS* parserParams) { + TORCH_CHECK( + facebook::torchcodec::dl_cuvidCreateVideoParser, + "cuvidCreateVideoParser called but NVCUVID not loaded!"); + return facebook::torchcodec::dl_cuvidCreateVideoParser( + videoParser, parserParams); +} + +CUresult CUDAAPI cuvidParseVideoData( + CUvideoparser videoParser, + CUVIDSOURCEDATAPACKET* cuvidPacket) { + TORCH_CHECK( + facebook::torchcodec::dl_cuvidParseVideoData, + "cuvidParseVideoData called but NVCUVID not loaded!"); + return facebook::torchcodec::dl_cuvidParseVideoData(videoParser, cuvidPacket); +} + +CUresult CUDAAPI cuvidDestroyVideoParser(CUvideoparser videoParser) { + TORCH_CHECK( + facebook::torchcodec::dl_cuvidDestroyVideoParser, + "cuvidDestroyVideoParser called but NVCUVID not loaded!"); + return facebook::torchcodec::dl_cuvidDestroyVideoParser(videoParser); +} + +CUresult CUDAAPI cuvidGetDecoderCaps(CUVIDDECODECAPS* caps) { + TORCH_CHECK( + facebook::torchcodec::dl_cuvidGetDecoderCaps, + "cuvidGetDecoderCaps called but NVCUVID not loaded!"); + return facebook::torchcodec::dl_cuvidGetDecoderCaps(caps); +} + +CUresult CUDAAPI cuvidCreateDecoder( + CUvideodecoder* decoder, + CUVIDDECODECREATEINFO* decoderParams) { + TORCH_CHECK( + facebook::torchcodec::dl_cuvidCreateDecoder, + "cuvidCreateDecoder called but NVCUVID not loaded!"); + return facebook::torchcodec::dl_cuvidCreateDecoder(decoder, decoderParams); +} + +CUresult CUDAAPI cuvidDestroyDecoder(CUvideodecoder decoder) { + TORCH_CHECK( + facebook::torchcodec::dl_cuvidDestroyDecoder, + "cuvidDestroyDecoder called but NVCUVID not loaded!"); + return facebook::torchcodec::dl_cuvidDestroyDecoder(decoder); +} + +CUresult CUDAAPI +cuvidDecodePicture(CUvideodecoder decoder, CUVIDPICPARAMS* picParams) { + TORCH_CHECK( + facebook::torchcodec::dl_cuvidDecodePicture, + "cuvidDecodePicture called but NVCUVID not loaded!"); + return facebook::torchcodec::dl_cuvidDecodePicture(decoder, picParams); +} + +#if !defined(__CUVID_DEVPTR64) || defined(__CUVID_INTERNAL) +// We need to protect the definition of the 32bit versions under the above +// conditions (see cuviddec.h). Defining them unconditionally would cause +// conflict compilation errors when cuviddec.h redefines those to the 64bit +// versions. +CUresult CUDAAPI cuvidMapVideoFrame( + CUvideodecoder decoder, + int pixIndex, + unsigned int* framePtr, + unsigned int* pitch, + CUVIDPROCPARAMS* procParams) { + TORCH_CHECK( + facebook::torchcodec::dl_cuvidMapVideoFrame, + "cuvidMapVideoFrame called but NVCUVID not loaded!"); + return facebook::torchcodec::dl_cuvidMapVideoFrame( + decoder, pixIndex, framePtr, pitch, procParams); +} + +CUresult CUDAAPI +cuvidUnmapVideoFrame(CUvideodecoder decoder, unsigned int framePtr) { + TORCH_CHECK( + facebook::torchcodec::dl_cuvidUnmapVideoFrame, + "cuvidUnmapVideoFrame called but NVCUVID not loaded!"); + return facebook::torchcodec::dl_cuvidUnmapVideoFrame(decoder, framePtr); +} +#endif + +CUresult CUDAAPI cuvidMapVideoFrame64( + CUvideodecoder decoder, + int pixIndex, + unsigned long long* framePtr, + unsigned int* pitch, + CUVIDPROCPARAMS* procParams) { + TORCH_CHECK( + facebook::torchcodec::dl_cuvidMapVideoFrame64, + "cuvidMapVideoFrame64 called but NVCUVID not loaded!"); + return facebook::torchcodec::dl_cuvidMapVideoFrame64( + decoder, pixIndex, framePtr, pitch, procParams); +} + +CUresult CUDAAPI +cuvidUnmapVideoFrame64(CUvideodecoder decoder, unsigned long long framePtr) { + TORCH_CHECK( + facebook::torchcodec::dl_cuvidUnmapVideoFrame64, + "cuvidUnmapVideoFrame64 called but NVCUVID not loaded!"); + return facebook::torchcodec::dl_cuvidUnmapVideoFrame64(decoder, framePtr); +} + +} // extern "C" + +#endif // FBCODE_CAFFE2 diff --git a/src/torchcodec/_core/NVCUVIDRuntimeLoader.h b/src/torchcodec/_core/NVCUVIDRuntimeLoader.h new file mode 100644 index 000000000..e6ee40a05 --- /dev/null +++ b/src/torchcodec/_core/NVCUVIDRuntimeLoader.h @@ -0,0 +1,14 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +namespace facebook::torchcodec { + +// See note in corresponding cpp file +bool loadNVCUVIDLibrary(); + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/NVDECCache.cpp b/src/torchcodec/_core/NVDECCache.cpp index 87ab5b0dc..302433cd4 100644 --- a/src/torchcodec/_core/NVDECCache.cpp +++ b/src/torchcodec/_core/NVDECCache.cpp @@ -7,6 +7,7 @@ #include #include +#include "src/torchcodec/_core/CUDACommon.h" #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/NVDECCache.h" @@ -19,20 +20,9 @@ extern "C" { namespace facebook::torchcodec { -NVDECCache& NVDECCache::getCache(int deviceIndex) { - const int MAX_CUDA_GPUS = 128; - TORCH_CHECK( - deviceIndex >= -1 && deviceIndex < MAX_CUDA_GPUS, - "Invalid device index = ", - deviceIndex); +NVDECCache& NVDECCache::getCache(const torch::Device& device) { static NVDECCache cacheInstances[MAX_CUDA_GPUS]; - if (deviceIndex == -1) { - // TODO NVDEC P3: Unify with existing getNonNegativeDeviceIndex() - TORCH_CHECK( - cudaGetDevice(&deviceIndex) == cudaSuccess, - "Failed to get current CUDA device."); - } - return cacheInstances[deviceIndex]; + return cacheInstances[getDeviceIndex(device)]; } UniqueCUvideodecoder NVDECCache::getDecoder(CUVIDEOFORMAT* videoFormat) { diff --git a/src/torchcodec/_core/NVDECCache.h b/src/torchcodec/_core/NVDECCache.h index 17fc99902..a0f2fb862 100644 --- a/src/torchcodec/_core/NVDECCache.h +++ b/src/torchcodec/_core/NVDECCache.h @@ -11,6 +11,9 @@ #include #include +#include + +#include "src/torchcodec/_core/NVCUVIDRuntimeLoader.h" #include "src/torchcodec/_core/nvcuvid_include/cuviddec.h" #include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h" @@ -36,7 +39,7 @@ using UniqueCUvideodecoder = // per GPU device, and it is accessed through the static getCache() method. class NVDECCache { public: - static NVDECCache& getCache(int deviceIndex); + static NVDECCache& getCache(const torch::Device& device); // Get decoder from cache - returns nullptr if none available UniqueCUvideodecoder getDecoder(CUVIDEOFORMAT* videoFormat); @@ -68,11 +71,6 @@ class NVDECCache { CacheKey(const CacheKey&) = default; CacheKey& operator=(const CacheKey&) = default; - // TODONVDEC P2: we only implement operator< which is enough for std::map, - // but: - // - we should consider using std::unordered_map - // - we should consider a more sophisticated and potentially less strict - // cache key comparison logic bool operator<(const CacheKey& other) const { return std::tie( codecType, diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index d06c47922..8d9e9f651 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -12,6 +12,7 @@ #include #include #include +#include "Metadata.h" #include "torch/types.h" namespace facebook::torchcodec { @@ -429,7 +430,6 @@ void SingleStreamDecoder::addStream( TORCH_CHECK( deviceInterface_ != nullptr, "Failed to create device interface. This should never happen, please report."); - deviceInterface_->initialize(streamInfo.stream, formatContext_); // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within // addStream() which is supposed to be generic @@ -441,7 +441,7 @@ void SingleStreamDecoder::addStream( AVCodecContext* codecContext = avcodec_alloc_context3(avCodec); TORCH_CHECK(codecContext != nullptr); - streamInfo.codecContext.reset(codecContext); + streamInfo.codecContext = makeSharedAVCodecContext(codecContext); int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); @@ -453,14 +453,19 @@ void SingleStreamDecoder::addStream( // Note that we must make sure to register the harware device context // with the codec context before calling avcodec_open2(). Otherwise, decoding // will happen on the CPU and not the hardware device. - deviceInterface_->registerHardwareDeviceWithCodec(codecContext); + deviceInterface_->registerHardwareDeviceWithCodec( + streamInfo.codecContext.get()); retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal)); - codecContext->time_base = streamInfo.stream->time_base; + streamInfo.codecContext->time_base = streamInfo.stream->time_base; + + // Initialize the device interface with the codec context + deviceInterface_->initialize( + streamInfo.stream, formatContext_, streamInfo.codecContext); containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName = - std::string(avcodec_get_name(codecContext->codec_id)); + std::string(avcodec_get_name(streamInfo.codecContext->codec_id)); // We will only need packets from the active stream, so we tell FFmpeg to // discard packets from the other streams. Note that av_read_frame() may still @@ -523,6 +528,7 @@ void SingleStreamDecoder::addVideoStream( if (transform->getOutputFrameDims().has_value()) { resizedOutputDims_ = transform->getOutputFrameDims().value(); } + transform->validate(streamMetadata); // Note that we are claiming ownership of the transform objects passed in to // us. @@ -1149,8 +1155,6 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() { getFFMPEGErrorStringFromErrorCode(status)); decodeStats_.numFlushes++; - avcodec_flush_buffers(streamInfo.codecContext.get()); - deviceInterface_->flush(); } @@ -1169,24 +1173,16 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( cursorWasJustSet_ = false; } - StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; UniqueAVFrame avFrame(av_frame_alloc()); AutoAVPacket autoAVPacket; int status = AVSUCCESS; bool reachedEOF = false; - // TODONVDEC P2: Instead of calling canDecodePacketDirectly() and rely on - // if/else blocks to dispatch to the interface or to FFmpeg, consider *always* - // dispatching to the interface. The default implementation of the interface's - // receiveFrame and sendPacket could just be calling avcodec_receive_frame and - // avcodec_send_packet. This would make the decoding loop even more generic. + // The default implementation uses avcodec_receive_frame and + // avcodec_send_packet, while specialized interfaces can override for + // hardware-specific optimizations. while (true) { - if (deviceInterface_->canDecodePacketDirectly()) { - status = deviceInterface_->receiveFrame(avFrame); - } else { - status = - avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get()); - } + status = deviceInterface_->receiveFrame(avFrame); if (status != AVSUCCESS && status != AVERROR(EAGAIN)) { // Non-retriable error @@ -1222,13 +1218,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( if (status == AVERROR_EOF) { // End of file reached. We must drain the decoder - if (deviceInterface_->canDecodePacketDirectly()) { - status = deviceInterface_->sendEOFPacket(); - } else { - status = avcodec_send_packet( - streamInfo.codecContext.get(), - /*avpkt=*/nullptr); - } + status = deviceInterface_->sendEOFPacket(); TORCH_CHECK( status >= AVSUCCESS, "Could not flush decoder: ", @@ -1253,11 +1243,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( // We got a valid packet. Send it to the decoder, and we'll receive it in // the next iteration. - if (deviceInterface_->canDecodePacketDirectly()) { - status = deviceInterface_->sendPacket(packet); - } else { - status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get()); - } + status = deviceInterface_->sendPacket(packet); TORCH_CHECK( status >= AVSUCCESS, "Could not push packet to decoder: ", @@ -1716,4 +1702,9 @@ double SingleStreamDecoder::getPtsSecondsForFrame(int64_t frameIndex) { streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase); } +std::string SingleStreamDecoder::getDeviceInterfaceDetails() const { + TORCH_CHECK(deviceInterface_ != nullptr, "Device interface doesn't exist."); + return deviceInterface_->getDetails(); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 48821ff09..4d4c11aa2 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -186,6 +186,8 @@ class SingleStreamDecoder { DecodeStats getDecodeStats() const; void resetDecodeStats(); + std::string getDeviceInterfaceDetails() const; + private: // -------------------------------------------------------------------------- // STREAMINFO AND ASSOCIATED STRUCTS @@ -221,7 +223,7 @@ class SingleStreamDecoder { AVMediaType avMediaType = AVMEDIA_TYPE_UNKNOWN; AVRational timeBase = {}; - UniqueAVCodecContext codecContext; + SharedAVCodecContext codecContext; // The FrameInfo indices we built when scanFileAndUpdateMetadataAndIndex was // called. @@ -311,7 +313,7 @@ class SingleStreamDecoder { int streamIndex, AVMediaType mediaType, const torch::Device& device = torch::kCPU, - const std::string_view deviceVariant = "default", + const std::string_view deviceVariant = "ffmpeg", std::optional ffmpegThreadCount = std::nullopt); // Returns the "best" stream index for a given media type. The "best" is diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index 7728a676e..e5ab256e1 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -41,8 +41,8 @@ struct VideoStreamOptions { // By default we use CPU for decoding for both C++ and python users. torch::Device device = torch::kCPU; - // Device variant (e.g., "default", "beta", etc.) - std::string_view deviceVariant = "default"; + // Device variant (e.g., "ffmpeg", "beta", etc.) + std::string_view deviceVariant = "ffmpeg"; // Encoding options // TODO-VideoEncoder: Consider adding other optional fields here diff --git a/src/torchcodec/_core/Transform.cpp b/src/torchcodec/_core/Transform.cpp index d0a5104f3..6083986e1 100644 --- a/src/torchcodec/_core/Transform.cpp +++ b/src/torchcodec/_core/Transform.cpp @@ -57,4 +57,31 @@ int ResizeTransform::getSwsFlags() const { return toSwsInterpolation(interpolationMode_); } +CropTransform::CropTransform(const FrameDims& dims, int x, int y) + : outputDims_(dims), x_(x), y_(y) { + TORCH_CHECK(x_ >= 0, "Crop x position must be >= 0, got: ", x_); + TORCH_CHECK(y_ >= 0, "Crop y position must be >= 0, got: ", y_); +} + +std::string CropTransform::getFilterGraphCpu() const { + return "crop=" + std::to_string(outputDims_.width) + ":" + + std::to_string(outputDims_.height) + ":" + std::to_string(x_) + ":" + + std::to_string(y_) + ":exact=1"; +} + +std::optional CropTransform::getOutputFrameDims() const { + return outputDims_; +} + +void CropTransform::validate(const StreamMetadata& streamMetadata) const { + TORCH_CHECK(x_ <= streamMetadata.width, "Crop x position out of bounds"); + TORCH_CHECK( + x_ + outputDims_.width <= streamMetadata.width, + "Crop x position out of bounds") + TORCH_CHECK(y_ <= streamMetadata.height, "Crop y position out of bounds"); + TORCH_CHECK( + y_ + outputDims_.height <= streamMetadata.height, + "Crop y position out of bounds"); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Transform.h b/src/torchcodec/_core/Transform.h index 6aea255ab..28d8c28a2 100644 --- a/src/torchcodec/_core/Transform.h +++ b/src/torchcodec/_core/Transform.h @@ -9,6 +9,7 @@ #include #include #include "src/torchcodec/_core/Frame.h" +#include "src/torchcodec/_core/Metadata.h" namespace facebook::torchcodec { @@ -33,6 +34,16 @@ class Transform { virtual bool isResize() const { return false; } + + // The validity of some transforms depends on the characteristics of the + // AVStream they're being applied to. For example, some transforms will + // specify coordinates inside a frame, we need to validate that those are + // within the frame's bounds. + // + // Note that the validation function does not return anything. We expect + // invalid configurations to throw an exception. + virtual void validate( + [[maybe_unused]] const StreamMetadata& streamMetadata) const {} }; class ResizeTransform : public Transform { @@ -56,4 +67,18 @@ class ResizeTransform : public Transform { InterpolationMode interpolationMode_; }; +class CropTransform : public Transform { + public: + CropTransform(const FrameDims& dims, int x, int y); + + std::string getFilterGraphCpu() const override; + std::optional getOutputFrameDims() const override; + void validate(const StreamMetadata& streamMetadata) const override; + + private: + FrameDims outputDims_; + int x_; + int y_; +}; + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index 24e54af0e..55ff697b3 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -14,6 +14,7 @@ ) from .ops import ( _add_video_stream, + _get_backend_details, _get_key_frame_indices, _test_frame_pts_equality, add_audio_stream, @@ -26,6 +27,8 @@ encode_audio_to_file_like, encode_audio_to_tensor, encode_video_to_file, + encode_video_to_file_like, + encode_video_to_tensor, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 5ba98e2c1..c6204de8c 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -32,20 +32,24 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); m.def( "encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); - m.def( - "encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()"); m.def( "encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor"); m.def( "_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); + m.def( + "encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()"); + m.def( + "encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor"); + m.def( + "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( "_create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor"); m.def( - "_add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"default\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()"); + "_add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()"); m.def( - "add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"default\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()"); + "add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()"); m.def( "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()"); m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); @@ -70,6 +74,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str"); m.def("_get_json_ffmpeg_library_versions() -> str"); + m.def("_get_backend_details(Tensor(a!) decoder) -> str"); m.def( "_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool"); m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()"); @@ -212,6 +217,26 @@ Transform* makeResizeTransform( return new ResizeTransform(FrameDims(height, width)); } +// Crop transform specs take the form: +// +// "crop, , , , " +// +// Where "crop" is the string literal and , , and are +// positive integers. and are the x and y coordinates of the top left +// corner of the crop. Note that we follow the PyTorch convention of (height, +// width) for specifying image dimensions; FFmpeg uses (width, height). +Transform* makeCropTransform( + const std::vector& cropTransformSpec) { + TORCH_CHECK( + cropTransformSpec.size() == 5, + "cropTransformSpec must have 5 elements including its name"); + int height = checkedToPositiveInt(cropTransformSpec[1]); + int width = checkedToPositiveInt(cropTransformSpec[2]); + int x = checkedToPositiveInt(cropTransformSpec[3]); + int y = checkedToPositiveInt(cropTransformSpec[4]); + return new CropTransform(FrameDims(height, width), x, y); +} + std::vector split(const std::string& str, char delimiter) { std::vector tokens; std::string token; @@ -239,6 +264,8 @@ std::vector makeTransforms(const std::string& transformSpecsRaw) { auto name = transformSpec[0]; if (name == "resize") { transforms.push_back(makeResizeTransform(transformSpec)); + } else if (name == "crop") { + transforms.push_back(makeCropTransform(transformSpec)); } else { TORCH_CHECK(false, "Invalid transform name: " + name); } @@ -319,7 +346,7 @@ void _add_video_stream( std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, std::string_view device = "cpu", - std::string_view device_variant = "default", + std::string_view device_variant = "ffmpeg", std::string_view transform_specs = "", std::optional> custom_frame_mappings = std::nullopt, @@ -376,7 +403,7 @@ void add_video_stream( std::optional dimension_order = std::nullopt, std::optional stream_index = std::nullopt, std::string_view device = "cpu", - std::string_view device_variant = "default", + std::string_view device_variant = "ffmpeg", std::string_view transform_specs = "", const std::optional>& custom_frame_mappings = std::nullopt) { @@ -498,21 +525,6 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio( return makeOpsAudioFramesOutput(result); } -void encode_video_to_file( - const at::Tensor& frames, - int64_t frame_rate, - std::string_view file_name, - std::optional crf = std::nullopt) { - VideoStreamOptions videoStreamOptions; - videoStreamOptions.crf = crf; - VideoEncoder( - frames, - validateInt64ToInt(frame_rate, "frame_rate"), - file_name, - videoStreamOptions) - .encode(); -} - void encode_audio_to_file( const at::Tensor& samples, int64_t sample_rate, @@ -587,6 +599,62 @@ void _encode_audio_to_file_like( encoder.encode(); } +void encode_video_to_file( + const at::Tensor& frames, + int64_t frame_rate, + std::string_view file_name, + std::optional crf = std::nullopt) { + VideoStreamOptions videoStreamOptions; + videoStreamOptions.crf = crf; + VideoEncoder( + frames, + validateInt64ToInt(frame_rate, "frame_rate"), + file_name, + videoStreamOptions) + .encode(); +} + +at::Tensor encode_video_to_tensor( + const at::Tensor& frames, + int64_t frame_rate, + std::string_view format, + std::optional crf = std::nullopt) { + auto avioContextHolder = std::make_unique(); + VideoStreamOptions videoStreamOptions; + videoStreamOptions.crf = crf; + return VideoEncoder( + frames, + validateInt64ToInt(frame_rate, "frame_rate"), + format, + std::move(avioContextHolder), + videoStreamOptions) + .encodeToTensor(); +} + +void _encode_video_to_file_like( + const at::Tensor& frames, + int64_t frame_rate, + std::string_view format, + int64_t file_like_context, + std::optional crf = std::nullopt) { + auto fileLikeContext = + reinterpret_cast(file_like_context); + TORCH_CHECK( + fileLikeContext != nullptr, "file_like_context must be a valid pointer"); + std::unique_ptr avioContextHolder(fileLikeContext); + + VideoStreamOptions videoStreamOptions; + videoStreamOptions.crf = crf; + + VideoEncoder encoder( + frames, + validateInt64ToInt(frame_rate, "frame_rate"), + format, + std::move(avioContextHolder), + videoStreamOptions); + encoder.encode(); +} + // For testing only. We need to implement this operation as a core library // function because what we're testing is round-tripping pts values as // double-precision floating point numbers from C++ to Python and back to C++. @@ -828,6 +896,11 @@ std::string _get_json_ffmpeg_library_versions() { return ss.str(); } +std::string get_backend_details(at::Tensor& decoder) { + auto videoDecoder = unwrapTensorToGetDecoder(decoder); + return videoDecoder->getDeviceInterfaceDetails(); +} + // Scans video packets to get more accurate metadata like frame count, exact // keyframe positions, etc. Exact keyframe positions are useful for efficient // accurate seeking. Note that this function reads the entire video but it does @@ -847,9 +920,11 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("encode_audio_to_file", &encode_audio_to_file); - m.impl("encode_video_to_file", &encode_video_to_file); m.impl("encode_audio_to_tensor", &encode_audio_to_tensor); m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like); + m.impl("encode_video_to_file", &encode_video_to_file); + m.impl("encode_video_to_tensor", &encode_video_to_tensor); + m.impl("_encode_video_to_file_like", &_encode_video_to_file_like); m.impl("seek_to_pts", &seek_to_pts); m.impl("add_video_stream", &add_video_stream); m.impl("_add_video_stream", &_add_video_stream); @@ -870,6 +945,8 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl( "scan_all_streams_to_update_metadata", &scan_all_streams_to_update_metadata); + + m.impl("_get_backend_details", &get_backend_details); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 44dc89e2b..32995c964 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -69,7 +69,7 @@ def load_torchcodec_shared_libraries(): raise RuntimeError( f"""Could not load libtorchcodec. Likely causes: 1. FFmpeg is not properly installed in your environment. We support - versions 4, 5, 6 and 7. + versions 4, 5, 6, and 7 on all platforms, and 8 on Mac and Linux. 2. The PyTorch version ({torch.__version__}) is not compatible with this version of TorchCodec. Refer to the version compatibility table: @@ -92,15 +92,21 @@ def load_torchcodec_shared_libraries(): encode_audio_to_file = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.encode_audio_to_file.default ) -encode_video_to_file = torch._dynamo.disallow_in_graph( - torch.ops.torchcodec_ns.encode_video_to_file.default -) encode_audio_to_tensor = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.encode_audio_to_tensor.default ) _encode_audio_to_file_like = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns._encode_audio_to_file_like.default ) +encode_video_to_file = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.encode_video_to_file.default +) +encode_video_to_tensor = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.encode_video_to_tensor.default +) +_encode_video_to_file_like = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns._encode_video_to_file_like.default +) create_from_tensor = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.create_from_tensor.default ) @@ -136,6 +142,7 @@ def load_torchcodec_shared_libraries(): _get_json_ffmpeg_library_versions = ( torch.ops.torchcodec_ns._get_json_ffmpeg_library_versions.default ) +_get_backend_details = torch.ops.torchcodec_ns._get_backend_details.default # ============================= @@ -200,6 +207,33 @@ def encode_audio_to_file_like( ) +def encode_video_to_file_like( + frames: torch.Tensor, + frame_rate: int, + format: str, + file_like: Union[io.RawIOBase, io.BufferedIOBase], + crf: Optional[int] = None, +) -> None: + """Encode video frames to a file-like object. + + Args: + frames: Video frames tensor + frame_rate: Frame rate in frames per second + format: Video format (e.g., "mp4", "mov", "mkv") + file_like: File-like object that supports write() and seek() methods + crf: Optional constant rate factor for encoding quality + """ + assert _pybind_ops is not None + + _encode_video_to_file_like( + frames, + frame_rate, + format, + _pybind_ops.create_file_like_context(file_like, True), # True means for writing + crf, + ) + + def get_frames_at_indices( decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, list[int]] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -254,16 +288,6 @@ def encode_audio_to_file_abstract( return -@register_fake("torchcodec_ns::encode_video_to_file") -def encode_video_to_file_abstract( - frames: torch.Tensor, - frame_rate: int, - filename: str, - crf: Optional[int] = None, -) -> None: - return - - @register_fake("torchcodec_ns::encode_audio_to_tensor") def encode_audio_to_tensor_abstract( samples: torch.Tensor, @@ -289,6 +313,37 @@ def _encode_audio_to_file_like_abstract( return +@register_fake("torchcodec_ns::encode_video_to_file") +def encode_video_to_file_abstract( + frames: torch.Tensor, + frame_rate: int, + filename: str, + crf: Optional[int], +) -> None: + return + + +@register_fake("torchcodec_ns::encode_video_to_tensor") +def encode_video_to_tensor_abstract( + frames: torch.Tensor, + frame_rate: int, + format: str, + crf: Optional[int], +) -> torch.Tensor: + return torch.empty([], dtype=torch.long) + + +@register_fake("torchcodec_ns::_encode_video_to_file_like") +def _encode_video_to_file_like_abstract( + frames: torch.Tensor, + frame_rate: int, + format: str, + file_like_context: int, + crf: Optional[int] = None, +) -> None: + return + + @register_fake("torchcodec_ns::create_from_tensor") def create_from_tensor_abstract( video_tensor: torch.Tensor, seek_mode: Optional[str] @@ -304,7 +359,7 @@ def _add_video_stream_abstract( dimension_order: Optional[str] = None, stream_index: Optional[int] = None, device: str = "cpu", - device_variant: str = "default", + device_variant: str = "ffmpeg", transform_specs: str = "", custom_frame_mappings: Optional[ tuple[torch.Tensor, torch.Tensor, torch.Tensor] @@ -322,7 +377,7 @@ def add_video_stream_abstract( dimension_order: Optional[str] = None, stream_index: Optional[int] = None, device: str = "cpu", - device_variant: str = "default", + device_variant: str = "ffmpeg", transform_specs: str = "", custom_frame_mappings: Optional[ tuple[torch.Tensor, torch.Tensor, torch.Tensor] @@ -496,3 +551,8 @@ def scan_all_streams_to_update_metadata_abstract(decoder: torch.Tensor) -> None: def get_ffmpeg_library_versions(): versions_json = _get_json_ffmpeg_library_versions() return json.loads(versions_json) + + +@register_fake("torchcodec_ns::_get_backend_details") +def _get_backend_details_abstract(decoder: torch.Tensor) -> str: + return "" diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 331c7ba79..130927c2e 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -147,16 +147,6 @@ def __init__( device = str(device) device_variant = _get_cuda_backend() - if device_variant == "ffmpeg": - # TODONVDEC P2 rename 'default' into 'ffmpeg' everywhere. - device_variant = "default" - - # Legacy support for device="cuda:0:beta" syntax - # TODONVDEC P2: remove support for this everywhere. This will require - # updating our tests. - if device == "cuda:0:beta": - device = "cuda:0" - device_variant = "beta" core.add_video_stream( self._decoder, diff --git a/test/generate_reference_resources.py b/test/generate_reference_resources.py index 5ae062111..953fb996e 100644 --- a/test/generate_reference_resources.py +++ b/test/generate_reference_resources.py @@ -6,17 +6,20 @@ import subprocess from pathlib import Path +from typing import Optional import numpy as np import torch from PIL import Image +from .utils import AV1_VIDEO, H265_VIDEO, NASA_VIDEO, TestVideo + # Run this script to update the resources used in unit tests. The resources are all derived # from source media already checked into the repo. -def convert_image_to_tensor(image_path): +def convert_image_to_tensor(image_path: str) -> None: image_path = Path(image_path) if not image_path.exists(): return @@ -31,26 +34,56 @@ def convert_image_to_tensor(image_path): image_path.unlink() -def get_frame_by_index(video_path, frame, output_path, stream): +def generate_frame_by_index( + video: TestVideo, + *, + frame_index: int, + stream_index: int, + filters: Optional[str] = None, +) -> None: + # Note that we are using 0-based index naming. As a result, we are + # generating files one-by-one, giving the actual file name that we want. + # ffmpeg does have an option to generate multiple files for us, but it uses + # 1-based indexing. We can't use 1-based indexing because we want to match + # the 0-based indexing in our tests. + base_path = video.get_base_path_by_index( + frame_index, stream_index=stream_index, filters=filters + ) + output_bmp = f"{base_path}.bmp" + + # Note that we have an exlicit format conversion to rgb24 in our filtergraph specification, + # which always happens BEFORE any of the filters that we receive as input. We do this to + # ensure that the color conversion happens BEFORE the filters, matching the behavior of the + # torchcodec filtergraph implementation. + # + # Not doing this would result in the color conversion happening AFTER the filters, which + # would result in different color values for the same frame. + filtergraph = f"select='eq(n\\,{frame_index})',format=rgb24" + if filters is not None: + filtergraph = filtergraph + f",{filters}" + cmd = [ "ffmpeg", "-y", "-i", - video_path, + video.path, "-map", - f"0:{stream}", + f"0:{stream_index}", "-vf", - f"select=eq(n\\,{frame})", - "-vsync", - "vfr", - "-q:v", - "2", - output_path, + filtergraph, + "-fps_mode", + "passthrough", + "-update", + "1", + output_bmp, ] subprocess.run(cmd, check=True) + convert_image_to_tensor(output_bmp) -def get_frame_by_timestamp(video_path, timestamp, output_path): +def generate_frame_by_timestamp( + video_path: str, timestamp: float, output_path: str +) -> None: cmd = [ "ffmpeg", "-y", @@ -63,60 +96,58 @@ def get_frame_by_timestamp(video_path, timestamp, output_path): output_path, ] subprocess.run(cmd, check=True) + convert_image_to_tensor(output_path) -def main(): - SCRIPT_DIR = Path(__file__).resolve().parent - TORCHCODEC_PATH = SCRIPT_DIR.parent - RESOURCES_DIR = TORCHCODEC_PATH / "test" / "resources" - VIDEO_PATH = RESOURCES_DIR / "nasa_13013.mp4" - - # Last generated with ffmpeg version 4.3 - # +def generate_nasa_13013_references(): # Note: The naming scheme used here must match the naming scheme used to load # tensors in ./utils.py. - STREAMS = [0, 3] - FRAMES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 20, 25, 30, 35, 386, 387, 388, 389] - for stream in STREAMS: - for frame in FRAMES: - # Note that we are using 0-based index naming. Asking ffmpeg to number output - # frames would result in 1-based index naming. We enforce 0-based index naming - # so that the name of reference frames matches the index when accessing that - # frame in the Python decoder. - output_bmp = f"{VIDEO_PATH}.stream{stream}.frame{frame:06d}.bmp" - get_frame_by_index(VIDEO_PATH, frame, output_bmp, stream=stream) - convert_image_to_tensor(output_bmp) + streams = [0, 3] + frames = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 20, 25, 30, 35, 386, 387, 388, 389] + for stream in streams: + for frame in frames: + generate_frame_by_index(NASA_VIDEO, frame_index=frame, stream_index=stream) # Extract individual frames at specific timestamps, including the last frame of the video. seek_timestamp = [6.0, 6.1, 10.0, 12.979633] timestamp_name = [f"{seek_timestamp:06f}" for seek_timestamp in seek_timestamp] for timestamp, name in zip(seek_timestamp, timestamp_name): - output_bmp = f"{VIDEO_PATH}.time{name}.bmp" - get_frame_by_timestamp(VIDEO_PATH, timestamp, output_bmp) - convert_image_to_tensor(output_bmp) + output_bmp = f"{NASA_VIDEO.path}.time{name}.bmp" + generate_frame_by_timestamp(NASA_VIDEO.path, timestamp, output_bmp) + + # Extract frames with specific filters. We have tests that assume these exact filters. + frames = [0, 15, 200, 389] + crop_filter = "crop=300:200:50:35:exact=1" + for frame in frames: + generate_frame_by_index( + NASA_VIDEO, frame_index=frame, stream_index=3, filters=crop_filter + ) + +def generate_h265_video_references(): # This video was generated by running the following: # conda install -c conda-forge x265 # ./configure --enable-nonfree --enable-gpl --prefix=$(readlink -f ../bin) --enable-libx265 --enable-rpath --extra-ldflags=-Wl,-rpath=$CONDA_PREFIX/lib --enable-filter=drawtext --enable-libfontconfig --enable-libfreetype --enable-libharfbuzz # ffmpeg -f lavfi -i color=size=128x128:duration=1:rate=10:color=blue -vf "drawtext=fontsize=30:fontcolor=white:x=(w-text_w)/2:y=(h-text_h)/2:text='Frame %{frame_num}'" -vcodec libx265 -pix_fmt yuv420p -g 2 -crf 10 h265_video.mp4 -y # Note that this video only has 1 stream, at index 0. - VIDEO_PATH = RESOURCES_DIR / "h265_video.mp4" - FRAMES = [5] - for frame in FRAMES: - output_bmp = f"{VIDEO_PATH}.stream0.frame{frame:06d}.bmp" - get_frame_by_index(VIDEO_PATH, frame, output_bmp, stream=0) - convert_image_to_tensor(output_bmp) + frames = [5] + for frame in frames: + generate_frame_by_index(H265_VIDEO, frame_index=frame, stream_index=0) + +def generate_av1_video_references(): # This video was generated by running the following: # ffmpeg -f lavfi -i testsrc=duration=5:size=640x360:rate=25,format=yuv420p -c:v libaom-av1 -crf 30 -colorspace bt709 -color_primaries bt709 -color_trc bt709 av1_video.mkv # Note that this video only has 1 stream, at index 0. - VIDEO_PATH = RESOURCES_DIR / "av1_video.mkv" - FRAMES = [10] + frames = [10] + for frame in frames: + generate_frame_by_index(AV1_VIDEO, frame_index=frame, stream_index=0) - for frame in FRAMES: - output_bmp = f"{VIDEO_PATH}.stream0.frame{frame:06d}.bmp" - get_frame_by_index(VIDEO_PATH, frame, output_bmp, stream=0) - convert_image_to_tensor(output_bmp) + +def main(): + generate_nasa_13013_references() + generate_h265_video_references() + generate_av1_video_references() if __name__ == "__main__": diff --git a/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000000.pt b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000000.pt new file mode 100644 index 000000000..c69af7cee Binary files /dev/null and b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000000.pt differ diff --git a/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000015.pt b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000015.pt new file mode 100644 index 000000000..c4f0bdc95 Binary files /dev/null and b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000015.pt differ diff --git a/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000200.pt b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000200.pt new file mode 100644 index 000000000..9849cfa99 Binary files /dev/null and b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000200.pt differ diff --git a/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000389.pt b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000389.pt new file mode 100644 index 000000000..3451f2a98 Binary files /dev/null and b/test/resources/nasa_13013.mp4.crop_300_200_50_35_exact_1.stream3.frame000389.pt differ diff --git a/test/test_decoders.py b/test/test_decoders.py index 300c953bf..6e08e05a4 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -35,6 +35,7 @@ H265_10BITS, H265_VIDEO, in_fbcode, + make_video_decoder, NASA_AUDIO, NASA_AUDIO_MP3, NASA_AUDIO_MP3_44100, @@ -51,7 +52,6 @@ TEST_SRC_2_720P_MPEG4, TEST_SRC_2_720P_VP8, TEST_SRC_2_720P_VP9, - unsplit_device_str, ) @@ -179,13 +179,12 @@ def test_create_fails(self): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_getitem_int(self, num_ffmpeg_threads, device, seek_mode): - decoder = VideoDecoder( + decoder, device = make_video_decoder( NASA_VIDEO.path, num_ffmpeg_threads=num_ffmpeg_threads, device=device, seek_mode=seek_mode, ) - device, _ = unsplit_device_str(device) ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device) ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device) @@ -230,8 +229,9 @@ def test_getitem_numpy_int(self): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_getitem_slice(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) - device, _ = unsplit_device_str(device) + decoder, device = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) # ensure that the degenerate case of a range of size 1 works @@ -391,7 +391,9 @@ def test_device_instance(self): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_getitem_fails(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) + decoder, _ = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) with pytest.raises(IndexError, match="Invalid frame index"): frame = decoder[1000] # noqa @@ -408,8 +410,9 @@ def test_getitem_fails(self, device, seek_mode): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_iteration(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) - device, _ = unsplit_device_str(device) + decoder, device = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device) ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device) @@ -456,8 +459,9 @@ def test_iteration_slow(self): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_at(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) - device, _ = unsplit_device_str(device) + decoder, device = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9).to(device) frame9 = decoder.get_frame_at(9) @@ -494,7 +498,7 @@ def test_get_frame_at(self, device, seek_mode): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_frame_at_tuple_unpacking(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + decoder, _ = make_video_decoder(NASA_VIDEO.path, device=device) frame = decoder.get_frame_at(50) data, pts, duration = decoder.get_frame_at(50) @@ -506,7 +510,9 @@ def test_get_frame_at_tuple_unpacking(self, device): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_at_fails(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) + decoder, _ = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) with pytest.raises( IndexError, @@ -520,8 +526,9 @@ def test_get_frame_at_fails(self, device, seek_mode): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_at(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) - device, _ = unsplit_device_str(device) + decoder, device = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) # test positive and negative frame index frames = decoder.get_frames_at([35, 25, -1, -2]) @@ -572,7 +579,9 @@ def test_get_frames_at(self, device, seek_mode): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_at_fails(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) + decoder, _ = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) with pytest.raises( IndexError, @@ -596,8 +605,7 @@ def test_get_frame_at_av1(self, device): if "cuda" in device and in_fbcode(): pytest.skip("decoding on CUDA is not supported internally") - decoder = VideoDecoder(AV1_VIDEO.path, device=device) - device, _ = unsplit_device_str(device) + decoder, device = make_video_decoder(AV1_VIDEO.path, device=device) ref_frame10 = AV1_VIDEO.get_frame_data_by_index(10) ref_frame_info10 = AV1_VIDEO.get_frame_info(10) decoded_frame10 = decoder.get_frame_at(10) @@ -608,8 +616,9 @@ def test_get_frame_at_av1(self, device): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_played_at(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) - device, _ = unsplit_device_str(device) + decoder, device = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) ref_frame_played_at_6 = NASA_VIDEO.get_frame_data_by_index(180).to(device) assert_frames_equal( @@ -638,7 +647,9 @@ def test_get_frame_played_at_h265(self): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frame_played_at_fails(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) + decoder, _ = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) with pytest.raises(IndexError, match="Invalid pts in seconds"): frame = decoder.get_frame_played_at(-1.0) # noqa @@ -650,8 +661,9 @@ def test_get_frame_played_at_fails(self, device, seek_mode): @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) @pytest.mark.parametrize("input_type", ("list", "tensor")) def test_get_frames_played_at(self, device, seek_mode, input_type): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) - device, _ = unsplit_device_str(device) + decoder, device = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) # Note: We know the frame at ~0.84s has index 25, the one at 1.16s has # index 35. We use those indices as reference to test against. @@ -693,7 +705,9 @@ def test_get_frames_played_at(self, device, seek_mode, input_type): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_played_at_fails(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) + decoder, _ = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) with pytest.raises(RuntimeError, match="must be greater than or equal to"): decoder.get_frames_played_at([-1]) @@ -710,13 +724,12 @@ def test_get_frames_played_at_fails(self, device, seek_mode): @pytest.mark.parametrize("stream_index", [0, 3, None]) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_in_range(self, stream_index, device, seek_mode): - decoder = VideoDecoder( + decoder, device = make_video_decoder( NASA_VIDEO.path, stream_index=stream_index, device=device, seek_mode=seek_mode, ) - device, _ = unsplit_device_str(device) # test degenerate case where we only actually get 1 frame ref_frames9 = NASA_VIDEO.get_frame_data_by_range( @@ -815,13 +828,12 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_in_range_slice_indices_syntax(self, device, seek_mode): - decoder = VideoDecoder( + decoder, device = make_video_decoder( NASA_VIDEO.path, stream_index=3, device=device, seek_mode=seek_mode, ) - device, _ = unsplit_device_str(device) # high range ends get capped to num_frames frames387_389 = decoder.get_frames_in_range(start=387, stop=1000) @@ -891,13 +903,12 @@ def test_get_frames_with_missing_num_frames_metadata( # Set the return value of the mock to be the mock_stream_dict mock_get_stream_json_metadata.return_value = json.dumps(mock_stream_dict) - decoder = VideoDecoder( + decoder, device = make_video_decoder( NASA_VIDEO.path, stream_index=3, device=device, seek_mode=seek_mode, ) - device, _ = unsplit_device_str(device) assert decoder.metadata.num_frames_from_header is None assert decoder.metadata.num_frames_from_content is None @@ -932,7 +943,7 @@ def test_get_frames_with_missing_num_frames_metadata( @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_dimension_order(self, dimension_order, frame_getter, device, seek_mode): - decoder = VideoDecoder( + decoder, _ = make_video_decoder( NASA_VIDEO.path, dimension_order=dimension_order, device=device, @@ -960,13 +971,12 @@ def test_dimension_order_fails(self): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode): - decoder = VideoDecoder( + decoder, device = make_video_decoder( NASA_VIDEO.path, stream_index=stream_index, device=device, seek_mode=seek_mode, ) - device, _ = unsplit_device_str(device) # Note that we are comparing the results of VideoDecoder's method: # get_frames_played_in_range() @@ -1100,7 +1110,9 @@ def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode): @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_by_pts_in_range_fails(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) + decoder, _ = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode=seek_mode + ) with pytest.raises(ValueError, match="Invalid start seconds"): frame = decoder.get_frames_played_in_range(100.0, 1.0) # noqa @@ -1113,7 +1125,9 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_key_frame_indices(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact") + decoder, _ = make_video_decoder( + NASA_VIDEO.path, device=device, seek_mode="exact" + ) key_frame_indices = decoder._get_key_frame_indices() # The key frame indices were generated from the following command: @@ -1134,7 +1148,9 @@ def test_get_key_frame_indices(self, device): key_frame_indices, nasa_reference_key_frame_indices, atol=0, rtol=0 ) - decoder = VideoDecoder(AV1_VIDEO.path, device=device, seek_mode="exact") + decoder, _ = make_video_decoder( + AV1_VIDEO.path, device=device, seek_mode="exact" + ) key_frame_indices = decoder._get_key_frame_indices() # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/av1_video.mkv | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt @@ -1144,7 +1160,9 @@ def test_get_key_frame_indices(self, device): key_frame_indices, av1_reference_key_frame_indices, atol=0, rtol=0 ) - decoder = VideoDecoder(H265_VIDEO.path, device=device, seek_mode="exact") + decoder, _ = make_video_decoder( + H265_VIDEO.path, device=device, seek_mode="exact" + ) key_frame_indices = decoder._get_key_frame_indices() # ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/h265_video.mp4 | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt @@ -1158,8 +1176,7 @@ def test_get_key_frame_indices(self, device): @pytest.mark.skipif(in_fbcode(), reason="Compile test fails internally.") @pytest.mark.parametrize("device", all_supported_devices()) def test_compile(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) - device, _ = unsplit_device_str(device) + decoder, device = make_video_decoder(NASA_VIDEO.path, device=device) @contextlib.contextmanager def restore_capture_scalar_outputs(): @@ -1297,17 +1314,17 @@ def test_10bit_videos(self, device, asset): # This just validates that we can decode 10-bit videos. # TODO validate against the ref that the decoded frames are correct - if device == "cuda:0:beta" and asset is H264_10BITS: + if device == "cuda:beta" and asset is H264_10BITS: # This fails on the BETA interface with: # # RuntimeError: Codec configuration not supported on this GPU. # Codec: 4, chroma format: 1, bit depth: 10 # - # It works on the default interface because FFmpeg fallsback to the + # It works on the ffmpeg interface because FFmpeg fallsback to the # CPU, while the BETA interface doesn't. pytest.skip("Asset not supported by NVDEC") - decoder = VideoDecoder(asset.path, device=device) + decoder, _ = make_video_decoder(asset.path, device=device) decoder.get_frame_at(10) def setup_frame_mappings(tmp_path, file, stream_index): @@ -1346,13 +1363,12 @@ def test_custom_frame_mappings_json_and_bytes( if hasattr(custom_frame_mappings, "read") else contextlib.nullcontext() ) as custom_frame_mappings: - decoder = VideoDecoder( + decoder, device = make_video_decoder( NASA_VIDEO.path, stream_index=stream_index, device=device, custom_frame_mappings=custom_frame_mappings, ) - device, _ = unsplit_device_str(device) frame_0 = decoder.get_frame_at(0) frame_5 = decoder.get_frame_at(5) assert_frames_equal( @@ -1483,9 +1499,8 @@ def test_beta_cuda_interface_get_frame_at( pytest.skip("AV1 CUDA not supported internally") ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) - beta_decoder = VideoDecoder( - asset.path, device="cuda:0:beta", seek_mode=seek_mode - ) + with set_cuda_backend("beta"): + beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) assert ref_decoder.metadata == beta_decoder.metadata @@ -1531,9 +1546,8 @@ def test_beta_cuda_interface_get_frames_at( pytest.skip("AV1 CUDA not supported internally") ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) - beta_decoder = VideoDecoder( - asset.path, device="cuda:0:beta", seek_mode=seek_mode - ) + with set_cuda_backend("beta"): + beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) assert ref_decoder.metadata == beta_decoder.metadata @@ -1577,9 +1591,8 @@ def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode): pytest.skip("AV1 CUDA not supported internally") ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) - beta_decoder = VideoDecoder( - asset.path, device="cuda:0:beta", seek_mode=seek_mode - ) + with set_cuda_backend("beta"): + beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) assert ref_decoder.metadata == beta_decoder.metadata @@ -1620,9 +1633,8 @@ def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode): pytest.skip("AV1 CUDA not supported internally") ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) - beta_decoder = VideoDecoder( - asset.path, device="cuda:0:beta", seek_mode=seek_mode - ) + with set_cuda_backend("beta"): + beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) assert ref_decoder.metadata == beta_decoder.metadata @@ -1664,9 +1676,8 @@ def test_beta_cuda_interface_backwards(self, asset, seek_mode): pytest.skip("AV1 CUDA not supported internally") ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) - beta_decoder = VideoDecoder( - asset.path, device="cuda:0:beta", seek_mode=seek_mode - ) + with set_cuda_backend("beta"): + beta_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode) assert ref_decoder.metadata == beta_decoder.metadata @@ -1690,17 +1701,19 @@ def test_beta_cuda_interface_backwards(self, asset, seek_mode): assert beta_frame.duration_seconds == ref_frame.duration_seconds @needs_cuda - def test_beta_cuda_interface_small_h265(self): - # Test to illustrate current difference in behavior between the BETA and - # the default interface: this video isn't supported by NVDEC, but in the - # default interface, FFMPEG fallsback to the CPU while we don't. + def test_beta_cuda_interface_cpu_fallback(self): + # Non-regression test for the CPU fallback behavior of the BETA CUDA + # interface. + # We know that the H265_VIDEO asset isn't supported by NVDEC, its + # dimensions are too small. We also know that the FFmpeg CUDA interface + # fallbacks to the CPU path in such cases. We assert that we fall back + # to the CPU path, too. + + ffmpeg = VideoDecoder(H265_VIDEO.path, device="cuda").get_frame_at(0) + with set_cuda_backend("beta"): + beta = VideoDecoder(H265_VIDEO.path, device="cuda").get_frame_at(0) - VideoDecoder(H265_VIDEO.path, device="cuda").get_frame_at(0) - with pytest.raises( - RuntimeError, - match="Video is too small in at least one dimension. Provided: 128x128 vs supported:144x144", - ): - VideoDecoder(H265_VIDEO.path, device="cuda:0:beta").get_frame_at(0) + torch.testing.assert_close(ffmpeg.data, beta.data, rtol=0, atol=0) @needs_cuda def test_beta_cuda_interface_error(self): @@ -1725,21 +1738,10 @@ def test_set_cuda_backend(self): with set_cuda_backend("BETA"): assert _get_cuda_backend() == "beta" - def assert_decoder_uses(decoder, *, expected_backend): - # Assert that a decoder instance is using a given backend. - # - # We know H265_VIDEO fails on the BETA backend while it works on the - # ffmpeg one. - if expected_backend == "ffmpeg": - decoder.get_frame_at(0) # this would fail if this was BETA - else: - with pytest.raises(RuntimeError, match="Video is too small"): - decoder.get_frame_at(0) - # Check that the default is the ffmpeg backend assert _get_cuda_backend() == "ffmpeg" dec = VideoDecoder(H265_VIDEO.path, device="cuda") - assert_decoder_uses(dec, expected_backend="ffmpeg") + assert _core._get_backend_details(dec._decoder).startswith("FFmpeg CUDA") # Check the setting "beta" effectively uses the BETA backend. # We also show that the affects decoder creation only. When the decoder @@ -1748,9 +1750,9 @@ def assert_decoder_uses(decoder, *, expected_backend): with set_cuda_backend("beta"): dec = VideoDecoder(H265_VIDEO.path, device="cuda") assert _get_cuda_backend() == "ffmpeg" - assert_decoder_uses(dec, expected_backend="beta") + assert _core._get_backend_details(dec._decoder).startswith("Beta CUDA") with set_cuda_backend("ffmpeg"): - assert_decoder_uses(dec, expected_backend="beta") + assert _core._get_backend_details(dec._decoder).startswith("Beta CUDA") # Hacky way to ensure passing "cuda:1" is supported by both backends. We # just check that there's an error when passing cuda:N where N is too diff --git a/test/test_encoders.py b/test/test_encoders.py index f8b5b3519..c5946654d 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -16,6 +16,7 @@ from .utils import ( assert_tensor_close_on_at_least, get_ffmpeg_major_version, + get_ffmpeg_minor_version, in_fbcode, IS_WINDOWS, NASA_AUDIO_MP3, @@ -23,6 +24,11 @@ TestContainerFile, ) +IS_WINDOWS_WITH_FFMPEG_LE_70 = IS_WINDOWS and ( + get_ffmpeg_major_version() < 7 + or (get_ffmpeg_major_version() == 7 and get_ffmpeg_minor_version() == 0) +) + @pytest.fixture def with_ffmpeg_debug_logs(): @@ -155,7 +161,11 @@ def test_bad_input_parametrized(self, method, tmp_path): avcodec_open2_failed_msg = "avcodec_open2 failed: Invalid argument" with pytest.raises( RuntimeError, - match=avcodec_open2_failed_msg if IS_WINDOWS else "invalid sample rate=10", + match=( + avcodec_open2_failed_msg + if IS_WINDOWS_WITH_FFMPEG_LE_70 + else "invalid sample rate=10" + ), ): getattr(decoder, method)(**valid_params) @@ -164,14 +174,18 @@ def test_bad_input_parametrized(self, method, tmp_path): ) with pytest.raises( RuntimeError, - match=avcodec_open2_failed_msg if IS_WINDOWS else "invalid sample rate=10", + match=( + avcodec_open2_failed_msg + if IS_WINDOWS_WITH_FFMPEG_LE_70 + else "invalid sample rate=10" + ), ): getattr(decoder, method)(sample_rate=10, **valid_params) with pytest.raises( RuntimeError, match=( avcodec_open2_failed_msg - if IS_WINDOWS + if IS_WINDOWS_WITH_FFMPEG_LE_70 else "invalid sample rate=99999999" ), ): @@ -192,7 +206,7 @@ def test_bad_input_parametrized(self, method, tmp_path): for num_channels in (0, 3): match = ( avcodec_open2_failed_msg - if IS_WINDOWS + if IS_WINDOWS_WITH_FFMPEG_LE_70 else re.escape( f"Desired number of channels ({num_channels}) is not supported" ) @@ -316,7 +330,7 @@ def test_against_cli( else: rtol, atol = None, None - if IS_WINDOWS and format == "mp3": + if IS_WINDOWS_WITH_FFMPEG_LE_70 and format == "mp3": # We're getting a "Could not open input file" on Windows mp3 files when decoding. # TODO: https://github.com/pytorch/torchcodec/issues/837 return @@ -370,7 +384,7 @@ def test_against_to_file( else: raise ValueError(f"Unknown method: {method}") - if not (IS_WINDOWS and format == "mp3"): + if not (IS_WINDOWS_WITH_FFMPEG_LE_70 and format == "mp3"): # We're getting a "Could not open input file" on Windows mp3 files when decoding. # TODO: https://github.com/pytorch/torchcodec/issues/837 torch.testing.assert_close( diff --git a/test/test_ops.py b/test/test_ops.py index fddd4043c..627829689 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib import io import os from functools import partial @@ -29,6 +28,8 @@ create_from_tensor, encode_audio_to_file, encode_video_to_file, + encode_video_to_file_like, + encode_video_to_tensor, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, @@ -41,6 +42,7 @@ get_next_frame, seek_to_pts, ) +from torchcodec.decoders import VideoDecoder from .utils import ( all_supported_devices, @@ -607,103 +609,6 @@ def test_color_conversion_library(self, color_conversion_library): ) assert_frames_equal(frame_time6, reference_frame_time6) - # We choose arbitrary values for width and height scaling to get better - # test coverage. Some pairs upscale the image while others downscale it. - @pytest.mark.parametrize( - "width_scaling_factor,height_scaling_factor", - ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)), - ) - @pytest.mark.parametrize("input_video", [NASA_VIDEO]) - def test_color_conversion_library_with_scaling( - self, input_video, width_scaling_factor, height_scaling_factor - ): - decoder = create_from_file(str(input_video.path)) - add_video_stream(decoder) - metadata = get_json_metadata(decoder) - metadata_dict = json.loads(metadata) - assert metadata_dict["width"] == input_video.width - assert metadata_dict["height"] == input_video.height - - target_height = int(input_video.height * height_scaling_factor) - target_width = int(input_video.width * width_scaling_factor) - if width_scaling_factor != 1.0: - assert target_width != input_video.width - if height_scaling_factor != 1.0: - assert target_height != input_video.height - - filtergraph_decoder = create_from_file(str(input_video.path)) - _add_video_stream( - filtergraph_decoder, - transform_specs=f"resize, {target_height}, {target_width}", - color_conversion_library="filtergraph", - ) - filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder) - - swscale_decoder = create_from_file(str(input_video.path)) - _add_video_stream( - swscale_decoder, - transform_specs=f"resize, {target_height}, {target_width}", - color_conversion_library="swscale", - ) - swscale_frame0, _, _ = get_next_frame(swscale_decoder) - assert_frames_equal(filtergraph_frame0, swscale_frame0) - assert filtergraph_frame0.shape == (3, target_height, target_width) - - @needs_cuda - def test_scaling_on_cuda_fails(self): - decoder = create_from_file(str(NASA_VIDEO.path)) - with pytest.raises( - RuntimeError, - match="Transforms are only supported for CPU devices.", - ): - add_video_stream(decoder, device="cuda", transform_specs="resize, 100, 100") - - def test_transform_fails(self): - decoder = create_from_file(str(NASA_VIDEO.path)) - with pytest.raises( - RuntimeError, - match="Invalid transform spec", - ): - add_video_stream(decoder, transform_specs=";") - - with pytest.raises( - RuntimeError, - match="Invalid transform name", - ): - add_video_stream(decoder, transform_specs="invalid, 1, 2") - - def test_resize_transform_fails(self): - decoder = create_from_file(str(NASA_VIDEO.path)) - with pytest.raises( - RuntimeError, - match="must have 3 elements", - ): - add_video_stream(decoder, transform_specs="resize, 100, 100, 100") - - with pytest.raises( - RuntimeError, - match="must be a positive integer", - ): - add_video_stream(decoder, transform_specs="resize, -10, 100") - - with pytest.raises( - RuntimeError, - match="must be a positive integer", - ): - add_video_stream(decoder, transform_specs="resize, 100, 0") - - with pytest.raises( - RuntimeError, - match="cannot be converted to an int", - ): - add_video_stream(decoder, transform_specs="resize, blah, 100") - - with pytest.raises( - RuntimeError, - match="out of range", - ): - add_video_stream(decoder, transform_specs="resize, 100, 1000000000000") - @pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW")) @pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale")) def test_color_conversion_library_with_dimension_order( @@ -743,86 +648,6 @@ def test_color_conversion_library_with_dimension_order( assert frames.shape[1:] == expected_shape assert_frames_equal(frames[0], frame0_ref) - @pytest.mark.parametrize( - "width_scaling_factor,height_scaling_factor", - ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)), - ) - @pytest.mark.parametrize("width", [30, 32, 300]) - @pytest.mark.parametrize("height", [128]) - def test_color_conversion_library_with_generated_videos( - self, tmp_path, width, height, width_scaling_factor, height_scaling_factor - ): - - # We consider filtergraph to be the reference color conversion library. - # However the video decoder sometimes uses swscale as that is faster. - # The exact color conversion library used is an implementation detail - # of the video decoder and depends on the video's width. - # - # In this test we compare the output of filtergraph (which is the - # reference) with the output of the video decoder (which may use - # swscale if it chooses for certain video widths) to make sure they are - # always the same. - video_path = f"{tmp_path}/frame_numbers_{width}x{height}.mp4" - # We don't specify a particular encoder because the ffmpeg binary could - # be configured with different encoders. For the purposes of this test, - # the actual encoder is irrelevant. - with contextlib.ExitStack() as stack: - ffmpeg_cli = "ffmpeg" - - if os.environ.get("IN_FBCODE_TORCHCODEC") == "1": - import importlib.resources - - ffmpeg_cli = stack.enter_context( - importlib.resources.path(__package__, "ffmpeg") - ) - - command = [ - ffmpeg_cli, - "-y", - "-f", - "lavfi", - "-i", - "color=blue", - "-pix_fmt", - "yuv420p", - "-s", - f"{width}x{height}", - "-frames:v", - "1", - video_path, - ] - subprocess.check_call(command) - - decoder = create_from_file(str(video_path)) - add_video_stream(decoder) - metadata = get_json_metadata(decoder) - metadata_dict = json.loads(metadata) - assert metadata_dict["width"] == width - assert metadata_dict["height"] == height - - target_height = int(height * height_scaling_factor) - target_width = int(width * width_scaling_factor) - if width_scaling_factor != 1.0: - assert target_width != width - if height_scaling_factor != 1.0: - assert target_height != height - - filtergraph_decoder = create_from_file(str(video_path)) - _add_video_stream( - filtergraph_decoder, - transform_specs=f"resize, {target_height}, {target_width}", - color_conversion_library="filtergraph", - ) - filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder) - - auto_decoder = create_from_file(str(video_path)) - add_video_stream( - auto_decoder, - transform_specs=f"resize, {target_height}, {target_width}", - ) - auto_frame0, _, _ = get_next_frame(auto_decoder) - assert_frames_equal(filtergraph_frame0, auto_frame0) - @needs_cuda def test_cuda_decoder(self): decoder = create_from_file(str(NASA_VIDEO.path)) @@ -1327,7 +1152,8 @@ def test_bad_input(self, tmp_path): class TestVideoEncoderOps: - + # TODO-VideoEncoder: Test encoding against different memory layouts (ex. test_contiguity) + # TODO-VideoEncoder: Parametrize test after moving to test_encoders def test_bad_input(self, tmp_path): output_file = str(tmp_path / ".mp4") @@ -1378,15 +1204,25 @@ def test_bad_input(self, tmp_path): filename="./bad/path.mp3", ) - def decode(self, file_path) -> torch.Tensor: - decoder = create_from_file(str(file_path), seek_mode="approximate") - add_video_stream(decoder) - frames, *_ = get_frames_in_range(decoder, start=0, stop=60) - return frames + with pytest.raises( + RuntimeError, + match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format", + ): + encode_video_to_tensor( + frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8), + frame_rate=10, + format="bad_format", + ) + + def decode(self, source=None) -> torch.Tensor: + return VideoDecoder(source).get_frames_in_range(start=0, stop=60) - @pytest.mark.parametrize("format", ("mov", "mp4", "mkv", "webm")) - def test_video_encoder_round_trip(self, tmp_path, format): - # Test that decode(encode(decode(asset))) == decode(asset) + @pytest.mark.parametrize( + "format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow)) + ) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + def test_video_encoder_round_trip(self, tmp_path, format, method): + # Test that decode(encode(decode(frames))) == decode(frames) ffmpeg_version = get_ffmpeg_major_version() # In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm. # As a result, we skip the round trip test. @@ -1398,15 +1234,36 @@ def test_video_encoder_round_trip(self, tmp_path, format): ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) ): pytest.skip("Codec for webm is not available in this FFmpeg installation.") - asset = TEST_SRC_2_720P - source_frames = self.decode(str(asset.path)).data + source_frames = self.decode(TEST_SRC_2_720P.path).data + + params = dict( + frame_rate=30, crf=0 + ) # Frame rate is fixed with num frames decoded + if method == "to_file": + encoded_path = str(tmp_path / f"encoder_output.{format}") + encode_video_to_file( + frames=source_frames, + filename=encoded_path, + **params, + ) + round_trip_frames = self.decode(encoded_path).data + elif method == "to_tensor": + encoded_tensor = encode_video_to_tensor( + source_frames, format=format, **params + ) + round_trip_frames = self.decode(encoded_tensor).data + elif method == "to_file_like": + file_like = io.BytesIO() + encode_video_to_file_like( + frames=source_frames, + format=format, + file_like=file_like, + **params, + ) + round_trip_frames = self.decode(file_like.getvalue()).data + else: + raise ValueError(f"Unknown method: {method}") - encoded_path = str(tmp_path / f"encoder_output.{format}") - frame_rate = 30 # Frame rate is fixed with num frames decoded - encode_video_to_file( - frames=source_frames, frame_rate=frame_rate, filename=encoded_path, crf=0 - ) - round_trip_frames = self.decode(encoded_path).data assert source_frames.shape == round_trip_frames.shape assert source_frames.dtype == round_trip_frames.dtype @@ -1422,24 +1279,75 @@ def test_video_encoder_round_trip(self, tmp_path, format): assert psnr(s_frame, rt_frame) > 30 assert_close(s_frame, rt_frame, atol=atol, rtol=0) + @pytest.mark.parametrize( + "format", + ( + "mov", + "mp4", + "avi", + "mkv", + "flv", + "gif", + pytest.param("webm", marks=pytest.mark.slow), + ), + ) + @pytest.mark.parametrize("method", ("to_tensor", "to_file_like")) + def test_against_to_file(self, tmp_path, format, method): + # Test that to_file, to_tensor, and to_file_like produce the same results + ffmpeg_version = get_ffmpeg_major_version() + if format == "webm" and ( + ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) + ): + pytest.skip("Codec for webm is not available in this FFmpeg installation.") + + source_frames = self.decode(TEST_SRC_2_720P.path).data + params = dict(frame_rate=30, crf=0) + + encoded_file = tmp_path / f"output.{format}" + encode_video_to_file(frames=source_frames, filename=str(encoded_file), **params) + + if method == "to_tensor": + encoded_output = encode_video_to_tensor( + source_frames, format=format, **params + ) + else: # to_file_like + file_like = io.BytesIO() + encode_video_to_file_like( + frames=source_frames, + file_like=file_like, + format=format, + **params, + ) + encoded_output = file_like.getvalue() + + torch.testing.assert_close( + self.decode(encoded_file).data, + self.decode(encoded_output).data, + atol=0, + rtol=0, + ) + @pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available") @pytest.mark.parametrize( - "format", ("mov", "mp4", "avi", "mkv", "webm", "flv", "gif") + "format", + ( + "mov", + "mp4", + "avi", + "mkv", + "flv", + "gif", + pytest.param("webm", marks=pytest.mark.slow), + ), ) def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): ffmpeg_version = get_ffmpeg_major_version() - if format == "webm": - if ffmpeg_version == 4: - pytest.skip( - "Codec for webm is not available in the FFmpeg4 installation." - ) - if IS_WINDOWS and ffmpeg_version in (6, 7): - pytest.skip( - "Codec for webm is not available in the FFmpeg6/7 installation on Windows." - ) - asset = TEST_SRC_2_720P - source_frames = self.decode(str(asset.path)).data - frame_rate = 30 + if format == "webm" and ( + ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) + ): + pytest.skip("Codec for webm is not available in this FFmpeg installation.") + + source_frames = self.decode(TEST_SRC_2_720P.path).data # Encode with FFmpeg CLI temp_raw_path = str(tmp_path / "temp_input.raw") @@ -1447,8 +1355,8 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes()) ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}") + frame_rate = 30 crf = 0 - quality_params = ["-crf", str(crf)] # Some codecs (ex. MPEG4) do not support CRF. # Flags not supported by the selected codec will be ignored. ffmpeg_cmd = [ @@ -1464,7 +1372,8 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): str(frame_rate), "-i", temp_raw_path, - *quality_params, + "-crf", + str(crf), ffmpeg_encoded_path, ] subprocess.run(ffmpeg_cmd, check=True) @@ -1496,6 +1405,82 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): ff_frame, enc_frame, percentage=percentage, atol=2 ) + def test_to_file_like_custom_file_object(self): + """Test with a custom file-like object that implements write and seek.""" + + class CustomFileObject: + def __init__(self): + self._file = io.BytesIO() + + def write(self, data): + return self._file.write(data) + + def seek(self, offset, whence=0): + return self._file.seek(offset, whence) + + def get_encoded_data(self): + return self._file.getvalue() + + source_frames = self.decode(TEST_SRC_2_720P.path).data + file_like = CustomFileObject() + encode_video_to_file_like( + source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like + ) + decoded_samples = self.decode(file_like.get_encoded_data()) + + torch.testing.assert_close( + decoded_samples.data, + source_frames, + atol=2, + rtol=0, + ) + + def test_to_file_like_real_file(self, tmp_path): + """Test to_file_like with a real file opened in binary write mode.""" + source_frames = self.decode(TEST_SRC_2_720P.path).data + file_path = tmp_path / "test_file_like.mp4" + + with open(file_path, "wb") as file_like: + encode_video_to_file_like( + source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like + ) + decoded_samples = self.decode(str(file_path)) + + torch.testing.assert_close( + decoded_samples.data, + source_frames, + atol=2, + rtol=0, + ) + + def test_to_file_like_bad_methods(self): + source_frames = self.decode(TEST_SRC_2_720P.path).data + + class NoWriteMethod: + def seek(self, offset, whence=0): + return 0 + + with pytest.raises( + RuntimeError, match="File like object must implement a write method" + ): + encode_video_to_file_like( + source_frames, + frame_rate=30, + format="mp4", + file_like=NoWriteMethod(), + ) + + class NoSeekMethod: + def write(self, data): + return len(data) + + with pytest.raises( + RuntimeError, match="File like object must implement a seek method" + ): + encode_video_to_file_like( + source_frames, frame_rate=30, format="mp4", file_like=NoSeekMethod() + ) + if __name__ == "__main__": pytest.main() diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py new file mode 100644 index 000000000..8d1ba5e53 --- /dev/null +++ b/test/test_transform_ops.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib + +import json +import os +import subprocess + +import pytest + +import torch + +from torchcodec._core import ( + _add_video_stream, + add_video_stream, + create_from_file, + get_frame_at_index, + get_json_metadata, + get_next_frame, +) + +from torchvision.transforms import v2 + +from .utils import assert_frames_equal, NASA_VIDEO, needs_cuda + +torch._dynamo.config.capture_dynamic_output_shape_ops = True + + +class TestVideoDecoderTransformOps: + # We choose arbitrary values for width and height scaling to get better + # test coverage. Some pairs upscale the image while others downscale it. + @pytest.mark.parametrize( + "width_scaling_factor,height_scaling_factor", + ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)), + ) + @pytest.mark.parametrize("input_video", [NASA_VIDEO]) + def test_color_conversion_library_with_scaling( + self, input_video, width_scaling_factor, height_scaling_factor + ): + decoder = create_from_file(str(input_video.path)) + add_video_stream(decoder) + metadata = get_json_metadata(decoder) + metadata_dict = json.loads(metadata) + assert metadata_dict["width"] == input_video.width + assert metadata_dict["height"] == input_video.height + + target_height = int(input_video.height * height_scaling_factor) + target_width = int(input_video.width * width_scaling_factor) + if width_scaling_factor != 1.0: + assert target_width != input_video.width + if height_scaling_factor != 1.0: + assert target_height != input_video.height + + filtergraph_decoder = create_from_file(str(input_video.path)) + _add_video_stream( + filtergraph_decoder, + transform_specs=f"resize, {target_height}, {target_width}", + color_conversion_library="filtergraph", + ) + filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder) + + swscale_decoder = create_from_file(str(input_video.path)) + _add_video_stream( + swscale_decoder, + transform_specs=f"resize, {target_height}, {target_width}", + color_conversion_library="swscale", + ) + swscale_frame0, _, _ = get_next_frame(swscale_decoder) + assert_frames_equal(filtergraph_frame0, swscale_frame0) + assert filtergraph_frame0.shape == (3, target_height, target_width) + + @pytest.mark.parametrize( + "width_scaling_factor,height_scaling_factor", + ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)), + ) + @pytest.mark.parametrize("width", [30, 32, 300]) + @pytest.mark.parametrize("height", [128]) + def test_color_conversion_library_with_generated_videos( + self, tmp_path, width, height, width_scaling_factor, height_scaling_factor + ): + # We consider filtergraph to be the reference color conversion library. + # However the video decoder sometimes uses swscale as that is faster. + # The exact color conversion library used is an implementation detail + # of the video decoder and depends on the video's width. + # + # In this test we compare the output of filtergraph (which is the + # reference) with the output of the video decoder (which may use + # swscale if it chooses for certain video widths) to make sure they are + # always the same. + video_path = f"{tmp_path}/frame_numbers_{width}x{height}.mp4" + # We don't specify a particular encoder because the ffmpeg binary could + # be configured with different encoders. For the purposes of this test, + # the actual encoder is irrelevant. + with contextlib.ExitStack() as stack: + ffmpeg_cli = "ffmpeg" + + if os.environ.get("IN_FBCODE_TORCHCODEC") == "1": + import importlib.resources + + ffmpeg_cli = stack.enter_context( + importlib.resources.path(__package__, "ffmpeg") + ) + + command = [ + ffmpeg_cli, + "-y", + "-f", + "lavfi", + "-i", + "color=blue", + "-pix_fmt", + "yuv420p", + "-s", + f"{width}x{height}", + "-frames:v", + "1", + video_path, + ] + subprocess.check_call(command) + + decoder = create_from_file(str(video_path)) + add_video_stream(decoder) + metadata = get_json_metadata(decoder) + metadata_dict = json.loads(metadata) + assert metadata_dict["width"] == width + assert metadata_dict["height"] == height + + target_height = int(height * height_scaling_factor) + target_width = int(width * width_scaling_factor) + if width_scaling_factor != 1.0: + assert target_width != width + if height_scaling_factor != 1.0: + assert target_height != height + + filtergraph_decoder = create_from_file(str(video_path)) + _add_video_stream( + filtergraph_decoder, + transform_specs=f"resize, {target_height}, {target_width}", + color_conversion_library="filtergraph", + ) + filtergraph_frame0, _, _ = get_next_frame(filtergraph_decoder) + + auto_decoder = create_from_file(str(video_path)) + add_video_stream( + auto_decoder, + transform_specs=f"resize, {target_height}, {target_width}", + ) + auto_frame0, _, _ = get_next_frame(auto_decoder) + assert_frames_equal(filtergraph_frame0, auto_frame0) + + @needs_cuda + def test_scaling_on_cuda_fails(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + with pytest.raises( + RuntimeError, + match="Transforms are only supported for CPU devices.", + ): + add_video_stream(decoder, device="cuda", transform_specs="resize, 100, 100") + + def test_transform_fails(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + with pytest.raises( + RuntimeError, + match="Invalid transform spec", + ): + add_video_stream(decoder, transform_specs=";") + + with pytest.raises( + RuntimeError, + match="Invalid transform name", + ): + add_video_stream(decoder, transform_specs="invalid, 1, 2") + + def test_resize_transform_fails(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + with pytest.raises( + RuntimeError, + match="must have 3 elements", + ): + add_video_stream(decoder, transform_specs="resize, 100, 100, 100") + + with pytest.raises( + RuntimeError, + match="must be a positive integer", + ): + add_video_stream(decoder, transform_specs="resize, -10, 100") + + with pytest.raises( + RuntimeError, + match="must be a positive integer", + ): + add_video_stream(decoder, transform_specs="resize, 100, 0") + + with pytest.raises( + RuntimeError, + match="cannot be converted to an int", + ): + add_video_stream(decoder, transform_specs="resize, blah, 100") + + with pytest.raises( + RuntimeError, + match="out of range", + ): + add_video_stream(decoder, transform_specs="resize, 100, 1000000000000") + + def test_crop_transform(self): + # Note that filtergraph accepts dimensions as (w, h) and we accept them as (h, w). + width = 300 + height = 200 + x = 50 + y = 35 + crop_spec = f"crop, {height}, {width}, {x}, {y}" + crop_filtergraph = f"crop={width}:{height}:{x}:{y}:exact=1" + expected_shape = (NASA_VIDEO.get_num_color_channels(), height, width) + + decoder_crop = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder_crop, transform_specs=crop_spec) + + decoder_full = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder_full) + + for frame_index in [0, 15, 200, 389]: + frame, *_ = get_frame_at_index(decoder_crop, frame_index=frame_index) + frame_ref = NASA_VIDEO.get_frame_data_by_index( + frame_index, filters=crop_filtergraph + ) + + frame_full, *_ = get_frame_at_index(decoder_full, frame_index=frame_index) + frame_tv = v2.functional.crop( + frame_full, top=y, left=x, height=height, width=width + ) + + assert frame.shape == expected_shape + assert frame_ref.shape == expected_shape + assert frame_tv.shape == expected_shape + + assert_frames_equal(frame, frame_tv) + assert_frames_equal(frame, frame_ref) + + def test_crop_transform_fails(self): + + with pytest.raises( + RuntimeError, + match="must have 5 elements", + ): + decoder = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder, transform_specs="crop, 100, 100") + + with pytest.raises( + RuntimeError, + match="must be a positive integer", + ): + decoder = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder, transform_specs="crop, -10, 100, 100, 100") + + with pytest.raises( + RuntimeError, + match="cannot be converted to an int", + ): + decoder = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder, transform_specs="crop, 100, 100, blah, 100") + + with pytest.raises( + RuntimeError, + match="x position out of bounds", + ): + decoder = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder, transform_specs="crop, 100, 100, 9999, 100") + + with pytest.raises( + RuntimeError, + match="y position out of bounds", + ): + decoder = create_from_file(str(NASA_VIDEO.path)) + add_video_stream(decoder, transform_specs="crop, 999, 100, 100, 100") diff --git a/test/utils.py b/test/utils.py index 7c91f307c..cbd6a5bf4 100644 --- a/test/utils.py +++ b/test/utils.py @@ -14,6 +14,7 @@ import torch from torchcodec._core import get_ffmpeg_library_versions +from torchcodec.decoders import set_cuda_backend, VideoDecoder from torchcodec.decoders._video_decoder import _read_custom_frame_mappings IS_WINDOWS = sys.platform in ("win32", "cygwin") @@ -26,40 +27,76 @@ def needs_cuda(test_item): return pytest.mark.needs_cuda(test_item) +# This is a special device string that we use to test the "beta" CUDA backend. +# It only exists here, in this test utils file. Public and core APIs have no +# idea that this is how we're tesing them. That is, that's not a supported +# `device` parameter for the VideoDecoder or for the _core APIs. +# Tests using all_supported_devices() will get this device string, and the test +# need to clean it up by calling either make_video_decoder for VideoDecoder, or +# unsplit_device_str for core APIs. +_CUDA_BETA_DEVICE_STR = "cuda:beta" + + def all_supported_devices(): return ( "cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda), - pytest.param("cuda:0:beta", marks=pytest.mark.needs_cuda), + pytest.param(_CUDA_BETA_DEVICE_STR, marks=pytest.mark.needs_cuda), ) def unsplit_device_str(device_str: str) -> str: # helper meant to be used as # device, device_variant = unsplit_device_str(device) - # when `device` comes from all_supported_devices() and may be "cuda:0:beta". + # when `device` comes from all_supported_devices() and may be _CUDA_BETA_DEVICE_STR. # It is used: - # - before calling `.to(device)` where device can't be "cuda:0:beta" + # - before calling `.to(device)` where device can't be _CUDA_BETA_DEVICE_STR. # - before calling add_video_stream(device=device, device_variant=device_variant) - # - # TODONVDEC P2: Find a less clunky way to test the BETA CUDA interface. It - # will ultimately depend on how we want to publicly expose it. - if device_str == "cuda:0:beta": + if device_str == _CUDA_BETA_DEVICE_STR: return "cuda", "beta" else: - return device_str, "default" + return device_str, "ffmpeg" -def get_ffmpeg_major_version(): +def make_video_decoder(*args, **kwargs) -> tuple[VideoDecoder, str]: + # Helper to create a VideoDecoder with the right cuda backend if needed. + # kwargs is expected to have a "device" key which comes from + # all_supported_devices(), and can be _CUDA_BETA_DEVICE_STR. + device = kwargs.pop("device", "cpu") + if device == _CUDA_BETA_DEVICE_STR: + clean_device, backend = "cuda", "beta" + else: + clean_device, backend = device, "ffmpeg" + + # set_cuda_backend is a no-op if the device is "cpu", so we can use it + # unconditionally. + with set_cuda_backend(backend): + dec = VideoDecoder(*args, **kwargs, device=clean_device) + + return dec, clean_device + + +def _get_ffmpeg_version_string(): ffmpeg_version = get_ffmpeg_library_versions()["ffmpeg_version"] # When building FFmpeg from source there can be a `n` prefix in the version # string. This is quite brittle as we're using av_version_info(), which has # no stable format. See https://github.com/pytorch/torchcodec/issues/100 if ffmpeg_version.startswith("n"): ffmpeg_version = ffmpeg_version.removeprefix("n") + + return ffmpeg_version + + +def get_ffmpeg_major_version(): + ffmpeg_version = _get_ffmpeg_version_string() return int(ffmpeg_version.split(".")[0]) +def get_ffmpeg_minor_version(): + ffmpeg_version = _get_ffmpeg_version_string() + return int(ffmpeg_version.split(".")[1]) + + def cuda_version_used_for_building_torch() -> Optional[tuple[int, int]]: # Return the CUDA version that was used to build PyTorch. That's not always # the same as the CUDA version that is currently installed on the running @@ -130,6 +167,12 @@ def assert_tensor_close_on_at_least( ) +# We embed filtergraph expressions in filenames, but they contain characters that +# some filesystems don't like. We turn all special characters into underscores. +def sanitize_filtergraph_expression(expression: str) -> str: + return "".join(c if c.isalnum() else "_" for c in expression) + + def in_fbcode() -> bool: return os.environ.get("IN_FBCODE_TORCHCODEC") == "1" @@ -328,16 +371,32 @@ def empty_duration_seconds(self) -> torch.Tensor: class TestVideo(TestContainerFile): """Base class for the *video* streams of a video container""" + def get_base_path_by_index( + self, idx: int, *, stream_index: int, filters: Optional[str] = None + ) -> pathlib.Path: + stream_and_frame = f"stream{stream_index}.frame{idx:06d}" + if filters is not None: + full_name = f"{self.filename}.{sanitize_filtergraph_expression(filters)}.{stream_and_frame}" + else: + full_name = f"{self.filename}.{stream_and_frame}" + + return _get_file_path(full_name) + def get_frame_data_by_index( - self, idx: int, *, stream_index: Optional[int] = None + self, + idx: int, + *, + stream_index: Optional[int] = None, + filters: Optional[str] = None, ) -> torch.Tensor: if stream_index is None: stream_index = self.default_stream_index - file_path = _get_file_path( - f"{self.filename}.stream{stream_index}.frame{idx:06d}.pt" + base_path = self.get_base_path_by_index( + idx, stream_index=stream_index, filters=filters ) - return torch.load(file_path, weights_only=True).permute(2, 0, 1) + tensor_file_path = f"{base_path}.pt" + return torch.load(tensor_file_path, weights_only=True).permute(2, 0, 1) def get_frame_data_by_range( self, @@ -434,6 +493,114 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor: ) +H265_VIDEO = TestVideo( + filename="h265_video.mp4", + default_stream_index=0, + # This metadata is extracted manually. + # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of json test/resources/h265_video.mp4 > out.json + stream_infos={ + 0: TestVideoStreamInfo(width=128, height=128, num_color_channels=3), + }, + frames={ + 0: { + 6: TestFrameInfo(pts_seconds=0.6, duration_seconds=0.1), + }, + }, +) + +AV1_VIDEO = TestVideo( + filename="av1_video.mkv", + default_stream_index=0, + # This metadata is extracted manually. + # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of json test/resources/av1_video.mkv > out.json + stream_infos={ + 0: TestVideoStreamInfo(width=640, height=360, num_color_channels=3), + }, + frames={ + 0: { + 10: TestFrameInfo(pts_seconds=0.400000, duration_seconds=0.040000), + }, + }, +) + + +# This is a BT.709 full range video, generated with: +# ffmpeg -f lavfi -i testsrc2=duration=1:size=1920x720:rate=30 \ +# -c:v libx264 -pix_fmt yuv420p -color_primaries bt709 -color_trc bt709 \ +# -colorspace bt709 -color_range pc bt709_full_range.mp4 +# +# We can confirm the color space and color range with: +# ffprobe -v quiet -select_streams v:0 -show_entries stream=color_space,color_transfer,color_primaries,color_range -of default=noprint_wrappers=1 test/resources/bt709_full_range.mp4 +# color_range=pc +# color_space=bt709 +# color_transfer=bt709 +# color_primaries=bt709 +BT709_FULL_RANGE = TestVideo( + filename="bt709_full_range.mp4", + default_stream_index=0, + stream_infos={ + 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3), + }, + frames={0: {}}, # Not needed for now +) + +# ffmpeg -f lavfi -i testsrc2=duration=2:size=1280x720:rate=30 -c:v libx264 -profile:v baseline -level 3.1 -pix_fmt yuv420p -b:v 2500k -r 30 -movflags +faststart output_720p_2s.mp4 +TEST_SRC_2_720P = TestVideo( + filename="testsrc2.mp4", + default_stream_index=0, + stream_infos={ + 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3), + }, + frames={0: {}}, # Not needed for now +) +# ffmpeg -f lavfi -i testsrc2=duration=10:size=1280x720:rate=30 -c:v libx265 -crf 23 -preset medium output.mp4 +TEST_SRC_2_720P_H265 = TestVideo( + filename="testsrc2_h265.mp4", + default_stream_index=0, + stream_infos={ + 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3), + }, + frames={0: {}}, # Not needed for now +) + +# ffmpeg -f lavfi -i testsrc2=size=1280x720:rate=30:duration=1 -c:v libvpx-vp9 -b:v 1M output_vp9.webm +TEST_SRC_2_720P_VP9 = TestVideo( + filename="testsrc2_vp9.webm", + default_stream_index=0, + stream_infos={ + 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3), + }, + frames={0: {}}, # Not needed for now +) + +# ffmpeg -f lavfi -i testsrc2=size=1280x720:rate=30:duration=1 -c:v libvpx -b:v 1M output_vp8.webm +TEST_SRC_2_720P_VP8 = TestVideo( + filename="testsrc2_vp8.webm", + default_stream_index=0, + stream_infos={ + 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3), + }, + frames={0: {}}, # Not needed for now +) + +# ffmpeg -f lavfi -i testsrc2=size=1280x720:rate=30:duration=1 -c:v mpeg4 -q:v 5 output_mpeg4.avi +TEST_SRC_2_720P_MPEG4 = TestVideo( + filename="testsrc2_mpeg4.avi", + default_stream_index=0, + stream_infos={ + 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3), + }, + frames={0: {}}, # Not needed for now +) + + +def supports_approximate_mode(asset: TestVideo) -> bool: + # Those are missing the `duration` field so they fail in approximate mode (on all devices). + # TODO: we should address this, see + # https://github.com/meta-pytorch/torchcodec/issues/945 + return asset not in (AV1_VIDEO, TEST_SRC_2_720P_VP9, TEST_SRC_2_720P_VP8) + + @dataclass class TestAudio(TestContainerFile): """Base class for the *audio* streams of a container (potentially a video), @@ -647,110 +814,3 @@ def sample_format(self) -> str: ) }, ) - -H265_VIDEO = TestVideo( - filename="h265_video.mp4", - default_stream_index=0, - # This metadata is extracted manually. - # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of json test/resources/h265_video.mp4 > out.json - stream_infos={ - 0: TestVideoStreamInfo(width=128, height=128, num_color_channels=3), - }, - frames={ - 0: { - 6: TestFrameInfo(pts_seconds=0.6, duration_seconds=0.1), - }, - }, -) - -AV1_VIDEO = TestVideo( - filename="av1_video.mkv", - default_stream_index=0, - # This metadata is extracted manually. - # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of json test/resources/av1_video.mkv > out.json - stream_infos={ - 0: TestVideoStreamInfo(width=640, height=360, num_color_channels=3), - }, - frames={ - 0: { - 10: TestFrameInfo(pts_seconds=0.400000, duration_seconds=0.040000), - }, - }, -) - - -# This is a BT.709 full range video, generated with: -# ffmpeg -f lavfi -i testsrc2=duration=1:size=1920x720:rate=30 \ -# -c:v libx264 -pix_fmt yuv420p -color_primaries bt709 -color_trc bt709 \ -# -colorspace bt709 -color_range pc bt709_full_range.mp4 -# -# We can confirm the color space and color range with: -# ffprobe -v quiet -select_streams v:0 -show_entries stream=color_space,color_transfer,color_primaries,color_range -of default=noprint_wrappers=1 test/resources/bt709_full_range.mp4 -# color_range=pc -# color_space=bt709 -# color_transfer=bt709 -# color_primaries=bt709 -BT709_FULL_RANGE = TestVideo( - filename="bt709_full_range.mp4", - default_stream_index=0, - stream_infos={ - 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3), - }, - frames={0: {}}, # Not needed for now -) - -# ffmpeg -f lavfi -i testsrc2=duration=2:size=1280x720:rate=30 -c:v libx264 -profile:v baseline -level 3.1 -pix_fmt yuv420p -b:v 2500k -r 30 -movflags +faststart output_720p_2s.mp4 -TEST_SRC_2_720P = TestVideo( - filename="testsrc2.mp4", - default_stream_index=0, - stream_infos={ - 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3), - }, - frames={0: {}}, # Not needed for now -) -# ffmpeg -f lavfi -i testsrc2=duration=10:size=1280x720:rate=30 -c:v libx265 -crf 23 -preset medium output.mp4 -TEST_SRC_2_720P_H265 = TestVideo( - filename="testsrc2_h265.mp4", - default_stream_index=0, - stream_infos={ - 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3), - }, - frames={0: {}}, # Not needed for now -) - -# ffmpeg -f lavfi -i testsrc2=size=1280x720:rate=30:duration=1 -c:v libvpx-vp9 -b:v 1M output_vp9.webm -TEST_SRC_2_720P_VP9 = TestVideo( - filename="testsrc2_vp9.webm", - default_stream_index=0, - stream_infos={ - 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3), - }, - frames={0: {}}, # Not needed for now -) - -# ffmpeg -f lavfi -i testsrc2=size=1280x720:rate=30:duration=1 -c:v libvpx -b:v 1M output_vp8.webm -TEST_SRC_2_720P_VP8 = TestVideo( - filename="testsrc2_vp8.webm", - default_stream_index=0, - stream_infos={ - 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3), - }, - frames={0: {}}, # Not needed for now -) - -# ffmpeg -f lavfi -i testsrc2=size=1280x720:rate=30:duration=1 -c:v mpeg4 -q:v 5 output_mpeg4.avi -TEST_SRC_2_720P_MPEG4 = TestVideo( - filename="testsrc2_mpeg4.avi", - default_stream_index=0, - stream_infos={ - 0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3), - }, - frames={0: {}}, # Not needed for now -) - - -def supports_approximate_mode(asset: TestVideo) -> bool: - # TODONVDEC P2: open an issue about his. That's actually not related to - # NVDEC at all, those don't support approximate mode because they don't set - # a duration. CPU decoder fails too! - return asset not in (AV1_VIDEO, TEST_SRC_2_720P_VP9, TEST_SRC_2_720P_VP8) diff --git a/version.txt b/version.txt index a3df0a695..c18d72be3 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.8.0 +0.8.1 \ No newline at end of file