diff --git a/.github/workflows/test_linux_cuda.yaml b/.github/workflows/test_linux_cuda.yaml deleted file mode 100644 index 527448bec..000000000 --- a/.github/workflows/test_linux_cuda.yaml +++ /dev/null @@ -1,64 +0,0 @@ - -name: Unit-tests on Linux GPU - -on: - pull_request: - push: - branches: - - nightly - - main - - release/* - workflow_dispatch: - -jobs: - tests: - strategy: - matrix: - python_version: ["3.9"] - # TODO: Add more cuda versions. - cuda_arch_version: ["12.4"] - fail-fast: false - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main - with: - runner: linux.g5.4xlarge.nvidia.gpu - repository: pytorch/torchcodec - gpu-arch-type: cuda - gpu-arch-version: ${{ matrix.cuda_arch_version }} - timeout: 120 - - script: | - nvidia-smi - conda create --yes --name test - conda activate test - conda install --quiet --yes pip cmake pkg-config nasm - - pip install --quiet --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu124 - conda install --quiet --yes nvidia::libnpp - - # Build and install FFMPEG from source with CUDA enabled. - # The one on conda doesn't have CUDA enabled. - # Sub-step: install nvidia headers. Reference this link for details: - # https://docs.nvidia.com/video-technologies/video-codec-sdk/12.1/ffmpeg-with-nvidia-gpu/index.html - git clone --quiet https://git.videolan.org/git/ffmpeg/nv-codec-headers.git - - pushd nv-codec-headers - make --silent PREFIX=$CONDA_PREFIX -j install - popd - - # Now build FFMPEG from source with CUDA enabled. - git clone --quiet https://git.ffmpeg.org/ffmpeg.git ffmpeg/ - pushd ffmpeg - git checkout origin/release/6.1 - which pkg-config - pkg-config --list-all - ./configure --prefix=$CONDA_PREFIX --enable-nonfree --enable-cuda-nvcc --disable-static --enable-shared --optflags=-fno-omit-frame-pointer --disable-stripping --enable-cuvid - make --silent -j install - popd - - CMAKE_BUILD_PARALLEL_LEVEL=8 CXXFLAGS="" LDFLAGS="-Wl,--allow-shlib-undefined -Wl,-rpath,$CONDA_PREFIX/lib -Wl,-rpath-link,$CONDA_PREFIX/lib -L$CONDA_PREFIX/lib" CMAKE_BUILD_TYPE=Release ENABLE_CUDA=1 ENABLE_NVTX=1 pip install -e ".[dev]" --no-build-isolation -vv --debug - - # We skip certain tests because they are not relevant to GPU decoding and they always fail with - # a custom FFMPEG build. - pytest -k "not (test_get_metadata or get_ffmpeg_version)" - python benchmarks/decoders/gpu_benchmark.py - conda deactivate diff --git a/CMakeLists.txt b/CMakeLists.txt index fc8d17c26..4dfeb0609 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,9 +1,6 @@ cmake_minimum_required(VERSION 3.18) project(TorchCodec) -option(ENABLE_CUDA "Enable CUDA decoding using NVDEC" OFF) -option(ENABLE_NVTX "Enable NVTX annotations for profiling" OFF) - add_subdirectory(src/torchcodec/decoders/_core) diff --git a/README.md b/README.md index 4640ad3e9..183120598 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,3 @@ guide](CONTRIBUTING.md) for more details. ## License TorchCodec is released under the [BSD 3 license](./LICENSE). - - -If you are building with ENABLE_CUDA and/or ENABLE_NVTX please review -[Nvidia licenses](https://docs.nvidia.com/cuda/eula/index.html). diff --git a/benchmarks/decoders/BenchmarkDecodersMain.cpp b/benchmarks/decoders/BenchmarkDecodersMain.cpp index c1b15bafb..a9762a0b5 100644 --- a/benchmarks/decoders/BenchmarkDecodersMain.cpp +++ b/benchmarks/decoders/BenchmarkDecodersMain.cpp @@ -63,7 +63,7 @@ void runNDecodeIterations( decoder->addVideoStreamDecoder(-1); for (double pts : ptsList) { decoder->setCursorPtsInSeconds(pts); - torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame; + torch::Tensor tensor = decoder->getNextDecodedOutput().frame; } if (i + 1 == warmupIterations) { start = std::chrono::high_resolution_clock::now(); @@ -95,7 +95,7 @@ void runNdecodeIterationsGrabbingConsecutiveFrames( VideoDecoder::createFromFilePath(videoPath); decoder->addVideoStreamDecoder(-1); for (int j = 0; j < consecutiveFrameCount; ++j) { - torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame; + torch::Tensor tensor = decoder->getNextDecodedOutput().frame; } if (i + 1 == warmupIterations) { start = std::chrono::high_resolution_clock::now(); @@ -145,8 +145,7 @@ void runNDecodeIterationsWithCustomOps( /*height=*/std::nullopt, /*thread_count=*/std::nullopt, /*dimension_order=*/std::nullopt, - /*stream_index=*/std::nullopt, - /*device_string=*/std::nullopt); + /*stream_index=*/std::nullopt); for (double pts : ptsList) { seekFrameOp.call(decoderTensor, pts); diff --git a/benchmarks/decoders/gpu_benchmark.py b/benchmarks/decoders/gpu_benchmark.py deleted file mode 100644 index 4fd7c8ad6..000000000 --- a/benchmarks/decoders/gpu_benchmark.py +++ /dev/null @@ -1,144 +0,0 @@ -import argparse -import os -import time - -import torch.utils.benchmark as benchmark - -import torchcodec -import torchvision.transforms.v2.functional as F - -RESIZED_WIDTH = 256 -RESIZED_HEIGHT = 256 - - -def transfer_and_resize_frame(frame, resize_device_string): - # This should be a no-op if the frame is already on the target device. - frame = frame.to(resize_device_string) - frame = F.resize(frame, (RESIZED_HEIGHT, RESIZED_WIDTH)) - return frame - - -def decode_full_video(video_path, decode_device_string, resize_device_string): - # We use the core API instead of SimpleVideoDecoder because the core API - # allows us to natively resize as part of the decode step. - print(f"{decode_device_string=} {resize_device_string=}") - decoder = torchcodec.decoders._core.create_from_file(video_path) - num_threads = None - if "cuda" in decode_device_string: - num_threads = 1 - width = None - height = None - if "native" in resize_device_string: - width = RESIZED_WIDTH - height = RESIZED_HEIGHT - torchcodec.decoders._core.add_video_stream( - decoder, - stream_index=-1, - device_string=decode_device_string, - num_threads=num_threads, - width=width, - height=height, - ) - - start_time = time.time() - frame_count = 0 - while True: - try: - frame, *_ = torchcodec.decoders._core.get_next_frame(decoder) - if resize_device_string != "none" and "native" not in resize_device_string: - frame = transfer_and_resize_frame(frame, resize_device_string) - - frame_count += 1 - except Exception as e: - print("EXCEPTION", e) - break - - end_time = time.time() - elapsed = end_time - start_time - fps = frame_count / (end_time - start_time) - print( - f"****** DECODED full video {decode_device_string=} {frame_count=} {elapsed=} {fps=}" - ) - return frame_count, end_time - start_time - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--devices", - default="cuda:0,cpu", - type=str, - help="Comma-separated devices to test decoding on.", - ) - parser.add_argument( - "--resize_devices", - default="cuda:0,cpu,native,none", - type=str, - help="Comma-separated devices to test preroc (resize) on. Use 'none' to specify no resize.", - ) - parser.add_argument( - "--video", - type=str, - default=os.path.dirname(__file__) + "/../../test/resources/nasa_13013.mp4", - ) - parser.add_argument( - "--use_torch_benchmark", - action=argparse.BooleanOptionalAction, - default=True, - help=( - "Use pytorch benchmark to measure decode time with warmup and " - "autorange. Without this we just run one iteration without warmup " - "to measure the cold start time." - ), - ) - args = parser.parse_args() - video_path = args.video - - if not args.use_torch_benchmark: - for device in args.devices.split(","): - print("Testing on", device) - decode_full_video(video_path, device) - return - - resize_devices = args.resize_devices.split(",") - resize_devices = [d for d in resize_devices if d != ""] - if len(resize_devices) == 0: - resize_devices.append("none") - - label = "Decode+Resize Time" - - results = [] - for decode_device_string in args.devices.split(","): - for resize_device_string in resize_devices: - decode_label = decode_device_string - if "cuda" in decode_label: - # Shorten "cuda:0" to "cuda" - decode_label = "cuda" - resize_label = resize_device_string - if "cuda" in resize_device_string: - # Shorten "cuda:0" to "cuda" - resize_label = "cuda" - print("decode_device", decode_device_string) - print("resize_device", resize_device_string) - t = benchmark.Timer( - stmt="decode_full_video(video_path, decode_device_string, resize_device_string)", - globals={ - "decode_device_string": decode_device_string, - "video_path": video_path, - "decode_full_video": decode_full_video, - "resize_device_string": resize_device_string, - }, - label=label, - description=f"video={os.path.basename(video_path)}", - sub_label=f"D={decode_label} R={resize_label}", - ).blocked_autorange() - results.append(t) - compare = benchmark.Compare(results) - compare.print() - print("Key: D=Decode, R=Resize") - print("Native resize is done as part of the decode step") - print("none resize means there is no resize step -- native or otherwise") - - -if __name__ == "__main__": - main() diff --git a/examples/basic_example.py b/examples/basic_example.py index 693c8c47d..abbc1b469 100644 --- a/examples/basic_example.py +++ b/examples/basic_example.py @@ -171,14 +171,3 @@ def plot(frames: torch.Tensor, title : Optional[str] = None): # %% plot(frame_at_2_seconds.data, "Frame displayed at 2 seconds") plot(first_two_seconds.data, "Frames displayed during [0, 2) seconds") - -# %% -# Using a CUDA GPU to accelerate decoding -# --------------------------------------- -# -# If you have a CUDA GPU that has NVDEC, you can decode on the GPU. -if torch.cuda.is_available(): - cuda_decoder = SimpleVideoDecoder(raw_video_bytes, device="cuda:0") - cuda_frame = cuda_decoder.get_frame_displayed_at(seconds=2) - print(cuda_frame.data.device) # should be cuda:0 - plot(cuda_frame.data.to("cpu"), "Frame displayed at 2 seconds on CUDA") diff --git a/setup.py b/setup.py index 75310ceab..fb5d0278c 100644 --- a/setup.py +++ b/setup.py @@ -112,16 +112,12 @@ def _build_all_extensions_with_cmake(self): torch_dir = Path(torch.utils.cmake_prefix_path) / "Torch" cmake_build_type = os.environ.get("CMAKE_BUILD_TYPE", "Release") python_version = sys.version_info - enable_cuda = os.environ.get("ENABLE_CUDA", "") - enable_nvtx = os.environ.get("ENABLE_NVTX", "") cmake_args = [ f"-DCMAKE_INSTALL_PREFIX={self._install_prefix}", f"-DTorch_DIR={torch_dir}", "-DCMAKE_VERBOSE_MAKEFILE=ON", f"-DCMAKE_BUILD_TYPE={cmake_build_type}", f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}", - f"-DENABLE_CUDA={enable_cuda}", - f"-DENABLE_NVTX={enable_nvtx}", ] Path(self.build_temp).mkdir(parents=True, exist_ok=True) diff --git a/src/torchcodec/_samplers/video_clip_sampler.py b/src/torchcodec/_samplers/video_clip_sampler.py index 384dd135e..1440edaeb 100644 --- a/src/torchcodec/_samplers/video_clip_sampler.py +++ b/src/torchcodec/_samplers/video_clip_sampler.py @@ -31,7 +31,6 @@ class VideoTooShortException(Exception): @dataclass class DecoderArgs: num_threads: int = 0 - device: torch.device = torch.device("cpu") @dataclass @@ -164,7 +163,6 @@ def forward(self, video_data: Tensor) -> Union[List[Any]]: width=target_width, height=target_height, num_threads=self.decoder_args.num_threads, - device_string=str(self.decoder_args.device), ) clips: List[Any] = [] diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index ed8e8ef36..0fe8f2433 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -4,28 +4,6 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Torch REQUIRED) - -if(ENABLE_CUDA) - find_package(CUDA REQUIRED) - - if(ENABLE_NVTX) - # We only need CPM for NVTX: - # https://github.com/NVIDIA/NVTX#cmake - file( - DOWNLOAD - https://github.com/cpm-cmake/CPM.cmake/releases/download/v0.38.3/CPM.cmake - ${CMAKE_CURRENT_BINARY_DIR}/cmake/CPM.cmake - EXPECTED_HASH SHA256=cc155ce02e7945e7b8967ddfaff0b050e958a723ef7aad3766d368940cb15494 - ) - include(${CMAKE_CURRENT_BINARY_DIR}/cmake/CPM.cmake) - CPMAddPackage( - NAME NVTX - GITHUB_REPOSITORY NVIDIA/NVTX - GIT_TAG v3.1.0-c-cpp - GIT_SHALLOW TRUE) - endif() -endif() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) @@ -41,12 +19,6 @@ function(make_torchcodec_library library_name ffmpeg_target) ) add_library(${library_name} SHARED ${sources}) set_property(TARGET ${library_name} PROPERTY CXX_STANDARD 17) - if(ENABLE_CUDA) - target_compile_definitions(${library_name} PRIVATE ENABLE_CUDA=1) - endif() - if(ENABLE_NVTX) - target_compile_definitions(${library_name} PRIVATE ENABLE_NVTX=1) - endif() target_include_directories( ${library_name} @@ -56,17 +28,12 @@ function(make_torchcodec_library library_name ffmpeg_target) ${Python3_INCLUDE_DIRS} ) - set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES} ${Python3_LIBRARIES}) - if(ENABLE_CUDA) - list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY} ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} ) - endif() - if(ENABLE_NVTX) - list(APPEND NEEDED_LIBRARIES nvtx3-cpp) - endif() target_link_libraries( ${library_name} PUBLIC - ${NEEDED_LIBRARIES} + ${ffmpeg_target} + ${TORCH_LIBRARIES} + ${Python3_LIBRARIES} ) # We already set the library_name to be libtorchcodecN, so we don't want diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.h b/src/torchcodec/decoders/_core/FFMPEGCommon.h index b5ad4e039..7bb61cefc 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.h +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.h @@ -57,8 +57,6 @@ using UniqueAVFilterInOut = std::unique_ptr< Deleterp>; using UniqueAVIOContext = std:: unique_ptr>; -using UniqueAVBufferRef = - std::unique_ptr>; // av_find_best_stream is not const-correct before commit: // https://github.com/FFmpeg/FFmpeg/commit/46dac8cf3d250184ab4247809bc03f60e14f4c0c diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index a5c0fddfb..c4bb1520f 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -5,22 +5,11 @@ // LICENSE file in the root directory of this source tree. #include "src/torchcodec/decoders/_core/VideoDecoder.h" -#include -#include #include #include -#include #include #include - -#ifdef ENABLE_CUDA -#include -#include -#include -#ifdef ENABLE_NVTX -#include -#endif -#endif +#include "torch/types.h" extern "C" { #include @@ -29,9 +18,6 @@ extern "C" { #include #include #include -#ifdef ENABLE_CUDA -#include -#endif } namespace facebook::torchcodec { @@ -107,87 +93,6 @@ std::vector splitStringWithDelimiters( return result; } -#ifdef ENABLE_CUDA - -AVBufferRef* getCudaContext() { - enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); - TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); - int err = 0; - AVBufferRef* hw_device_ctx; - err = av_hwdevice_ctx_create( - &hw_device_ctx, - type, - nullptr, - nullptr, - // Introduced in 58.26.100: - // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265 -#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) - AV_CUDA_USE_CURRENT_CONTEXT -#else - 0 -#endif - ); - if (err < 0) { - TORCH_CHECK( - false, - "Failed to create specified HW device", - getFFMPEGErrorStringFromErrorCode(err)); - } - return hw_device_ctx; -} - -torch::Tensor allocateDeviceTensor( - at::IntArrayRef shape, - torch::Device device, - const torch::Dtype dtype = torch::kUInt8) { - return torch::empty( - shape, - torch::TensorOptions() - .dtype(dtype) - .layout(torch::kStrided) - .device(device)); -} - -torch::Tensor convertFrameToTensorUsingCUDA( - const AVCodecContext* codecContext, - const VideoDecoder::VideoStreamDecoderOptions& options, - const AVFrame* src) { - TORCH_CHECK( - src->format == AV_PIX_FMT_CUDA, - "Expected format to be AV_PIX_FMT_CUDA, got " + - std::string(av_get_pix_fmt_name((AVPixelFormat)src->format))); - int width = options.width.value_or(codecContext->width); - int height = options.height.value_or(codecContext->height); - NppStatus status; - NppiSize oSizeROI; - oSizeROI.width = width; - oSizeROI.height = height; - Npp8u* input[2]; - input[0] = (Npp8u*)src->data[0]; - input[1] = (Npp8u*)src->data[1]; - torch::Tensor dst = allocateDeviceTensor({height, width, 3}, options.device); - auto start = std::chrono::high_resolution_clock::now(); - status = nppiNV12ToRGB_8u_P2C3R( - input, - src->linesize[0], - static_cast(dst.data_ptr()), - dst.stride(0), - oSizeROI); - TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); - auto end = std::chrono::high_resolution_clock::now(); - std::chrono::duration duration = end - start; - VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width - << " took: " << duration.count() << "us" << std::endl; - if (options.dimensionOrder == "NCHW") { - // The docs guaranty this to return a view: - // https://pytorch.org/docs/stable/generated/torch.permute.html - dst = dst.permute({2, 0, 1}); - } - return dst; -} - -#endif - } // namespace VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions( @@ -433,13 +338,13 @@ void VideoDecoder::initializeFilterGraphForStream( inputs.reset(inputsTmp); if (ffmpegStatus < 0) { throw std::runtime_error( - "Failed to parse filter description: " + std::string(description) + - "; " + getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); + "Failed to parse filter description: " + + getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } ffmpegStatus = avfilter_graph_config(filterState.filterGraph.get(), nullptr); if (ffmpegStatus < 0) { throw std::runtime_error( - "Failed to configure filter graph: " + std::string(description) + "; " + + "Failed to configure filter graph: " + getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } } @@ -488,37 +393,15 @@ void VideoDecoder::addVideoStreamDecoder( int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); TORCH_CHECK_EQ(retVal, AVSUCCESS); - - if (options.device.type() == torch::DeviceType::CUDA) { -#ifdef ENABLE_CUDA - // We create a small tensor using pytorch to initialize the cuda context. - torch::Tensor dummyTensorForCudaInitialization = torch::zeros( - {1}, - torch::TensorOptions().dtype(torch::kUInt8).device(options.device)); - codecContext->hw_device_ctx = av_buffer_ref(getCudaContext()); - - TORCH_INTERNAL_ASSERT( - codecContext->hw_device_ctx, - "Failed to create/reference the CUDA HW device context for index=" + - std::to_string(options.device.index()) + "."); -#else - throw std::runtime_error( - "CUDA support is not enabled in this build of TorchCodec."); -#endif - } - retVal = avcodec_open2(streamInfo.codecContext.get(), codec, nullptr); if (retVal < AVSUCCESS) { throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal)); } - codecContext->time_base = streamInfo.stream->time_base; activeStreamIndices_.insert(streamNumber); updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext); streamInfo.options = options; - if (options.device.is_cpu()) { - initializeFilterGraphForStream(streamNumber, options); - } + initializeFilterGraphForStream(streamNumber, options); } void VideoDecoder::updateMetadataWithCodecContext( @@ -765,13 +648,10 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { VideoDecoder::DecodedOutput VideoDecoder::getDecodedOutputWithFilter( std::function filterFunction) { -#ifdef ENABLE_NVTX - nvtx3::scoped_range loop{"decodeOneFrame"}; -#endif if (activeStreamIndices_.size() == 0) { throw std::runtime_error("No active streams configured."); } - VLOG(9) << "Starting getNextDecodedOutputNoDemux()"; + VLOG(9) << "Starting getNextDecodedOutput()"; resetDecodeStats(); if (maybeDesiredPts_.has_value()) { VLOG(9) << "maybeDesiredPts_=" << *maybeDesiredPts_; @@ -855,13 +735,8 @@ VideoDecoder::DecodedOutput VideoDecoder::getDecodedOutputWithFilter( // This packet is not for any of the active streams. continue; } - { -#ifdef ENABLE_NVTX - nvtx3::scoped_range loop{"avcodec_send_packet"}; -#endif - ffmpegStatus = avcodec_send_packet( - streams_[packet->stream_index].codecContext.get(), packet.get()); - } + ffmpegStatus = avcodec_send_packet( + streams_[packet->stream_index].codecContext.get(), packet.get()); decodeStats_.numPacketsSentToDecoder++; if (ffmpegStatus < AVSUCCESS) { throw std::runtime_error( @@ -898,9 +773,8 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( UniqueAVFrame frame) { // Convert the frame to tensor. DecodedOutput output; - auto& streamInfo = streams_[streamIndex]; output.streamIndex = streamIndex; - output.streamType = streamInfo.stream->codecpar->codec_type; + output.streamType = streams_[streamIndex].stream->codecpar->codec_type; output.pts = frame->pts; output.ptsSeconds = ptsToSeconds(frame->pts, formatContext_->streams[streamIndex]->time_base); @@ -908,22 +782,8 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( output.durationSeconds = ptsToSeconds( getDuration(frame), formatContext_->streams[streamIndex]->time_base); if (output.streamType == AVMEDIA_TYPE_VIDEO) { - if (streamInfo.options.device.is_cpu()) { - output.frame = - convertFrameToTensorUsingFilterGraph(streamIndex, frame.get()); - } else if (streamInfo.options.device.is_cuda()) { -#ifdef ENABLE_CUDA - { -#ifdef ENABLE_NVTX - nvtx3::scoped_range loop{"convertFrameUsingCuda"}; -#endif - output.frame = convertFrameToTensorUsingCUDA( - streamInfo.codecContext.get(), streamInfo.options, frame.get()); - } -#else - throw std::runtime_error("CUDA is not enabled in this build."); -#endif // ENABLE_CUDA - } + output.frame = + convertFrameToTensorUsingFilterGraph(streamIndex, frame.get()); } else if (output.streamType == AVMEDIA_TYPE_AUDIO) { // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement // audio decoding. @@ -932,7 +792,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( return output; } -VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux( +VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestamp( double seconds) { for (auto& [streamIndex, stream] : streams_) { double frameStartTime = ptsToSeconds(stream.currentPts, stream.timeBase); @@ -1007,7 +867,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( int64_t pts = stream.allFrames[frameIndex].pts; setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase)); - return getNextDecodedOutputNoDemux(); + return getNextDecodedOutput(); } VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndexes( @@ -1160,7 +1020,7 @@ VideoDecoder::getFramesDisplayedByTimestampInRange( return output; } -VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() { +VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutput() { return getDecodedOutputWithFilter( [this](int frameStreamIndex, AVFrame* frame) { StreamInfo& activeStream = streams_[frameStreamIndex]; diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index a84d7da56..e54731554 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -12,7 +12,6 @@ #include #include -#include "c10/core/Device.h" #include "src/torchcodec/decoders/_core/FFMPEGCommon.h" namespace facebook::torchcodec { @@ -140,8 +139,6 @@ class VideoDecoder { // is the same as the original video. std::optional width; std::optional height; - // Set the device to torch::kGPU for GPU decoding. - torch::Device device = torch::kCPU; }; struct AudioStreamDecoderOptions {}; void addVideoStreamDecoder( @@ -153,8 +150,8 @@ class VideoDecoder { // ---- SINGLE FRAME SEEK AND DECODING API ---- // Places the cursor at the first frame on or after the position in seconds. - // Calling getNextDecodedOutputNoDemux() will return the first frame at or - // after this position. + // Calling getNextFrameAsTensor() will return the first frame at or after this + // position. void setCursorPtsInSeconds(double seconds); struct DecodedOutput { // The actual decoded output as a Tensor. @@ -180,14 +177,13 @@ class VideoDecoder { }; // Decodes the frame where the current cursor position is. It also advances // the cursor to the next frame. - DecodedOutput getNextDecodedOutputNoDemux(); - // Decodes the first frame in any added stream that is visible at a given - // timestamp. Frames in the video have a presentation timestamp and a - // duration. For example, if a frame has presentation timestamp of 5.0s and a - // duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0). - // i.e. it will be returned when this function is called with seconds=5.0 or - // seconds=5.999, etc. - DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds); + DecodedOutput getNextDecodedOutput(); + // Decodes the frame that is visible at a given timestamp. Frames in the video + // have a presentation timestamp and a duration. For example, if a frame has + // presentation timestamp of 5.0s and a duration of 1.0s, it will be visible + // in the timestamp range [5.0, 6.0). i.e. it will be returned when this + // function is called with seconds=5.0 or seconds=5.999, etc. + DecodedOutput getFrameDisplayedAtTimestamp(double seconds); DecodedOutput getFrameAtIndex(int streamIndex, int64_t frameIndex); struct BatchDecodedOutput { torch::Tensor frames; @@ -281,8 +277,6 @@ class VideoDecoder { FilterState filterState; std::vector keyFrames; std::vector allFrames; - AVPixelFormat hwPixelFormat = AV_PIX_FMT_NONE; - UniqueAVBufferRef hwDeviceContext; }; VideoDecoder(); // Returns the key frame index of the presentation timestamp using FFMPEG's diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 073ef658c..4b271f2eb 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -29,7 +29,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def("create_from_file(str filename) -> Tensor"); m.def("create_from_tensor(Tensor video_tensor) -> Tensor"); m.def( - "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device_string=None) -> ()"); + "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None) -> ()"); m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); m.def("get_next_frame(Tensor(a!) decoder) -> (Tensor, Tensor, Tensor)"); m.def( @@ -117,8 +117,7 @@ void add_video_stream( std::optional height, std::optional num_threads, std::optional dimension_order, - std::optional stream_index, - std::optional device_string) { + std::optional stream_index) { VideoDecoder::VideoStreamDecoderOptions options; options.width = width; options.height = height; @@ -129,10 +128,6 @@ void add_video_stream( TORCH_CHECK(stdDimensionOrder == "NHWC" || stdDimensionOrder == "NCHW"); options.dimensionOrder = stdDimensionOrder; } - if (device_string.has_value()) { - std::string deviceString{device_string.value()}; - options.device = torch::Device(deviceString); - } auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->addVideoStreamDecoder(stream_index.value_or(-1), options); @@ -147,7 +142,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); VideoDecoder::DecodedOutput result; try { - result = videoDecoder->getNextDecodedOutputNoDemux(); + result = videoDecoder->getNextDecodedOutput(); } catch (const VideoDecoder::EndOfFileException& e) { throw pybind11::stop_iteration(e.what()); } @@ -161,7 +156,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder) { OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto result = videoDecoder->getFrameDisplayedAtTimestampNoDemux(seconds); + auto result = videoDecoder->getFrameDisplayedAtTimestamp(seconds); return makeOpsDecodedOutput(result); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index b391f710a..f3df2ada7 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -35,8 +35,7 @@ void add_video_stream( std::optional height = std::nullopt, std::optional num_threads = std::nullopt, std::optional dimension_order = std::nullopt, - std::optional stream_index = std::nullopt, - std::optional device_string = std::nullopt); + std::optional stream_index = std::nullopt); // Seek to a particular presentation timestamp in the video in seconds. void seek_to_pts(at::Tensor& decoder, double seconds); diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 518803a70..1d6460ba8 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -116,7 +116,6 @@ def add_video_stream_abstract( num_threads: Optional[int] = None, dimension_order: Optional[str] = None, stream_index: Optional[int] = None, - device_string: Optional[str] = None, ) -> None: return diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py index f189f36fa..b04411b8d 100644 --- a/src/torchcodec/decoders/_simple_video_decoder.py +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Iterable, Iterator, Literal, Tuple, Union -from torch import device as torch_device, Tensor +from torch import Tensor from torchcodec.decoders import _core as core @@ -89,14 +89,6 @@ class SimpleVideoDecoder: This can be either "NCHW" (default) or "NHWC", where N is the batch size, C is the number of channels, H is the height, and W is the width of the frames. - device (torch.device, optional): The device to use for decoding. - Currently we only support CPU and CUDA devices. If CUDA is used, - we use NVDEC and CUDA to do decoding and color-conversion - respectively. The resulting frame is left on the GPU for further - processing. - You can either pass in a string like "cpu" or "cuda:0" or a - torch.device like torch.device("cuda:0"). - Default: ``torch.device("cpu")``. .. note:: @@ -114,7 +106,6 @@ def __init__( self, source: Union[str, Path, bytes, Tensor], dimension_order: Literal["NCHW", "NHWC"] = "NCHW", - device: Union[str, torch_device] = torch_device("cpu"), ): if isinstance(source, str): self._decoder = core.create_from_file(source) @@ -138,20 +129,7 @@ def __init__( ) core.scan_all_streams_to_update_metadata(self._decoder) - num_threads = None - if isinstance(device, str): - device = torch_device(device) - if device.type == "cuda": - # Using multiple CPU threads seems to slow down decoding on CUDA. - # CUDA internally uses dedicated hardware to do decoding so we - # don't need CPU software threads here. - num_threads = 1 - core.add_video_stream( - self._decoder, - dimension_order=dimension_order, - device_string=str(device), - num_threads=num_threads, - ) + core.add_video_stream(self._decoder, dimension_order=dimension_order) self.metadata, self._stream_index = _get_and_validate_stream_metadata( self._decoder diff --git a/test/decoders/CMakeLists.txt b/test/decoders/CMakeLists.txt index a5a26704d..21791dde3 100644 --- a/test/decoders/CMakeLists.txt +++ b/test/decoders/CMakeLists.txt @@ -26,10 +26,6 @@ add_executable( VideoDecoderOpsTest.cpp ) -if(ENABLE_CUDA) - target_compile_definitions(VideoDecoderTest PRIVATE ENABLE_CUDA=1) -endif() - target_include_directories(VideoDecoderTest SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) target_include_directories(VideoDecoderTest PRIVATE ../../) target_include_directories(VideoDecoderOpsTest SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 04cbed0a6..057148b36 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -17,10 +17,6 @@ #include "tools/cxx/Resources.h" #endif -#ifdef ENABLE_CUDA -#include -#endif - using namespace ::testing; C10_DEFINE_bool( @@ -152,7 +148,7 @@ TEST(VideoDecoderTest, RespectsWidthAndHeightFromOptions) { streamOptions.width = 100; streamOptions.height = 120; decoder->addVideoStreamDecoder(-1, streamOptions); - torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame; + torch::Tensor tensor = decoder->getNextDecodedOutput().frame; EXPECT_EQ(tensor.sizes(), std::vector({3, 120, 100})); } @@ -163,7 +159,7 @@ TEST(VideoDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) { VideoDecoder::VideoStreamDecoderOptions streamOptions; streamOptions.dimensionOrder = "NHWC"; decoder->addVideoStreamDecoder(-1, streamOptions); - torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame; + torch::Tensor tensor = decoder->getNextDecodedOutput().frame; EXPECT_EQ(tensor.sizes(), std::vector({270, 480, 3})); } @@ -172,12 +168,12 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) { std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); - auto output = ourDecoder->getNextDecodedOutputNoDemux(); + auto output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor0FromOurDecoder = output.frame; EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 0.0); EXPECT_EQ(output.pts, 0); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor1FromOurDecoder = output.frame; EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000); @@ -205,54 +201,6 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) { } } -#ifdef ENABLE_CUDA -TEST(GPUVideoDecoderTest, ReturnsFirstTwoFramesOfVideo) { - if (!torch::cuda::is_available()) { - return; - } - at::cuda::getDefaultCUDAStream(); - std::string path = getResourcePath("nasa_13013.mp4"); - std::unique_ptr ourDecoder = - VideoDecoder::createFromFilePath(path); - VideoDecoder::VideoStreamDecoderOptions streamOptions; - streamOptions.device = torch::Device("cuda"); - ASSERT_TRUE(streamOptions.device.is_cuda()); - ASSERT_EQ(streamOptions.device.type(), torch::DeviceType::CUDA); - ourDecoder->addVideoStreamDecoder(-1, streamOptions); - auto output = ourDecoder->getNextDecodedOutputNoDemux(); - torch::Tensor tensor1FromOurDecoder = output.frame; - EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector({3, 270, 480})); - EXPECT_EQ(output.ptsSeconds, 0.0); - EXPECT_EQ(output.pts, 0); - output = ourDecoder->getNextDecodedOutputNoDemux(); - torch::Tensor tensor2FromOurDecoder = output.frame; - EXPECT_EQ(tensor2FromOurDecoder.sizes(), std::vector({3, 270, 480})); - EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000); - EXPECT_EQ(output.pts, 1001); - - torch::Tensor tensor1FromFFMPEG = - readTensorFromDisk("nasa_13013.mp4.frame000001.cuda.pt"); - torch::Tensor tensor2FromFFMPEG = - readTensorFromDisk("nasa_13013.mp4.frame000002.cuda.pt"); - - EXPECT_EQ(tensor1FromFFMPEG.sizes(), std::vector({3, 270, 480})); - EXPECT_EQ(tensor2FromFFMPEG.sizes(), std::vector({3, 270, 480})); - EXPECT_EQ(tensor1FromOurDecoder.device().type(), torch::DeviceType::CUDA); - EXPECT_EQ(tensor2FromOurDecoder.device().type(), torch::DeviceType::CUDA); - torch::Tensor tensor1FromOurDecoderCPU = tensor1FromOurDecoder.cpu(); - torch::Tensor tensor2FromOurDecoderCPU = tensor1FromOurDecoder.cpu(); - EXPECT_TRUE(torch::equal(tensor1FromOurDecoderCPU, tensor1FromFFMPEG)); - EXPECT_TRUE(torch::equal(tensor2FromOurDecoderCPU, tensor2FromFFMPEG)); - - if (FLAGS_dump_frames_for_debugging) { - dumpTensorToDisk(tensor1FromFFMPEG, "tensor1FromFFMPEG.pt"); - dumpTensorToDisk(tensor2FromFFMPEG, "tensor2FromFFMPEG.pt"); - dumpTensorToDisk(tensor1FromOurDecoderCPU, "tensor1FromOurDecoder.pt"); - dumpTensorToDisk(tensor2FromOurDecoderCPU, "tensor2FromOurDecoder.pt"); - } -} -#endif - TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr ourDecoder = @@ -306,11 +254,11 @@ TEST_P(VideoDecoderTest, SeeksCloseToEof) { createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); ourDecoder->setCursorPtsInSeconds(388388. / 30'000); - auto output = ourDecoder->getNextDecodedOutputNoDemux(); + auto output = ourDecoder->getNextDecodedOutput(); EXPECT_EQ(output.ptsSeconds, 388'388. / 30'000); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextDecodedOutput(); EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000); - EXPECT_THROW(ourDecoder->getNextDecodedOutputNoDemux(), std::exception); + EXPECT_THROW(ourDecoder->getNextDecodedOutput(), std::exception); } TEST_P(VideoDecoderTest, GetsFrameDisplayedAtTimestamp) { @@ -318,19 +266,18 @@ TEST_P(VideoDecoderTest, GetsFrameDisplayedAtTimestamp) { std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); - auto output = ourDecoder->getFrameDisplayedAtTimestampNoDemux(6.006); + auto output = ourDecoder->getFrameDisplayedAtTimestamp(6.006); EXPECT_EQ(output.ptsSeconds, 6.006); // The frame's duration is 0.033367 according to ffprobe, // so the next frame is displayed at timestamp=6.039367. const double kNextFramePts = 6.039366666666667; // The frame that is displayed a microsecond before the next frame is still // the previous frame. - output = - ourDecoder->getFrameDisplayedAtTimestampNoDemux(kNextFramePts - 1e-6); + output = ourDecoder->getFrameDisplayedAtTimestamp(kNextFramePts - 1e-6); EXPECT_EQ(output.ptsSeconds, 6.006); // The frame that is displayed at the exact pts of the frame is the next // frame. - output = ourDecoder->getFrameDisplayedAtTimestampNoDemux(kNextFramePts); + output = ourDecoder->getFrameDisplayedAtTimestamp(kNextFramePts); EXPECT_EQ(output.ptsSeconds, kNextFramePts); // This is the timestamp of the last frame in this video. @@ -340,7 +287,7 @@ TEST_P(VideoDecoderTest, GetsFrameDisplayedAtTimestamp) { kPtsOfLastFrameInVideoStream + kDurationOfLastFrameInVideoStream; // Sanity check: make sure duration is strictly positive. EXPECT_GT(kPtsPlusDurationOfLastFrame, kPtsOfLastFrameInVideoStream); - output = ourDecoder->getFrameDisplayedAtTimestampNoDemux( + output = ourDecoder->getFrameDisplayedAtTimestamp( kPtsPlusDurationOfLastFrame - 1e-6); EXPECT_EQ(output.ptsSeconds, kPtsOfLastFrameInVideoStream); } @@ -351,7 +298,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); ourDecoder->setCursorPtsInSeconds(6.0); - auto output = ourDecoder->getNextDecodedOutputNoDemux(); + auto output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor6FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000); torch::Tensor tensor6FromFFMPEG = @@ -367,7 +314,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { EXPECT_GT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 180); ourDecoder->setCursorPtsInSeconds(6.1); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor61FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 183'183. / 30'000); torch::Tensor tensor61FromFFMPEG = @@ -387,7 +334,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { EXPECT_LT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 10); ourDecoder->setCursorPtsInSeconds(10.0); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor10FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 300'300. / 30'000); torch::Tensor tensor10FromFFMPEG = @@ -404,7 +351,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { EXPECT_GT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 60); ourDecoder->setCursorPtsInSeconds(6.0); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextDecodedOutput(); tensor6FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000); EXPECT_TRUE(torch::equal(tensor6FromOurDecoder, tensor6FromFFMPEG)); @@ -419,7 +366,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { constexpr double kPtsOfLastFrameInVideoStream = 389'389. / 30'000; // ~12.9 ourDecoder->setCursorPtsInSeconds(kPtsOfLastFrameInVideoStream); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor7FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000); torch::Tensor tensor7FromFFMPEG = diff --git a/test/decoders/manual_smoke_test.py b/test/decoders/manual_smoke_test.py index 389aa5f4e..7351155c1 100644 --- a/test/decoders/manual_smoke_test.py +++ b/test/decoders/manual_smoke_test.py @@ -4,36 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import argparse import os import torchcodec from torchvision.io.image import write_png - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--device", default="cpu", type=str, help="Specify 'cuda:0' for CUDA decoding" - ) - args = parser.parse_args() - - decoder = torchcodec.decoders._core.create_from_file( - os.path.dirname(__file__) + "/../resources/nasa_13013.mp4" - ) - torchcodec.decoders._core.scan_all_streams_to_update_metadata(decoder) - torchcodec.decoders._core.add_video_stream( - decoder, stream_index=3, device_string=args.device - ) - frame, _, _ = torchcodec.decoders._core.get_frame_at_index( - decoder, stream_index=3, frame_index=180 - ) - if "cuda" in args.device: - output_name = "frame180.cuda.png" - else: - output_name = "frame180.cpu.png" - write_png(frame.cpu(), output_name) - - -if __name__ == "__main__": - main() +decoder = torchcodec.decoders._core.create_from_file( + os.path.dirname(__file__) + "/../resources/nasa_13013.mp4" +) +torchcodec.decoders._core.scan_all_streams_to_update_metadata(decoder) +torchcodec.decoders._core.add_video_stream(decoder, stream_index=3) +frame, _, _ = torchcodec.decoders._core.get_frame_at_index( + decoder, stream_index=3, frame_index=180 +) +write_png(frame, "frame180.png") diff --git a/test/decoders/test_simple_video_decoder.py b/test/decoders/test_simple_video_decoder.py index 1187efbf1..e064b904d 100644 --- a/test/decoders/test_simple_video_decoder.py +++ b/test/decoders/test_simple_video_decoder.py @@ -45,34 +45,6 @@ def test_create_fails(self): with pytest.raises(TypeError, match="Unknown source type"): decoder = SimpleVideoDecoder(123) # noqa - def test_can_accept_devices(self): - # You can pass a CPU device as a string... - decoder = SimpleVideoDecoder(NASA_VIDEO.path, device="cpu") - assert_tensor_equal(decoder[0], NASA_VIDEO.get_frame_data_by_index(0)) - - # ...or as a torch.device. - decoder = SimpleVideoDecoder(NASA_VIDEO.path, device=torch.device("cpu")) - assert_tensor_equal(decoder[0], NASA_VIDEO.get_frame_data_by_index(0)) - - if torch.cuda.is_available(): - # You can pass a CUDA device as a string... - decoder = SimpleVideoDecoder(NASA_VIDEO.path, device="cuda") - frame = decoder[0] - assert frame.device.type == "cuda" - assert frame.shape == torch.Size( - [NASA_VIDEO.num_color_channels, NASA_VIDEO.height, NASA_VIDEO.width] - ) - - # ...or as a torch.device. - decoder = SimpleVideoDecoder(NASA_VIDEO.path, device=torch.device("cuda")) - frame = decoder[0] - assert frame.device.type == "cuda" - assert frame.shape == torch.Size( - [NASA_VIDEO.num_color_channels, NASA_VIDEO.height, NASA_VIDEO.width] - ) - # TODO: compare tensor values too. We don't compare values because - # the exact values are hardware-dependent. - def test_getitem_int(self): decoder = SimpleVideoDecoder(NASA_VIDEO.path) diff --git a/test/generate_reference_resources.sh b/test/generate_reference_resources.sh index ccc7262e1..e90e26ae8 100755 --- a/test/generate_reference_resources.sh +++ b/test/generate_reference_resources.sh @@ -40,8 +40,6 @@ ffmpeg -y -ss 12.979633 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time12.979633. # Audio generation in the form of an mp3. ffmpeg -y -i "$VIDEO_PATH" -b:a 192K -vn "$VIDEO_PATH.audio.mp3" -# TODO: Add frames decoded by Nvidia's NVDEC. - # This video was generated by running the following: # conda install -c conda-forge x265 # ./configure --enable-nonfree --enable-gpl --prefix=$(readlink -f ../bin) --enable-libx265 --enable-rpath --extra-ldflags=-Wl,-rpath=$CONDA_PREFIX/lib --enable-filter=drawtext --enable-libfontconfig --enable-libfreetype --enable-libharfbuzz diff --git a/test/resources/nasa_13013.mp4.frame000001.cuda.pt b/test/resources/nasa_13013.mp4.frame000001.cuda.pt deleted file mode 100644 index 17c59fd4f..000000000 Binary files a/test/resources/nasa_13013.mp4.frame000001.cuda.pt and /dev/null differ diff --git a/test/resources/nasa_13013.mp4.frame000002.cuda.pt b/test/resources/nasa_13013.mp4.frame000002.cuda.pt deleted file mode 100644 index 17c59fd4f..000000000 Binary files a/test/resources/nasa_13013.mp4.frame000002.cuda.pt and /dev/null differ diff --git a/test/samplers/test_video_clip_sampler.py b/test/samplers/test_video_clip_sampler.py index 4768f37cd..963a1c9ad 100644 --- a/test/samplers/test_video_clip_sampler.py +++ b/test/samplers/test_video_clip_sampler.py @@ -4,7 +4,6 @@ import pytest import torch from torchcodec._samplers import ( - DecoderArgs, IndexBasedSamplerArgs, TimeBasedSamplerArgs, VideoArgs, @@ -31,16 +30,11 @@ ), ], ) -@pytest.mark.parametrize(("device"), [torch.device("cpu"), torch.device("cuda:0")]) -def test_sampler(sampler_args, device): - if device.type == "cuda" and not torch.cuda.is_available(): - pytest.skip("GPU not available") - +def test_sampler(sampler_args): torch.manual_seed(0) desired_width, desired_height = 320, 240 video_args = VideoArgs(desired_width=desired_width, desired_height=desired_height) - decoder_args = DecoderArgs(device=device) - sampler = VideoClipSampler(video_args, sampler_args, decoder_args) + sampler = VideoClipSampler(video_args, sampler_args) clips = sampler(NASA_VIDEO.to_tensor()) assert_tensor_equal(len(clips), sampler_args.clips_per_video) clip = clips[0]