From 30fe7349028289492a349226b32c937b81b8cafa Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 7 Mar 2025 11:56:26 -0800 Subject: [PATCH 01/54] Remove unused C++ decoder creation --- .../decoders/_core/VideoDecoderOps.cpp | 14 ----- .../decoders/_core/VideoDecoderOps.h | 7 --- test/decoders/CMakeLists.txt | 14 ----- test/decoders/VideoDecoderOpsTest.cpp | 51 ------------------- 4 files changed, 86 deletions(-) delete mode 100644 test/decoders/VideoDecoderOpsTest.cpp diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index ad8c6f258..bb13e113d 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -142,20 +142,6 @@ at::Tensor create_from_tensor( return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } -at::Tensor create_from_buffer( - const void* buffer, - size_t length, - std::optional seek_mode) { - VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact; - if (seek_mode.has_value()) { - realSeek = seekModeFromString(seek_mode.value()); - } - - std::unique_ptr uniqueDecoder = - std::make_unique(buffer, length, realSeek); - return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); -} - void add_video_stream( at::Tensor& decoder, std::optional width, diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index a3cc821ad..034a8842a 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -28,13 +28,6 @@ at::Tensor create_from_tensor( at::Tensor video_tensor, std::optional seek_mode = std::nullopt); -// This API is C++ only and will not be exposed via custom ops, use -// videodecoder_create_from_bytes in Python -at::Tensor create_from_buffer( - const void* buffer, - size_t length, - std::optional seek_mode = std::nullopt); - // Add a new video stream at `stream_index` using the provided options. void add_video_stream( at::Tensor& decoder, diff --git a/test/decoders/CMakeLists.txt b/test/decoders/CMakeLists.txt index 21791dde3..3350c92c5 100644 --- a/test/decoders/CMakeLists.txt +++ b/test/decoders/CMakeLists.txt @@ -21,15 +21,8 @@ add_executable( VideoDecoderTest.cpp ) -add_executable( - VideoDecoderOpsTest - VideoDecoderOpsTest.cpp -) - target_include_directories(VideoDecoderTest SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) target_include_directories(VideoDecoderTest PRIVATE ../../) -target_include_directories(VideoDecoderOpsTest SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) -target_include_directories(VideoDecoderOpsTest PRIVATE ../../) target_link_libraries( VideoDecoderTest @@ -37,12 +30,5 @@ target_link_libraries( GTest::gtest_main ) -target_link_libraries( - VideoDecoderOpsTest - ${libtorchcodec_target_name} - GTest::gtest_main -) - include(GoogleTest) gtest_discover_tests(VideoDecoderTest) -gtest_discover_tests(VideoDecoderOpsTest) diff --git a/test/decoders/VideoDecoderOpsTest.cpp b/test/decoders/VideoDecoderOpsTest.cpp deleted file mode 100644 index f2414797d..000000000 --- a/test/decoders/VideoDecoderOpsTest.cpp +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include "src/torchcodec/decoders/_core/VideoDecoderOps.h" - -#include -#include -#include -#include - -#ifdef FBCODE_BUILD -#include "tools/cxx/Resources.h" -#endif - -using namespace ::testing; - -namespace facebook::torchcodec { - -std::string getResourcePath(const std::string& filename) { -#ifdef FBCODE_BUILD - std::string filepath = "pytorch/torchcodec/test/resources/" + filename; - filepath = build::getResourcePath(filepath).string(); -#else - std::filesystem::path dirPath = std::filesystem::path(__FILE__); - std::string filepath = - dirPath.parent_path().string() + "/../resources/" + filename; -#endif - return filepath; -} - -TEST(VideoDecoderOpsTest, TestCreateDecoderFromBuffer) { - std::string filepath = getResourcePath("nasa_13013.mp4"); - std::ostringstream outputStringStream; - std::ifstream input(filepath, std::ios::binary); - outputStringStream << input.rdbuf(); - std::string content = outputStringStream.str(); - void* buffer = content.data(); - size_t length = outputStringStream.str().length(); - at::Tensor decoder = create_from_buffer(buffer, length); - add_video_stream(decoder); - auto result = get_next_frame(decoder); - at::Tensor tensor1 = std::get<0>(result); - EXPECT_EQ(tensor1.sizes(), std::vector({3, 270, 480})); - EXPECT_EQ(std::get<1>(result).item(), 0); - EXPECT_NEAR(std::get<2>(result).item(), 0.033367, 1e-6); -} - -} // namespace facebook::torchcodec From a0930032c33ca35466930e493857065a40ceff15 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 14 Mar 2025 12:44:00 -0700 Subject: [PATCH 02/54] Add support for decoding from Python file-like objects --- src/torchcodec/decoders/_core/CMakeLists.txt | 135 ++++++++++++------ .../decoders/_core/FFMPEGCommon.cpp | 43 +++--- src/torchcodec/decoders/_core/FFMPEGCommon.h | 31 ++-- .../decoders/_core/VideoDecoder.cpp | 24 ++-- src/torchcodec/decoders/_core/VideoDecoder.h | 10 +- .../decoders/_core/VideoDecoderOps.cpp | 30 ++-- src/torchcodec/decoders/_core/__init__.py | 1 + .../decoders/_core/video_decoder_ops.py | 34 ++++- test/decoders/CMakeLists.txt | 3 +- test/decoders/VideoDecoderTest.cpp | 9 +- test/decoders/test_ops.py | 9 +- 11 files changed, 222 insertions(+), 107 deletions(-) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 688a249d5..2fcd5f55a 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -7,50 +7,99 @@ find_package(Torch REQUIRED) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}") find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) -function(make_torchcodec_library library_name ffmpeg_target) - set( - sources - FFMPEGCommon.h - FFMPEGCommon.cpp - VideoDecoder.h - VideoDecoder.cpp - VideoDecoderOps.h - VideoDecoderOps.cpp - DeviceInterface.h - ) - if(ENABLE_CUDA) - list(APPEND sources CudaDevice.cpp) - else() - list(APPEND sources CPUOnlyDevice.cpp) - endif() - add_library(${library_name} SHARED ${sources}) - set_property(TARGET ${library_name} PROPERTY CXX_STANDARD 17) +function(make_torchcodec_sublibrary + library_name + sources + dependent_libraries + ffmpeg_include_dirs) - target_include_directories( - ${library_name} + add_library(${library_name} SHARED ${sources}) + set_target_properties(${library_name} PROPERTIES CXX_STANDARD 17) + target_include_directories(${library_name} PRIVATE ./../../../../ "${TORCH_INSTALL_PREFIX}/include" ${Python3_INCLUDE_DIRS} + ${ffmpeg_include_dirs} ) - set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES} - ${Python3_LIBRARIES}) - if(ENABLE_CUDA) - list(APPEND NEEDED_LIBRARIES - ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} ) - endif() + # Avoid adding the "lib" prefix which we already add explicitly. + set_target_properties(${library_name} PROPERTIES PREFIX "") + target_link_libraries( ${library_name} PUBLIC - ${NEEDED_LIBRARIES} + ${dependent_libraries} ) +endfunction() - # We already set the library_name to be libtorchcodecN, so we don't want - # cmake to add another "lib" prefix. We do it this way because it makes it - # easier to find references to libtorchcodec in the code (e.g. via `git - # grep`) - set_target_properties(${library_name} PROPERTIES PREFIX "") +function(make_torchcodec_libraries + ffmpeg_major_version + ffmpeg_target + ffmpeg_include_dirs) + + # Create libtorchcodec_decoderN.so + set(decoder_library_name "libtorchcodec_decoder${ffmpeg_major_version}") + set(decoder_sources FFMPEGCommon.cpp VideoDecoder.cpp) + + if(ENABLE_CUDA) + list(APPEND decoder_sources CudaDevice.cpp) + else() + list(APPEND decoder_sources CPUOnlyDevice.cpp) + endif() + + set(decoder_dependent_libraries + ${ffmpeg_target} + ${TORCH_LIBRARIES} + ${Python3_LIBRARIES} + ) + + if(ENABLE_CUDA) + list(APPEND decoder_dependent_libraries + ${CUDA_nppi_LIBRARY} + ${CUDA_nppicc_LIBRARY} + ) + endif() + + make_torchcodec_sublibrary( + "${decoder_library_name}" + "${decoder_sources}" + "${decoder_dependent_libraries}" + "${ffmpeg_include_dirs}" + ) + + # Create libtorchcodec_custom_opsN.so + set(custom_ops_library_name "libtorchcodec_custom_ops${ffmpeg_major_version}") + set(custom_ops_sources VideoDecoderOps.cpp) + make_torchcodec_sublibrary( + "${custom_ops_library_name}" + "${custom_ops_sources}" + "${decoder_library_name}" + "${ffmpeg_include_dirs}" + ) + + # Create libtorchcodec_pybind_opsN.so + set(pybind_ops_library_name "libtorchcodec_pybind_ops${ffmpeg_major_version}") + set(pybind_ops_sources PyBindOps.cpp) + make_torchcodec_sublibrary( + "${pybind_ops_library_name}" + "${pybind_ops_sources}" + "${decoder_library_name}" + "${ffmpeg_include_dirs}" + ) + target_compile_definitions( + ${pybind_ops_library_name} + PUBLIC + TORCHCODEC_PYBIND=_torchcodec_pybind_ops${ffmpeg_major_version} + ) + + # Install all libraries. + set( + all_libraries + ${decoder_library_name} + ${custom_ops_library_name} + ${pybind_ops_library_name} + ) # The install step is invoked within CMakeBuild.build_library() in # setup.py and just copies the built .so files from the temp @@ -58,7 +107,7 @@ function(make_torchcodec_library library_name ffmpeg_target) # still need to manually pass "DESTINATION ..." for cmake to copy those # files in CMAKE_INSTALL_PREFIX instead of CMAKE_INSTALL_PREFIX/lib. install( - TARGETS ${library_name} + TARGETS ${all_libraries} LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX} ) endfunction() @@ -76,11 +125,10 @@ if(DEFINED ENV{BUILD_AGAINST_ALL_FFMPEG_FROM_S3}) ${CMAKE_CURRENT_SOURCE_DIR}/fetch_and_expose_non_gpl_ffmpeg_libs.cmake ) - - make_torchcodec_library(libtorchcodec4 ffmpeg4) - make_torchcodec_library(libtorchcodec7 ffmpeg7) - make_torchcodec_library(libtorchcodec6 ffmpeg6) - make_torchcodec_library(libtorchcodec5 ffmpeg5) + make_torchcodec_libraries(4 ffmpeg4 $ffmpeg4_INCLUDE_DIRS) + make_torchcodec_libraries(7 ffmpeg7 $ffmpeg7_INCLUDE_DIRs) + make_torchcodec_libraries(6 ffmpeg6 $ffmpeg6_INCLUDE_DIRS) + make_torchcodec_libraries(5 ffmpeg5 $ffmpeg5_INCLUDE_DIRS) else() message( @@ -120,10 +168,11 @@ else() ) endif() - set(libtorchcodec_target_name libtorchcodec${ffmpeg_major_version}) - # Make libtorchcodec_target_name available in the parent's scope, for the - # test's CMakeLists.txt - set(libtorchcodec_target_name ${libtorchcodec_target_name} PARENT_SCOPE) + make_torchcodec_libraries(${ffmpeg_major_version} PkgConfig::LIBAV ${LIBAV_INCLUDE_DIRS}) - make_torchcodec_library(${libtorchcodec_target_name} PkgConfig::LIBAV) + # Expose these values updwards so that the test compilation does not need + # to re-figure it out. FIXME: it's not great that we just copy-paste the + # library name. + set(libtorchcodec_library_name "libtorchcodec_decoder${ffmpeg_major_version}" PARENT_SCOPE) + set(libav_include_dirs ${LIBAV_INCLUDE_DIRS} PARENT_SCOPE) endif() diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp index 4185d9b94..2531c54ff 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp @@ -73,9 +73,12 @@ int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext) { AVIOBytesContext::AVIOBytesContext( const void* data, - size_t dataSize, - size_t bufferSize) - : bufferData_{static_cast(data), dataSize, 0} { + int64_t dataSize, + int bufferSize) + : dataContext_{static_cast(data), dataSize, 0} { + TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!"); + TORCH_CHECK(dataSize > 0, "Video data size must be positive"); + auto buffer = static_cast(av_malloc(bufferSize)); TORCH_CHECK( buffer != nullptr, @@ -85,7 +88,7 @@ AVIOBytesContext::AVIOBytesContext( buffer, bufferSize, 0, - &bufferData_, + &dataContext_, &AVIOBytesContext::read, nullptr, &AVIOBytesContext::seek)); @@ -102,50 +105,50 @@ AVIOBytesContext::~AVIOBytesContext() { } } -AVIOContext* AVIOBytesContext::getAVIO() { +AVIOContext* AVIOBytesContext::getAVIOContext() const { return avioContext_.get(); } -// The signature of this function is defined by FFMPEG. +// The signature of this function is defined by FFmpeg. int AVIOBytesContext::read(void* opaque, uint8_t* buf, int buf_size) { - auto bufferData = static_cast(opaque); + auto dataContext = static_cast(opaque); TORCH_CHECK( - bufferData->current <= bufferData->size, + dataContext->current <= dataContext->size, "Tried to read outside of the buffer: current=", - bufferData->current, + dataContext->current, ", size=", - bufferData->size); + dataContext->size); - buf_size = - FFMIN(buf_size, static_cast(bufferData->size - bufferData->current)); + buf_size = FFMIN( + buf_size, static_cast(dataContext->size - dataContext->current)); TORCH_CHECK( buf_size >= 0, "Tried to read negative bytes: buf_size=", buf_size, ", size=", - bufferData->size, + dataContext->size, ", current=", - bufferData->current); + dataContext->current); if (!buf_size) { return AVERROR_EOF; } - memcpy(buf, bufferData->data + bufferData->current, buf_size); - bufferData->current += buf_size; + memcpy(buf, dataContext->data + dataContext->current, buf_size); + dataContext->current += buf_size; return buf_size; } -// The signature of this function is defined by FFMPEG. +// The signature of this function is defined by FFmpeg. int64_t AVIOBytesContext::seek(void* opaque, int64_t offset, int whence) { - auto bufferData = static_cast(opaque); + auto dataContext = static_cast(opaque); int64_t ret = -1; switch (whence) { case AVSEEK_SIZE: - ret = bufferData->size; + ret = dataContext->size; break; case SEEK_SET: - bufferData->current = offset; + dataContext->current = offset; ret = offset; break; default: diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.h b/src/torchcodec/decoders/_core/FFMPEGCommon.h index 0454058bc..b07112588 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.h +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.h @@ -144,24 +144,27 @@ int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext); // Returns true if sws_scale can handle unaligned data. bool canSwsScaleHandleUnalignedData(); +// TODO: explain purpose of context holder +class AVIOContextHolder { + public: + virtual ~AVIOContextHolder(){}; + virtual AVIOContext* getAVIOContext() const = 0; +}; + +// TODO: make comment below better // A struct that holds state for reading bytes from an IO context. // We give this to FFMPEG and it will pass it back to us when it needs to read // or seek in the memory buffer. -struct AVIOBufferData { - const uint8_t* data; - size_t size; - size_t current; -}; - +// // A class that can be used as AVFormatContext's IO context. It reads from a // memory buffer that is passed in. -class AVIOBytesContext { +class AVIOBytesContext : public AVIOContextHolder { public: - AVIOBytesContext(const void* data, size_t dataSize, size_t bufferSize); - ~AVIOBytesContext(); + AVIOBytesContext(const void* data, int64_t dataSize, int bufferSize); + virtual ~AVIOBytesContext(); // Returns the AVIOContext that can be passed to FFMPEG. - AVIOContext* getAVIO(); + virtual AVIOContext* getAVIOContext() const override; // The signature of this function is defined by FFMPEG. static int read(void* opaque, uint8_t* buf, int buf_size); @@ -170,8 +173,14 @@ class AVIOBytesContext { static int64_t seek(void* opaque, int64_t offset, int whence); private: + struct DataContext { + const uint8_t* data; + int64_t size; + int64_t current; + }; + UniqueAVIOContext avioContext_; - struct AVIOBufferData bufferData_; + DataContext dataContext_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 6e7e72f27..d9b8b4c29 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -80,15 +80,13 @@ VideoDecoder::VideoDecoder(const std::string& videoFilePath, SeekMode seekMode) initializeDecoder(); } -VideoDecoder::VideoDecoder(const void* data, size_t length, SeekMode seekMode) - : seekMode_(seekMode) { - TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!"); - +VideoDecoder::VideoDecoder( + std::unique_ptr context, + SeekMode seekMode) + : seekMode_(seekMode), avioContextHolder_(std::move(context)) { av_log_set_level(AV_LOG_QUIET); - constexpr int bufferSize = 64 * 1024; - ioBytesContext_.reset(new AVIOBytesContext(data, length, bufferSize)); - TORCH_CHECK(ioBytesContext_, "Failed to create AVIOBytesContext"); + TORCH_CHECK(avioContextHolder_, "Context holder cannot be null"); // Because FFmpeg requires a reference to a pointer in the call to open, we // can't use a unique pointer here. Note that means we must call free if open @@ -96,7 +94,7 @@ VideoDecoder::VideoDecoder(const void* data, size_t length, SeekMode seekMode) AVFormatContext* rawContext = avformat_alloc_context(); TORCH_CHECK(rawContext != nullptr, "Unable to alloc avformat context"); - rawContext->pb = ioBytesContext_->getAVIO(); + rawContext->pb = avioContextHolder_->getAVIOContext(); int status = avformat_open_input(&rawContext, nullptr, nullptr, nullptr); if (status != 0) { avformat_free_context(rawContext); @@ -1747,4 +1745,14 @@ FrameDims getHeightAndWidthFromOptionsOrAVFrame( videoStreamOptions.width.value_or(avFrame.width)); } +VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode) { + if (seekMode == "exact") { + return VideoDecoder::SeekMode::exact; + } else if (seekMode == "approximate") { + return VideoDecoder::SeekMode::approximate; + } else { + TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode)); + } +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index a28dcf9cb..47dd2caaa 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -34,11 +34,9 @@ class VideoDecoder { const std::string& videoFilePath, SeekMode seekMode = SeekMode::exact); - // Creates a VideoDecoder from a given buffer of data. Note that the data is - // not owned by the VideoDecoder. + // TODO: make comment accurate explicit VideoDecoder( - const void* data, - size_t length, + std::unique_ptr context, SeekMode seekMode = SeekMode::exact); // -------------------------------------------------------------------------- @@ -472,7 +470,7 @@ class VideoDecoder { // Stores various internal decoding stats. DecodeStats decodeStats_; // Stores the AVIOContext for the input buffer. - std::unique_ptr ioBytesContext_; + std::unique_ptr avioContextHolder_; // Whether or not we have already scanned all streams to update the metadata. bool scannedAllStreams_ = false; // Tracks that we've already been initialized. @@ -554,4 +552,6 @@ std::ostream& operator<<( std::ostream& os, const VideoDecoder::DecodeStats& stats); +VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index bb13e113d..0578d5771 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -13,6 +13,8 @@ #include "c10/util/Exception.h" #include "src/torchcodec/decoders/_core/VideoDecoder.h" +namespace py = pybind11; + namespace facebook::torchcodec { // ============================== @@ -30,6 +32,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); + m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); m.def( "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None) -> ()"); m.def( @@ -68,7 +71,7 @@ at::Tensor wrapDecoderPointerToTensor( auto deleter = [decoder](void*) { delete decoder; }; at::Tensor tensor = - at::from_blob(decoder, {sizeof(VideoDecoder)}, deleter, {at::kLong}); + at::from_blob(decoder, {sizeof(VideoDecoder*)}, deleter, {at::kLong}); auto videoDecoder = static_cast(tensor.mutable_data_ptr()); TORCH_CHECK_EQ(videoDecoder, decoder) << "videoDecoder=" << videoDecoder; return tensor; @@ -93,16 +96,6 @@ OpsFrameBatchOutput makeOpsFrameBatchOutput( return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds); } -VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode) { - if (seekMode == "exact") { - return VideoDecoder::SeekMode::exact; - } else if (seekMode == "approximate") { - return VideoDecoder::SeekMode::approximate; - } else { - throw std::runtime_error("Invalid seek mode: " + std::string(seekMode)); - } -} - } // namespace // ============================== @@ -129,7 +122,7 @@ at::Tensor create_from_tensor( at::Tensor video_tensor, std::optional seek_mode) { TORCH_CHECK(video_tensor.is_contiguous(), "video_tensor must be contiguous"); - void* buffer = video_tensor.mutable_data_ptr(); + void* data = video_tensor.mutable_data_ptr(); size_t length = video_tensor.numel(); VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact; @@ -137,8 +130,18 @@ at::Tensor create_from_tensor( realSeek = seekModeFromString(seek_mode.value()); } + constexpr int bufferSize = 64 * 1024; + auto contextHolder = + std::make_unique(data, length, bufferSize); + std::unique_ptr uniqueDecoder = - std::make_unique(buffer, length, realSeek); + std::make_unique(std::move(contextHolder), realSeek); + return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); +} + +at::Tensor _convert_to_tensor(int64_t decoder_ptr) { + auto decoder = reinterpret_cast(decoder_ptr); + std::unique_ptr uniqueDecoder(decoder); return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } @@ -521,6 +524,7 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) { TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("create_from_file", &create_from_file); m.impl("create_from_tensor", &create_from_tensor); + m.impl("_convert_to_tensor", &_convert_to_tensor); m.impl( "_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions); } diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index 7dcb866c9..0834ee394 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -20,6 +20,7 @@ add_video_stream, create_from_bytes, create_from_file, + create_from_file_like, create_from_tensor, get_ffmpeg_library_versions, get_frame_at_index, diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 190384684..ce1ed1aa4 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import importlib +import io import json import warnings from typing import List, Optional, Tuple @@ -15,6 +17,8 @@ _get_extension_path, ) +_pybind_ops = None + def load_torchcodec_extension(): # Successively try to load libtorchcodec7.so, libtorchcodec6.so, @@ -27,9 +31,20 @@ def load_torchcodec_extension(): exceptions = [] for ffmpeg_major_version in (7, 6, 5, 4): - library_name = f"libtorchcodec{ffmpeg_major_version}" + decoder_library_name = f"libtorchcodec_decoder{ffmpeg_major_version}" + custom_ops_library_name = f"libtorchcodec_custom_ops{ffmpeg_major_version}" + pybind_ops_library_name = f"libtorchcodec_pybind_ops{ffmpeg_major_version}" try: - torch.ops.load_library(_get_extension_path(library_name)) + torch.ops.load_library(_get_extension_path(decoder_library_name)) + torch.ops.load_library(_get_extension_path(custom_ops_library_name)) + torch.ops.load_library(_get_extension_path(pybind_ops_library_name)) + spec = importlib.util.spec_from_file_location( + f"_torchcodec_pybind_ops{ffmpeg_major_version}", + _get_extension_path(pybind_ops_library_name), + ) + global _pybind_ops + _pybind_ops = importlib.util.module_from_spec(spec) + assert _pybind_ops is not None return except Exception as e: # TODO: recording and reporting exceptions this way is OK for now as it's just for debugging, @@ -67,6 +82,9 @@ def load_torchcodec_extension(): create_from_tensor = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.create_from_tensor.default ) +_convert_to_tensor = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns._convert_to_tensor.default +) add_video_stream = torch.ops.torchcodec_ns.add_video_stream.default _add_video_stream = torch.ops.torchcodec_ns._add_video_stream.default add_audio_stream = torch.ops.torchcodec_ns.add_audio_stream.default @@ -107,6 +125,13 @@ def create_from_bytes( return create_from_tensor(buffer, seek_mode) +def create_from_file_like( + file_like: io.RawIOBase, seek_mode: Optional[str] = None +) -> torch.Tensor: + assert _pybind_ops is not None + return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode)) + + # ============================== # Abstract impl for the operators. Needed by torch.compile. # ============================== @@ -122,6 +147,11 @@ def create_from_tensor_abstract( return torch.empty([], dtype=torch.long) +@register_fake("torchcodec_ns::_convert_to_tensor") +def _convert_to_tensor_abstract(decoder_ptr: int) -> torch.Tensor: + return torch.empty([], dtype=torch.long) + + @register_fake("torchcodec_ns::_add_video_stream") def _add_video_stream_abstract( decoder: torch.Tensor, diff --git a/test/decoders/CMakeLists.txt b/test/decoders/CMakeLists.txt index 3350c92c5..1dd6ce153 100644 --- a/test/decoders/CMakeLists.txt +++ b/test/decoders/CMakeLists.txt @@ -22,11 +22,12 @@ add_executable( ) target_include_directories(VideoDecoderTest SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) +target_include_directories(VideoDecoderTest SYSTEM PRIVATE ${libav_include_dirs}) target_include_directories(VideoDecoderTest PRIVATE ../../) target_link_libraries( VideoDecoderTest - ${libtorchcodec_target_name} + ${libtorchcodec_library_name} GTest::gtest_main ) diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 145663227..ca9f9c64b 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -48,10 +48,15 @@ class VideoDecoderTest : public testing::TestWithParam { std::ifstream input(filepath, std::ios::binary); outputStringStream << input.rdbuf(); content_ = outputStringStream.str(); + void* buffer = content_.data(); - size_t length = outputStringStream.str().length(); + size_t length = content_.length(); + constexpr int bufferSize = 64 * 1024; + auto contextHolder = + std::make_unique(buffer, length, bufferSize); + return std::make_unique( - buffer, length, VideoDecoder::SeekMode::approximate); + std::move(contextHolder), VideoDecoder::SeekMode::approximate); } else { return std::make_unique( filepath, VideoDecoder::SeekMode::approximate); diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index ab0e2bb09..916a1b5db 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -23,6 +23,7 @@ add_video_stream, create_from_bytes, create_from_file, + create_from_file_like, create_from_tensor, get_ffmpeg_library_versions, get_frame_at_index, @@ -340,7 +341,7 @@ def get_frame1_and_frame_time6(decoder): assert_frames_equal(frame_time6, reference_frame_time6.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) - @pytest.mark.parametrize("create_from", ("file", "tensor", "bytes")) + @pytest.mark.parametrize("create_from", ("file", "tensor", "bytes", "file_like")) def test_create_decoder(self, create_from, device): path = str(NASA_VIDEO.path) if create_from == "file": @@ -349,10 +350,14 @@ def test_create_decoder(self, create_from, device): arr = np.fromfile(path, dtype=np.uint8) video_tensor = torch.from_numpy(arr) decoder = create_from_tensor(video_tensor) - else: # bytes + elif create_from == "bytes": with open(path, "rb") as f: video_bytes = f.read() decoder = create_from_bytes(video_bytes) + elif create_from == "file_like": + decoder = create_from_file_like(open(path, mode="rb", buffering=0), "exact") + else: + raise ValueError("Oops, double check the parametrization of this test!") add_video_stream(decoder, device=device) frame0, _, _ = get_next_frame(decoder) From 53d0729a38c28d541ffb5ebcce8c5d543ed97f4c Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 14 Mar 2025 12:54:32 -0700 Subject: [PATCH 03/54] Forgot the new file. :/ --- src/torchcodec/decoders/_core/PyBindOps.cpp | 166 ++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 src/torchcodec/decoders/_core/PyBindOps.cpp diff --git a/src/torchcodec/decoders/_core/PyBindOps.cpp b/src/torchcodec/decoders/_core/PyBindOps.cpp new file mode 100644 index 000000000..de8ca6f4a --- /dev/null +++ b/src/torchcodec/decoders/_core/PyBindOps.cpp @@ -0,0 +1,166 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +#include "src/torchcodec/decoders/_core/VideoDecoder.h" + +namespace py = pybind11; + +namespace facebook::torchcodec { + +namespace { + +// Necessary to make sure that we hold the GIL when we delete a py::object. +struct PyObjectDeleter { + inline void operator()(py::object* obj) const { + if (obj) { + py::gil_scoped_acquire gil; + delete obj; + } + } +}; + +class AVIOFileLikeContext : public AVIOContextHolder { + public: + AVIOFileLikeContext(py::object fileLike, int bufferSize) + : fileLikeContext_{std::unique_ptr(new py::object(fileLike)), bufferSize} { + { + // TODO: Is it necessary to acquire the GIL here? Is it maybe even harmful? At + // the moment, this is only called from within a pybind function, and pybind + // guarantees we have the GIL. + py::gil_scoped_acquire gil; + TORCH_CHECK( + py::hasattr(fileLike, "read"), + "File like object must implement a read method."); + TORCH_CHECK( + py::hasattr(fileLike, "seek"), + "File like object must implement a seek method."); + } + + auto buffer = static_cast(av_malloc(bufferSize)); + TORCH_CHECK( + buffer != nullptr, + "Failed to allocate buffer of size " + std::to_string(bufferSize)); + + avioContext_.reset(avio_alloc_context( + buffer, + bufferSize, + 0, + &fileLikeContext_, + &AVIOFileLikeContext::read, + nullptr, + &AVIOFileLikeContext::seek)); + + if (!avioContext_) { + av_freep(&buffer); + TORCH_CHECK(false, "Failed to allocate AVIOContext"); + } + } + + virtual ~AVIOFileLikeContext() { + if (avioContext_) { + av_freep(&avioContext_->buffer); + } + } + + virtual AVIOContext* getAVIOContext() const override { + return avioContext_.get(); + } + + static int read(void* opaque, uint8_t* buf, int buf_size) { + auto fileLikeContext = static_cast(opaque); + buf_size = FFMIN(buf_size, fileLikeContext->bufferSize); + + int num_read = 0; + while (num_read < buf_size) { + int request = buf_size - num_read; + py::gil_scoped_acquire gil; + auto chunk = static_cast(static_cast( + fileLikeContext->fileLike->attr("read")(request))); + int chunk_len = static_cast(chunk.length()); + if (chunk_len == 0) { + break; + } + TORCH_CHECK( + chunk_len <= request, + "Requested up to ", + request, + " bytes but, received ", + chunk_len, + " bytes. The given object does not confirm to read protocol of file object."); + memcpy(buf, chunk.data(), chunk_len); + buf += chunk_len; + num_read += chunk_len; + } + return num_read == 0 ? AVERROR_EOF : num_read; + } + + static int64_t seek(void* opaque, int64_t offset, int whence) { + // We do not know the file size. + if (whence == AVSEEK_SIZE) { + return AVERROR(EIO); + } + auto fileLikeContext = static_cast(opaque); + py::gil_scoped_acquire gil; + return py::cast( + fileLikeContext->fileLike->attr("seek")(offset, whence)); + } + + private: + struct FileLikeContext { + // Note that we keep a pointer to the Python object because we need to + // strictly control when its destructor is called. We must hold the GIL + // when its destructor gets called, as it needs to update the reference + // count. It's easiest to control that when it's a pointer. Otherwise, we'd + // have to ensure whatever enclosing scope holds the object has the GIL, + // and that's, at least, hard. For all of the common pitfalls, see: + // + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + std::unique_ptr fileLike; + int bufferSize; + }; + + UniqueAVIOContext avioContext_; + FileLikeContext fileLikeContext_; +}; + +} // namespace + +// In principle, this should be able to return a tensor. But when we try that, we +// run into the bug reported here: +// +// https://github.com/pytorch/pytorch/issues/136664 +// +// So we instead launder the pointer through an int, and then use a conversion +// function on the custom ops side to launder that int into a tensor. +int64_t create_from_file_like( + py::object file_like, + std::optional seek_mode) { + VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact; + if (seek_mode.has_value()) { + realSeek = seekModeFromString(seek_mode.value()); + } + + constexpr int bufferSize = 64 * 1024; + auto contextHolder = + std::make_unique(file_like, bufferSize); + + VideoDecoder* decoder = new VideoDecoder(std::move(contextHolder), realSeek); + return reinterpret_cast(decoder); +} + +#ifndef TORCHCODEC_PYBIND +#error TORCHCODEC_PYBIND must be defined. +#endif + +PYBIND11_MODULE(TORCHCODEC_PYBIND, m) { + m.def("create_from_file_like", &create_from_file_like); +} + +} // namespace facebook::torchcodec From 6bae1728932e437d1080e3510c13ff2886bbb8cc Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 14 Mar 2025 12:57:46 -0700 Subject: [PATCH 04/54] Lint. --- src/torchcodec/decoders/_core/PyBindOps.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/decoders/_core/PyBindOps.cpp b/src/torchcodec/decoders/_core/PyBindOps.cpp index de8ca6f4a..fb3d8a870 100644 --- a/src/torchcodec/decoders/_core/PyBindOps.cpp +++ b/src/torchcodec/decoders/_core/PyBindOps.cpp @@ -4,9 +4,9 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. +#include #include #include -#include #include "src/torchcodec/decoders/_core/VideoDecoder.h" @@ -29,11 +29,14 @@ struct PyObjectDeleter { class AVIOFileLikeContext : public AVIOContextHolder { public: AVIOFileLikeContext(py::object fileLike, int bufferSize) - : fileLikeContext_{std::unique_ptr(new py::object(fileLike)), bufferSize} { + : fileLikeContext_{ + std::unique_ptr( + new py::object(fileLike)), + bufferSize} { { - // TODO: Is it necessary to acquire the GIL here? Is it maybe even harmful? At - // the moment, this is only called from within a pybind function, and pybind - // guarantees we have the GIL. + // TODO: Is it necessary to acquire the GIL here? Is it maybe even + // harmful? At the moment, this is only called from within a pybind + // function, and pybind guarantees we have the GIL. py::gil_scoped_acquire gil; TORCH_CHECK( py::hasattr(fileLike, "read"), @@ -132,8 +135,8 @@ class AVIOFileLikeContext : public AVIOContextHolder { } // namespace -// In principle, this should be able to return a tensor. But when we try that, we -// run into the bug reported here: +// In principle, this should be able to return a tensor. But when we try that, +// we run into the bug reported here: // // https://github.com/pytorch/pytorch/issues/136664 // From 70a8364477d68329ee31081120ae9f1ec1861383 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 14 Mar 2025 13:08:32 -0700 Subject: [PATCH 05/54] Remove unneded namespace alias. --- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 9a84f1a09..98028b08c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -13,8 +13,6 @@ #include "c10/util/Exception.h" #include "src/torchcodec/decoders/_core/VideoDecoder.h" -namespace py = pybind11; - namespace facebook::torchcodec { // ============================== From edce04b9eac07edd95f44c996084274d86ea7fbb Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 14 Mar 2025 13:12:18 -0700 Subject: [PATCH 06/54] Remove asserts. --- src/torchcodec/decoders/_core/ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 1406139b8..a145ae197 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -44,7 +44,6 @@ def load_torchcodec_extension(): ) global _pybind_ops _pybind_ops = importlib.util.module_from_spec(spec) - assert _pybind_ops is not None return except Exception as e: # TODO: recording and reporting exceptions this way is OK for now as it's just for debugging, @@ -131,7 +130,6 @@ def create_from_bytes( def create_from_file_like( file_like: io.RawIOBase, seek_mode: Optional[str] = None ) -> torch.Tensor: - assert _pybind_ops is not None return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode)) From 7741ae437da598ca8cce6cfb00c5197c868709b9 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 14 Mar 2025 13:17:48 -0700 Subject: [PATCH 07/54] Cleanup pybind ops loading. --- src/torchcodec/decoders/_core/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index a145ae197..e935ff134 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -34,12 +34,12 @@ def load_torchcodec_extension(): decoder_library_name = f"libtorchcodec_decoder{ffmpeg_major_version}" custom_ops_library_name = f"libtorchcodec_custom_ops{ffmpeg_major_version}" pybind_ops_library_name = f"libtorchcodec_pybind_ops{ffmpeg_major_version}" + pybind_ops_module_name = f"_torchcodec_pybind_ops{ffmpeg_major_version}" try: torch.ops.load_library(_get_extension_path(decoder_library_name)) torch.ops.load_library(_get_extension_path(custom_ops_library_name)) - torch.ops.load_library(_get_extension_path(pybind_ops_library_name)) spec = importlib.util.spec_from_file_location( - f"_torchcodec_pybind_ops{ffmpeg_major_version}", + pybind_ops_module_name, _get_extension_path(pybind_ops_library_name), ) global _pybind_ops From 0117a785f0aa87c02835495757344bb757ceb67a Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 14 Mar 2025 13:27:31 -0700 Subject: [PATCH 08/54] Explicitly say _pybind_ops is a module type --- src/torchcodec/decoders/_core/ops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index e935ff134..48565bafe 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -8,6 +8,7 @@ import io import json import warnings +from types import ModuleType from typing import List, Optional, Tuple import torch @@ -17,7 +18,7 @@ _get_extension_path, ) -_pybind_ops = None +_pybind_ops: Optional[ModuleType] = None def load_torchcodec_extension(): @@ -130,6 +131,7 @@ def create_from_bytes( def create_from_file_like( file_like: io.RawIOBase, seek_mode: Optional[str] = None ) -> torch.Tensor: + assert _pybind_ops is not None return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode)) From 681b9cc1587ac117e1dc442289cf4cb8fd5e58c4 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Sat, 15 Mar 2025 12:48:07 -0700 Subject: [PATCH 09/54] Refactor AVIOContextHolder --- .../decoders/_core/FFMPEGCommon.cpp | 77 ++++-------------- src/torchcodec/decoders/_core/FFMPEGCommon.h | 44 ++++------ src/torchcodec/decoders/_core/PyBindOps.cpp | 45 ++--------- .../decoders/_core/VideoDecoderOps.cpp | 80 ++++++++++++++++++- 4 files changed, 110 insertions(+), 136 deletions(-) diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp index df3d3a363..f0e2a5c6b 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp @@ -78,14 +78,14 @@ int getNumChannels(const UniqueAVCodecContext& avCodecContext) { #endif } -AVIOBytesContext::AVIOBytesContext( - const void* data, - int64_t dataSize, - int bufferSize) - : dataContext_{static_cast(data), dataSize, 0} { - TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!"); - TORCH_CHECK(dataSize > 0, "Video data size must be positive"); - +void AVIOContextHolder::createAVIOContext( + AVIOReadFunction read, + AVIOSeekFunction seek, + void* heldData, + int bufferSize) { + TORCH_CHECK( + bufferSize > 0, + "Buffer size must be greater than 0; is " + std::to_string(bufferSize)); auto buffer = static_cast(av_malloc(bufferSize)); TORCH_CHECK( buffer != nullptr, @@ -95,10 +95,10 @@ AVIOBytesContext::AVIOBytesContext( buffer, bufferSize, 0, - &dataContext_, - &AVIOBytesContext::read, - nullptr, - &AVIOBytesContext::seek)); + heldData, + read, + nullptr, // write function; not supported yet + seek)); if (!avioContext_) { av_freep(&buffer); @@ -106,63 +106,14 @@ AVIOBytesContext::AVIOBytesContext( } } -AVIOBytesContext::~AVIOBytesContext() { +AVIOContextHolder::~AVIOContextHolder() { if (avioContext_) { av_freep(&avioContext_->buffer); } } -AVIOContext* AVIOBytesContext::getAVIOContext() const { +AVIOContext* AVIOContextHolder::getAVIOContext() { return avioContext_.get(); } -// The signature of this function is defined by FFmpeg. -int AVIOBytesContext::read(void* opaque, uint8_t* buf, int buf_size) { - auto dataContext = static_cast(opaque); - TORCH_CHECK( - dataContext->current <= dataContext->size, - "Tried to read outside of the buffer: current=", - dataContext->current, - ", size=", - dataContext->size); - - buf_size = FFMIN( - buf_size, static_cast(dataContext->size - dataContext->current)); - TORCH_CHECK( - buf_size >= 0, - "Tried to read negative bytes: buf_size=", - buf_size, - ", size=", - dataContext->size, - ", current=", - dataContext->current); - - if (!buf_size) { - return AVERROR_EOF; - } - memcpy(buf, dataContext->data + dataContext->current, buf_size); - dataContext->current += buf_size; - return buf_size; -} - -// The signature of this function is defined by FFmpeg. -int64_t AVIOBytesContext::seek(void* opaque, int64_t offset, int whence) { - auto dataContext = static_cast(opaque); - int64_t ret = -1; - - switch (whence) { - case AVSEEK_SIZE: - ret = dataContext->size; - break; - case SEEK_SET: - dataContext->current = offset; - ret = offset; - break; - default: - break; - } - - return ret; -} - } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.h b/src/torchcodec/decoders/_core/FFMPEGCommon.h index 991d7145d..8cbc7c479 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.h +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.h @@ -145,43 +145,27 @@ int getNumChannels(const UniqueAVCodecContext& avCodecContext); // Returns true if sws_scale can handle unaligned data. bool canSwsScaleHandleUnalignedData(); +using AVIOReadFunction = int (*)(void*, uint8_t*, int); +using AVIOSeekFunction = int64_t (*)(void*, int64_t, int); + // TODO: explain purpose of context holder class AVIOContextHolder { public: - virtual ~AVIOContextHolder(){}; - virtual AVIOContext* getAVIOContext() const = 0; -}; - -// TODO: make comment below better -// A struct that holds state for reading bytes from an IO context. -// We give this to FFMPEG and it will pass it back to us when it needs to read -// or seek in the memory buffer. -// -// A class that can be used as AVFormatContext's IO context. It reads from a -// memory buffer that is passed in. -class AVIOBytesContext : public AVIOContextHolder { - public: - AVIOBytesContext(const void* data, int64_t dataSize, int bufferSize); - virtual ~AVIOBytesContext(); + virtual ~AVIOContextHolder(); + AVIOContext* getAVIOContext(); - // Returns the AVIOContext that can be passed to FFMPEG. - virtual AVIOContext* getAVIOContext() const override; - - // The signature of this function is defined by FFMPEG. - static int read(void* opaque, uint8_t* buf, int buf_size); - - // The signature of this function is defined by FFMPEG. - static int64_t seek(void* opaque, int64_t offset, int whence); + protected: + void createAVIOContext( + AVIOReadFunction read, + AVIOSeekFunction seek, + void* heldData, + int bufferSize = defaultBufferSize); private: - struct DataContext { - const uint8_t* data; - int64_t size; - int64_t current; - }; - UniqueAVIOContext avioContext_; - DataContext dataContext_; + + // Defaults to 64 KB + static const int defaultBufferSize = 64 * 1014; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/PyBindOps.cpp b/src/torchcodec/decoders/_core/PyBindOps.cpp index fb3d8a870..fdc678b9d 100644 --- a/src/torchcodec/decoders/_core/PyBindOps.cpp +++ b/src/torchcodec/decoders/_core/PyBindOps.cpp @@ -28,11 +28,9 @@ struct PyObjectDeleter { class AVIOFileLikeContext : public AVIOContextHolder { public: - AVIOFileLikeContext(py::object fileLike, int bufferSize) - : fileLikeContext_{ - std::unique_ptr( - new py::object(fileLike)), - bufferSize} { + explicit AVIOFileLikeContext(py::object fileLike) + : fileLikeContext_{std::unique_ptr( + new py::object(fileLike))} { { // TODO: Is it necessary to acquire the GIL here? Is it maybe even // harmful? At the moment, this is only called from within a pybind @@ -45,40 +43,11 @@ class AVIOFileLikeContext : public AVIOContextHolder { py::hasattr(fileLike, "seek"), "File like object must implement a seek method."); } - - auto buffer = static_cast(av_malloc(bufferSize)); - TORCH_CHECK( - buffer != nullptr, - "Failed to allocate buffer of size " + std::to_string(bufferSize)); - - avioContext_.reset(avio_alloc_context( - buffer, - bufferSize, - 0, - &fileLikeContext_, - &AVIOFileLikeContext::read, - nullptr, - &AVIOFileLikeContext::seek)); - - if (!avioContext_) { - av_freep(&buffer); - TORCH_CHECK(false, "Failed to allocate AVIOContext"); - } - } - - virtual ~AVIOFileLikeContext() { - if (avioContext_) { - av_freep(&avioContext_->buffer); - } - } - - virtual AVIOContext* getAVIOContext() const override { - return avioContext_.get(); + createAVIOContext(&read, &seek, &fileLikeContext_); } static int read(void* opaque, uint8_t* buf, int buf_size) { auto fileLikeContext = static_cast(opaque); - buf_size = FFMIN(buf_size, fileLikeContext->bufferSize); int num_read = 0; while (num_read < buf_size) { @@ -126,10 +95,8 @@ class AVIOFileLikeContext : public AVIOContextHolder { // // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors std::unique_ptr fileLike; - int bufferSize; }; - UniqueAVIOContext avioContext_; FileLikeContext fileLikeContext_; }; @@ -150,9 +117,7 @@ int64_t create_from_file_like( realSeek = seekModeFromString(seek_mode.value()); } - constexpr int bufferSize = 64 * 1024; - auto contextHolder = - std::make_unique(file_like, bufferSize); + auto contextHolder = std::make_unique(file_like); VideoDecoder* decoder = new VideoDecoder(std::move(contextHolder), realSeek); return reinterpret_cast(decoder); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 98028b08c..0014a6791 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -64,6 +64,82 @@ TORCH_LIBRARY(torchcodec_ns, m) { } namespace { + +// TODO: make comment below better +// A struct that holds state for reading bytes from an IO context. +// We give this to FFMPEG and it will pass it back to us when it needs to read +// or seek in the memory buffer. +// +// A class that can be used as AVFormatContext's IO context. It reads from a +// memory buffer that is passed in. +class AVIOBytesContext : public AVIOContextHolder { + public: + explicit AVIOBytesContext(const void* data, int64_t dataSize) + : dataContext_{static_cast(data), dataSize, 0} { + TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!"); + TORCH_CHECK(dataSize > 0, "Video data size must be positive"); + createAVIOContext(&read, &seek, &dataContext_); + } + + // The signature of this function is defined by FFMPEG. + static int read(void* opaque, uint8_t* buf, int buf_size) { + auto dataContext = static_cast(opaque); + TORCH_CHECK( + dataContext->current <= dataContext->size, + "Tried to read outside of the buffer: current=", + dataContext->current, + ", size=", + dataContext->size); + + buf_size = FFMIN( + buf_size, static_cast(dataContext->size - dataContext->current)); + TORCH_CHECK( + buf_size >= 0, + "Tried to read negative bytes: buf_size=", + buf_size, + ", size=", + dataContext->size, + ", current=", + dataContext->current); + + if (!buf_size) { + return AVERROR_EOF; + } + memcpy(buf, dataContext->data + dataContext->current, buf_size); + dataContext->current += buf_size; + return buf_size; + } + + // The signature of this function is defined by FFMPEG. + static int64_t seek(void* opaque, int64_t offset, int whence) { + auto dataContext = static_cast(opaque); + int64_t ret = -1; + + switch (whence) { + case AVSEEK_SIZE: + ret = dataContext->size; + break; + case SEEK_SET: + dataContext->current = offset; + ret = offset; + break; + default: + break; + } + + return ret; + } + + private: + struct DataContext { + const uint8_t* data; + int64_t size; + int64_t current; + }; + + DataContext dataContext_; +}; + at::Tensor wrapDecoderPointerToTensor( std::unique_ptr uniqueDecoder) { VideoDecoder* decoder = uniqueDecoder.release(); @@ -135,9 +211,7 @@ at::Tensor create_from_tensor( realSeek = seekModeFromString(seek_mode.value()); } - constexpr int bufferSize = 64 * 1024; - auto contextHolder = - std::make_unique(data, length, bufferSize); + auto contextHolder = std::make_unique(data, length); std::unique_ptr uniqueDecoder = std::make_unique(std::move(contextHolder), realSeek); From 43d6ddec4102502e19d82d189830d24edf9cd07f Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Sat, 15 Mar 2025 12:56:32 -0700 Subject: [PATCH 10/54] AVIOFileLikeContext refactoring --- src/torchcodec/decoders/_core/PyBindOps.cpp | 40 ++++++++++----------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/src/torchcodec/decoders/_core/PyBindOps.cpp b/src/torchcodec/decoders/_core/PyBindOps.cpp index fdc678b9d..d67138cea 100644 --- a/src/torchcodec/decoders/_core/PyBindOps.cpp +++ b/src/torchcodec/decoders/_core/PyBindOps.cpp @@ -26,11 +26,12 @@ struct PyObjectDeleter { } }; +using UniquePyObject = std::unique_ptr; + class AVIOFileLikeContext : public AVIOContextHolder { public: explicit AVIOFileLikeContext(py::object fileLike) - : fileLikeContext_{std::unique_ptr( - new py::object(fileLike))} { + : fileLike_{UniquePyObject(new py::object(fileLike))} { { // TODO: Is it necessary to acquire the GIL here? Is it maybe even // harmful? At the moment, this is only called from within a pybind @@ -43,18 +44,18 @@ class AVIOFileLikeContext : public AVIOContextHolder { py::hasattr(fileLike, "seek"), "File like object must implement a seek method."); } - createAVIOContext(&read, &seek, &fileLikeContext_); + createAVIOContext(&read, &seek, &fileLike_); } static int read(void* opaque, uint8_t* buf, int buf_size) { - auto fileLikeContext = static_cast(opaque); + auto fileLike = static_cast(opaque); int num_read = 0; while (num_read < buf_size) { int request = buf_size - num_read; py::gil_scoped_acquire gil; - auto chunk = static_cast(static_cast( - fileLikeContext->fileLike->attr("read")(request))); + auto chunk = static_cast( + static_cast((*fileLike)->attr("read")(request))); int chunk_len = static_cast(chunk.length()); if (chunk_len == 0) { break; @@ -78,26 +79,21 @@ class AVIOFileLikeContext : public AVIOContextHolder { if (whence == AVSEEK_SIZE) { return AVERROR(EIO); } - auto fileLikeContext = static_cast(opaque); + auto fileLike = static_cast(opaque); py::gil_scoped_acquire gil; - return py::cast( - fileLikeContext->fileLike->attr("seek")(offset, whence)); + return py::cast((*fileLike)->attr("seek")(offset, whence)); } private: - struct FileLikeContext { - // Note that we keep a pointer to the Python object because we need to - // strictly control when its destructor is called. We must hold the GIL - // when its destructor gets called, as it needs to update the reference - // count. It's easiest to control that when it's a pointer. Otherwise, we'd - // have to ensure whatever enclosing scope holds the object has the GIL, - // and that's, at least, hard. For all of the common pitfalls, see: - // - // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors - std::unique_ptr fileLike; - }; - - FileLikeContext fileLikeContext_; + // Note that we keep a pointer to the Python object because we need to + // strictly control when its destructor is called. We must hold the GIL + // when its destructor gets called, as it needs to update the reference + // count. It's easiest to control that when it's a pointer. Otherwise, we'd + // have to ensure whatever enclosing scope holds the object has the GIL, + // and that's, at least, hard. For all of the common pitfalls, see: + // + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + UniquePyObject fileLike_; }; } // namespace From d301f5317dc721fe547cbea5947850d75daa9e34 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 07:48:22 -0700 Subject: [PATCH 11/54] Better comment for AVIOContextHolder. --- src/torchcodec/decoders/_core/FFMPEGCommon.h | 25 ++++++++++++++++++-- src/torchcodec/decoders/_core/PyBindOps.cpp | 8 ++++--- test/decoders/VideoDecoderTest.cpp | 5 +--- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.h b/src/torchcodec/decoders/_core/FFMPEGCommon.h index 8cbc7c479..d1e0a04a1 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.h +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.h @@ -148,7 +148,28 @@ bool canSwsScaleHandleUnalignedData(); using AVIOReadFunction = int (*)(void*, uint8_t*, int); using AVIOSeekFunction = int64_t (*)(void*, int64_t, int); -// TODO: explain purpose of context holder +// The AVIOContextHolder serves several purposes: +// +// 1. It is a smart pointer for the AVIOContext. It has the logic to create +// a new AVIOContext and will appropriately free the AVIOContext when it +// goes out of scope. Note that this requires more than just the having a +// UniqueAVIOContext, as the AVIOContext points to a buffer which must be +// freed. +// 2. It is a base class for AVIOContext specializations. When specializing a +// AVIOContext, we need to provide four things: +// 1. A read callback function. +// 2. A seek callback function. +// 3. A write callback function. (Not supported yet; it's for encoding.) +// 4. A pointer to some context object that has the same lifetime as the +// AVIOContext itself. This context object holds the custom state that +// tracks the custom behavior of reading, seeking and writing. It is +// provided upon AVIOContext creation and to the read, seek and +// write callback functions. +// While it's not required, it is natural for the derived classes to make +// all of the above members. Base classes need to call +// createAVIOContext(), ideally in there constructor. +// 3. A generic handle for those that just need to manage having access to an +// AVIOContext, but aren't necessarily concerned with how it was customized. class AVIOContextHolder { public: virtual ~AVIOContextHolder(); @@ -165,7 +186,7 @@ class AVIOContextHolder { UniqueAVIOContext avioContext_; // Defaults to 64 KB - static const int defaultBufferSize = 64 * 1014; + static const int defaultBufferSize = 64 * 1024; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/PyBindOps.cpp b/src/torchcodec/decoders/_core/PyBindOps.cpp index d67138cea..d22004dc5 100644 --- a/src/torchcodec/decoders/_core/PyBindOps.cpp +++ b/src/torchcodec/decoders/_core/PyBindOps.cpp @@ -53,6 +53,8 @@ class AVIOFileLikeContext : public AVIOContextHolder { int num_read = 0; while (num_read < buf_size) { int request = buf_size - num_read; + // TODO: It is maybe more efficient to grab the lock once in the + // surrounding scope? py::gil_scoped_acquire gil; auto chunk = static_cast( static_cast((*fileLike)->attr("read")(request))); @@ -85,11 +87,11 @@ class AVIOFileLikeContext : public AVIOContextHolder { } private: - // Note that we keep a pointer to the Python object because we need to + // Note that we dynamically allocate the Python object because we need to // strictly control when its destructor is called. We must hold the GIL // when its destructor gets called, as it needs to update the reference - // count. It's easiest to control that when it's a pointer. Otherwise, we'd - // have to ensure whatever enclosing scope holds the object has the GIL, + // count. It's easiest to control that when it's dynamic memory. Otherwise, + // we'd have to ensure whatever enclosing scope holds the object has the GIL, // and that's, at least, hard. For all of the common pitfalls, see: // // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index ca9f9c64b..9380276ea 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -51,10 +51,7 @@ class VideoDecoderTest : public testing::TestWithParam { void* buffer = content_.data(); size_t length = content_.length(); - constexpr int bufferSize = 64 * 1024; - auto contextHolder = - std::make_unique(buffer, length, bufferSize); - + auto contextHolder = std::make_unique(buffer, length); return std::make_unique( std::move(contextHolder), VideoDecoder::SeekMode::approximate); } else { From a76d6a00288318d75122c31b58f477bc7f1cb8d9 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 09:50:52 -0700 Subject: [PATCH 12/54] Break out AVIOContext stuff into their own header and source files --- .../decoders/_core/AVIOBytesContext.cpp | 68 ++++++++++++++ .../decoders/_core/AVIOBytesContext.h | 32 +++++++ .../decoders/_core/AVIOContextHolder.cpp | 50 +++++++++++ .../decoders/_core/AVIOContextHolder.h | 64 ++++++++++++++ .../decoders/_core/AVIOFileLikeContext.cpp | 68 ++++++++++++++ .../decoders/_core/AVIOFileLikeContext.h | 49 +++++++++++ src/torchcodec/decoders/_core/CMakeLists.txt | 26 +++++- .../decoders/_core/FFMPEGCommon.cpp | 38 -------- src/torchcodec/decoders/_core/FFMPEGCommon.h | 44 ---------- src/torchcodec/decoders/_core/PyBindOps.cpp | 88 +------------------ src/torchcodec/decoders/_core/VideoDecoder.h | 6 +- .../decoders/_core/VideoDecoderOps.cpp | 76 +--------------- test/decoders/CMakeLists.txt | 1 + test/decoders/VideoDecoderTest.cpp | 1 + 14 files changed, 363 insertions(+), 248 deletions(-) create mode 100644 src/torchcodec/decoders/_core/AVIOBytesContext.cpp create mode 100644 src/torchcodec/decoders/_core/AVIOBytesContext.h create mode 100644 src/torchcodec/decoders/_core/AVIOContextHolder.cpp create mode 100644 src/torchcodec/decoders/_core/AVIOContextHolder.h create mode 100644 src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp create mode 100644 src/torchcodec/decoders/_core/AVIOFileLikeContext.h diff --git a/src/torchcodec/decoders/_core/AVIOBytesContext.cpp b/src/torchcodec/decoders/_core/AVIOBytesContext.cpp new file mode 100644 index 000000000..ce4c32872 --- /dev/null +++ b/src/torchcodec/decoders/_core/AVIOBytesContext.cpp @@ -0,0 +1,68 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include "src/torchcodec/decoders/_core/AVIOBytesContext.h" + +namespace facebook::torchcodec { + +AVIOBytesContext::AVIOBytesContext(const void* data, int64_t dataSize) + : dataContext_{static_cast(data), dataSize, 0} { + TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!"); + TORCH_CHECK(dataSize > 0, "Video data size must be positive"); + createAVIOContext(&read, &seek, &dataContext_); +} + +// The signature of this function is defined by FFMPEG. +int AVIOBytesContext::read(void* opaque, uint8_t* buf, int buf_size) { + auto dataContext = static_cast(opaque); + TORCH_CHECK( + dataContext->current <= dataContext->size, + "Tried to read outside of the buffer: current=", + dataContext->current, + ", size=", + dataContext->size); + + buf_size = FFMIN( + buf_size, static_cast(dataContext->size - dataContext->current)); + TORCH_CHECK( + buf_size >= 0, + "Tried to read negative bytes: buf_size=", + buf_size, + ", size=", + dataContext->size, + ", current=", + dataContext->current); + + if (!buf_size) { + return AVERROR_EOF; + } + memcpy(buf, dataContext->data + dataContext->current, buf_size); + dataContext->current += buf_size; + return buf_size; +} + +// The signature of this function is defined by FFMPEG. +int64_t AVIOBytesContext::seek(void* opaque, int64_t offset, int whence) { + auto dataContext = static_cast(opaque); + int64_t ret = -1; + + switch (whence) { + case AVSEEK_SIZE: + ret = dataContext->size; + break; + case SEEK_SET: + dataContext->current = offset; + ret = offset; + break; + default: + break; + } + + return ret; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/AVIOBytesContext.h b/src/torchcodec/decoders/_core/AVIOBytesContext.h new file mode 100644 index 000000000..dd4d68555 --- /dev/null +++ b/src/torchcodec/decoders/_core/AVIOBytesContext.h @@ -0,0 +1,32 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "src/torchcodec/decoders/_core/AVIOContextHolder.h" + +namespace facebook::torchcodec { + +// TODO: make comment below better +// memory buffer that is passed in. +class AVIOBytesContext : public AVIOContextHolder { + public: + explicit AVIOBytesContext(const void* data, int64_t dataSize); + + private: + struct DataContext { + const uint8_t* data; + int64_t size; + int64_t current; + }; + + static int read(void* opaque, uint8_t* buf, int buf_size); + static int64_t seek(void* opaque, int64_t offset, int whence); + + DataContext dataContext_; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/AVIOContextHolder.cpp b/src/torchcodec/decoders/_core/AVIOContextHolder.cpp new file mode 100644 index 000000000..863d41e28 --- /dev/null +++ b/src/torchcodec/decoders/_core/AVIOContextHolder.cpp @@ -0,0 +1,50 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include "src/torchcodec/decoders/_core/AVIOContextHolder.h" + +namespace facebook::torchcodec { + +void AVIOContextHolder::createAVIOContext( + AVIOReadFunction read, + AVIOSeekFunction seek, + void* heldData, + int bufferSize) { + TORCH_CHECK( + bufferSize > 0, + "Buffer size must be greater than 0; is " + std::to_string(bufferSize)); + auto buffer = static_cast(av_malloc(bufferSize)); + TORCH_CHECK( + buffer != nullptr, + "Failed to allocate buffer of size " + std::to_string(bufferSize)); + + avioContext_.reset(avio_alloc_context( + buffer, + bufferSize, + 0, + heldData, + read, + nullptr, // write function; not supported yet + seek)); + + if (!avioContext_) { + av_freep(&buffer); + TORCH_CHECK(false, "Failed to allocate AVIOContext"); + } +} + +AVIOContextHolder::~AVIOContextHolder() { + if (avioContext_) { + av_freep(&avioContext_->buffer); + } +} + +AVIOContext* AVIOContextHolder::getAVIOContext() { + return avioContext_.get(); +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/AVIOContextHolder.h b/src/torchcodec/decoders/_core/AVIOContextHolder.h new file mode 100644 index 000000000..c7a8d8ac6 --- /dev/null +++ b/src/torchcodec/decoders/_core/AVIOContextHolder.h @@ -0,0 +1,64 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "src/torchcodec/decoders/_core/FFMPEGCommon.h" + +namespace facebook::torchcodec { + +// These signatures are defined by FFmpeg. +using AVIOReadFunction = int (*)(void*, uint8_t*, int); +using AVIOSeekFunction = int64_t (*)(void*, int64_t, int); + +// The AVIOContextHolder serves several purposes: +// +// 1. It is a smart pointer for the AVIOContext. It has the logic to create +// a new AVIOContext and will appropriately free the AVIOContext when it +// goes out of scope. Note that this requires more than just the having a +// UniqueAVIOContext, as the AVIOContext points to a buffer which must be +// freed. +// 2. It is a base class for AVIOContext specializations. When specializing a +// AVIOContext, we need to provide four things: +// 1. A read callback function. +// 2. A seek callback function. +// 3. A write callback function. (Not supported yet; it's for encoding.) +// 4. A pointer to some context object that has the same lifetime as the +// AVIOContext itself. This context object holds the custom state that +// tracks the custom behavior of reading, seeking and writing. It is +// provided upon AVIOContext creation and to the read, seek and +// write callback functions. +// While it's not required, it is natural for the derived classes to make +// all of the above members. Base classes need to call +// createAVIOContext(), ideally in there constructor. +// 3. A generic handle for those that just need to manage having access to an +// AVIOContext, but aren't necessarily concerned with how it was customized. +class AVIOContextHolder { + public: + virtual ~AVIOContextHolder(); + AVIOContext* getAVIOContext(); + + protected: + // Make constructor protected to prevent anyone from constructing + // an AVIOContextHolder without deriving it. (Ordinarily this would be + // enforced by having a pure virtual methods, but we don't have any.) + AVIOContextHolder() = default; + + // Deriving classes should call this function in their constructor. + void createAVIOContext( + AVIOReadFunction read, + AVIOSeekFunction seek, + void* heldData, + int bufferSize = defaultBufferSize); + + private: + UniqueAVIOContext avioContext_; + + // Defaults to 64 KB + static const int defaultBufferSize = 64 * 1024; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp b/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp new file mode 100644 index 000000000..9343695a4 --- /dev/null +++ b/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp @@ -0,0 +1,68 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include "src/torchcodec/decoders/_core/AVIOFileLikeContext.h" + +namespace facebook::torchcodec { + +AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike) + : fileLike_{UniquePyObject(new py::object(fileLike))} { + { + // TODO: Is it necessary to acquire the GIL here? Is it maybe even + // harmful? At the moment, this is only called from within a pybind + // function, and pybind guarantees we have the GIL. + py::gil_scoped_acquire gil; + TORCH_CHECK( + py::hasattr(fileLike, "read"), + "File like object must implement a read method."); + TORCH_CHECK( + py::hasattr(fileLike, "seek"), + "File like object must implement a seek method."); + } + createAVIOContext(&read, &seek, &fileLike_); +} + +int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { + auto fileLike = static_cast(opaque); + + int num_read = 0; + while (num_read < buf_size) { + int request = buf_size - num_read; + // TODO: It is maybe more efficient to grab the lock once in the + // surrounding scope? + py::gil_scoped_acquire gil; + auto chunk = static_cast( + static_cast((*fileLike)->attr("read")(request))); + int chunk_len = static_cast(chunk.length()); + if (chunk_len == 0) { + break; + } + TORCH_CHECK( + chunk_len <= request, + "Requested up to ", + request, + " bytes but, received ", + chunk_len, + " bytes. The given object does not confirm to read protocol of file object."); + memcpy(buf, chunk.data(), chunk_len); + buf += chunk_len; + num_read += chunk_len; + } + return num_read == 0 ? AVERROR_EOF : num_read; +} + +int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) { + // We do not know the file size. + if (whence == AVSEEK_SIZE) { + return AVERROR(EIO); + } + auto fileLike = static_cast(opaque); + py::gil_scoped_acquire gil; + return py::cast((*fileLike)->attr("seek")(offset, whence)); +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/AVIOFileLikeContext.h b/src/torchcodec/decoders/_core/AVIOFileLikeContext.h new file mode 100644 index 000000000..de8b5f615 --- /dev/null +++ b/src/torchcodec/decoders/_core/AVIOFileLikeContext.h @@ -0,0 +1,49 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +#include "src/torchcodec/decoders/_core/AVIOContextHolder.h" + +namespace py = pybind11; + +namespace facebook::torchcodec { + +// Necessary to make sure that we hold the GIL when we delete a py::object. +struct PyObjectDeleter { + inline void operator()(py::object* obj) const { + if (obj) { + py::gil_scoped_acquire gil; + delete obj; + } + } +}; + +using UniquePyObject = std::unique_ptr; + +class AVIOFileLikeContext : public AVIOContextHolder { + public: + explicit AVIOFileLikeContext(py::object fileLike); + + private: + static int read(void* opaque, uint8_t* buf, int buf_size); + static int64_t seek(void* opaque, int64_t offset, int whence); + + // Note that we dynamically allocate the Python object because we need to + // strictly control when its destructor is called. We must hold the GIL + // when its destructor gets called, as it needs to update the reference + // count. It's easiest to control that when it's dynamic memory. Otherwise, + // we'd have to ensure whatever enclosing scope holds the object has the GIL, + // and that's, at least, hard. For all of the common pitfalls, see: + // + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + UniquePyObject fileLike_; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 2fcd5f55a..69a913e94 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -40,7 +40,11 @@ function(make_torchcodec_libraries # Create libtorchcodec_decoderN.so set(decoder_library_name "libtorchcodec_decoder${ffmpeg_major_version}") - set(decoder_sources FFMPEGCommon.cpp VideoDecoder.cpp) + set(decoder_sources + AVIOContextHolder.cpp + FFMPEGCommon.cpp + VideoDecoder.cpp + ) if(ENABLE_CUDA) list(APPEND decoder_sources CudaDevice.cpp) @@ -70,7 +74,10 @@ function(make_torchcodec_libraries # Create libtorchcodec_custom_opsN.so set(custom_ops_library_name "libtorchcodec_custom_ops${ffmpeg_major_version}") - set(custom_ops_sources VideoDecoderOps.cpp) + set(custom_ops_sources + AVIOBytesContext.cpp + VideoDecoderOps.cpp + ) make_torchcodec_sublibrary( "${custom_ops_library_name}" "${custom_ops_sources}" @@ -80,7 +87,10 @@ function(make_torchcodec_libraries # Create libtorchcodec_pybind_opsN.so set(pybind_ops_library_name "libtorchcodec_pybind_ops${ffmpeg_major_version}") - set(pybind_ops_sources PyBindOps.cpp) + set(pybind_ops_sources + AVIOFileLikeContext.cpp + PyBindOps.cpp + ) make_torchcodec_sublibrary( "${pybind_ops_library_name}" "${pybind_ops_sources}" @@ -92,6 +102,13 @@ function(make_torchcodec_libraries PUBLIC TORCHCODEC_PYBIND=_torchcodec_pybind_ops${ffmpeg_major_version} ) + # pybind11 quirk, see: + # https://pybind11.readthedocs.io/en/stable/faq.html#someclass-declared-with-greater-visibility-than-the-type-of-its-field-someclass-member-wattributes + target_compile_options( + ${pybind_ops_library_name} + PUBLIC + "-fvisibility=hidden" + ) # Install all libraries. set( @@ -172,7 +189,8 @@ else() # Expose these values updwards so that the test compilation does not need # to re-figure it out. FIXME: it's not great that we just copy-paste the - # library name. + # library names. set(libtorchcodec_library_name "libtorchcodec_decoder${ffmpeg_major_version}" PARENT_SCOPE) + set(libtorchcodec_custom_ops_name "libtorchcodec_custom_ops${ffmpeg_major_version}" PARENT_SCOPE) set(libav_include_dirs ${LIBAV_INCLUDE_DIRS} PARENT_SCOPE) endif() diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp index f0e2a5c6b..9b5729651 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.cpp @@ -78,42 +78,4 @@ int getNumChannels(const UniqueAVCodecContext& avCodecContext) { #endif } -void AVIOContextHolder::createAVIOContext( - AVIOReadFunction read, - AVIOSeekFunction seek, - void* heldData, - int bufferSize) { - TORCH_CHECK( - bufferSize > 0, - "Buffer size must be greater than 0; is " + std::to_string(bufferSize)); - auto buffer = static_cast(av_malloc(bufferSize)); - TORCH_CHECK( - buffer != nullptr, - "Failed to allocate buffer of size " + std::to_string(bufferSize)); - - avioContext_.reset(avio_alloc_context( - buffer, - bufferSize, - 0, - heldData, - read, - nullptr, // write function; not supported yet - seek)); - - if (!avioContext_) { - av_freep(&buffer); - TORCH_CHECK(false, "Failed to allocate AVIOContext"); - } -} - -AVIOContextHolder::~AVIOContextHolder() { - if (avioContext_) { - av_freep(&avioContext_->buffer); - } -} - -AVIOContext* AVIOContextHolder::getAVIOContext() { - return avioContext_.get(); -} - } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.h b/src/torchcodec/decoders/_core/FFMPEGCommon.h index d1e0a04a1..665fc480b 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.h +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.h @@ -145,48 +145,4 @@ int getNumChannels(const UniqueAVCodecContext& avCodecContext); // Returns true if sws_scale can handle unaligned data. bool canSwsScaleHandleUnalignedData(); -using AVIOReadFunction = int (*)(void*, uint8_t*, int); -using AVIOSeekFunction = int64_t (*)(void*, int64_t, int); - -// The AVIOContextHolder serves several purposes: -// -// 1. It is a smart pointer for the AVIOContext. It has the logic to create -// a new AVIOContext and will appropriately free the AVIOContext when it -// goes out of scope. Note that this requires more than just the having a -// UniqueAVIOContext, as the AVIOContext points to a buffer which must be -// freed. -// 2. It is a base class for AVIOContext specializations. When specializing a -// AVIOContext, we need to provide four things: -// 1. A read callback function. -// 2. A seek callback function. -// 3. A write callback function. (Not supported yet; it's for encoding.) -// 4. A pointer to some context object that has the same lifetime as the -// AVIOContext itself. This context object holds the custom state that -// tracks the custom behavior of reading, seeking and writing. It is -// provided upon AVIOContext creation and to the read, seek and -// write callback functions. -// While it's not required, it is natural for the derived classes to make -// all of the above members. Base classes need to call -// createAVIOContext(), ideally in there constructor. -// 3. A generic handle for those that just need to manage having access to an -// AVIOContext, but aren't necessarily concerned with how it was customized. -class AVIOContextHolder { - public: - virtual ~AVIOContextHolder(); - AVIOContext* getAVIOContext(); - - protected: - void createAVIOContext( - AVIOReadFunction read, - AVIOSeekFunction seek, - void* heldData, - int bufferSize = defaultBufferSize); - - private: - UniqueAVIOContext avioContext_; - - // Defaults to 64 KB - static const int defaultBufferSize = 64 * 1024; -}; - } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/PyBindOps.cpp b/src/torchcodec/decoders/_core/PyBindOps.cpp index d22004dc5..0c19b253d 100644 --- a/src/torchcodec/decoders/_core/PyBindOps.cpp +++ b/src/torchcodec/decoders/_core/PyBindOps.cpp @@ -4,102 +4,18 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. +#include #include #include #include +#include "src/torchcodec/decoders/_core/AVIOFileLikeContext.h" #include "src/torchcodec/decoders/_core/VideoDecoder.h" namespace py = pybind11; namespace facebook::torchcodec { -namespace { - -// Necessary to make sure that we hold the GIL when we delete a py::object. -struct PyObjectDeleter { - inline void operator()(py::object* obj) const { - if (obj) { - py::gil_scoped_acquire gil; - delete obj; - } - } -}; - -using UniquePyObject = std::unique_ptr; - -class AVIOFileLikeContext : public AVIOContextHolder { - public: - explicit AVIOFileLikeContext(py::object fileLike) - : fileLike_{UniquePyObject(new py::object(fileLike))} { - { - // TODO: Is it necessary to acquire the GIL here? Is it maybe even - // harmful? At the moment, this is only called from within a pybind - // function, and pybind guarantees we have the GIL. - py::gil_scoped_acquire gil; - TORCH_CHECK( - py::hasattr(fileLike, "read"), - "File like object must implement a read method."); - TORCH_CHECK( - py::hasattr(fileLike, "seek"), - "File like object must implement a seek method."); - } - createAVIOContext(&read, &seek, &fileLike_); - } - - static int read(void* opaque, uint8_t* buf, int buf_size) { - auto fileLike = static_cast(opaque); - - int num_read = 0; - while (num_read < buf_size) { - int request = buf_size - num_read; - // TODO: It is maybe more efficient to grab the lock once in the - // surrounding scope? - py::gil_scoped_acquire gil; - auto chunk = static_cast( - static_cast((*fileLike)->attr("read")(request))); - int chunk_len = static_cast(chunk.length()); - if (chunk_len == 0) { - break; - } - TORCH_CHECK( - chunk_len <= request, - "Requested up to ", - request, - " bytes but, received ", - chunk_len, - " bytes. The given object does not confirm to read protocol of file object."); - memcpy(buf, chunk.data(), chunk_len); - buf += chunk_len; - num_read += chunk_len; - } - return num_read == 0 ? AVERROR_EOF : num_read; - } - - static int64_t seek(void* opaque, int64_t offset, int whence) { - // We do not know the file size. - if (whence == AVSEEK_SIZE) { - return AVERROR(EIO); - } - auto fileLike = static_cast(opaque); - py::gil_scoped_acquire gil; - return py::cast((*fileLike)->attr("seek")(offset, whence)); - } - - private: - // Note that we dynamically allocate the Python object because we need to - // strictly control when its destructor is called. We must hold the GIL - // when its destructor gets called, as it needs to update the reference - // count. It's easiest to control that when it's dynamic memory. Otherwise, - // we'd have to ensure whatever enclosing scope holds the object has the GIL, - // and that's, at least, hard. For all of the common pitfalls, see: - // - // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors - UniquePyObject fileLike_; -}; - -} // namespace - // In principle, this should be able to return a tensor. But when we try that, // we run into the bug reported here: // diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 667686e17..8fb06dcc9 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -12,6 +12,7 @@ #include #include +#include "src/torchcodec/decoders/_core/AVIOContextHolder.h" #include "src/torchcodec/decoders/_core/FFMPEGCommon.h" namespace facebook::torchcodec { @@ -34,7 +35,10 @@ class VideoDecoder { const std::string& videoFilePath, SeekMode seekMode = SeekMode::exact); - // TODO: make comment accurate + // Creates a VideoDecoder using the provided AVIOContext inside the + // AVIOContextHolder. The AVIOContextHolder is the base class, and the + // derived class will have specialized how the custom read, seek and writes + // work. explicit VideoDecoder( std::unique_ptr context, SeekMode seekMode = SeekMode::exact); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 0014a6791..4ce31c23c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -11,6 +11,7 @@ #include #include "c10/core/SymIntArrayRef.h" #include "c10/util/Exception.h" +#include "src/torchcodec/decoders/_core/AVIOBytesContext.h" #include "src/torchcodec/decoders/_core/VideoDecoder.h" namespace facebook::torchcodec { @@ -65,81 +66,6 @@ TORCH_LIBRARY(torchcodec_ns, m) { namespace { -// TODO: make comment below better -// A struct that holds state for reading bytes from an IO context. -// We give this to FFMPEG and it will pass it back to us when it needs to read -// or seek in the memory buffer. -// -// A class that can be used as AVFormatContext's IO context. It reads from a -// memory buffer that is passed in. -class AVIOBytesContext : public AVIOContextHolder { - public: - explicit AVIOBytesContext(const void* data, int64_t dataSize) - : dataContext_{static_cast(data), dataSize, 0} { - TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!"); - TORCH_CHECK(dataSize > 0, "Video data size must be positive"); - createAVIOContext(&read, &seek, &dataContext_); - } - - // The signature of this function is defined by FFMPEG. - static int read(void* opaque, uint8_t* buf, int buf_size) { - auto dataContext = static_cast(opaque); - TORCH_CHECK( - dataContext->current <= dataContext->size, - "Tried to read outside of the buffer: current=", - dataContext->current, - ", size=", - dataContext->size); - - buf_size = FFMIN( - buf_size, static_cast(dataContext->size - dataContext->current)); - TORCH_CHECK( - buf_size >= 0, - "Tried to read negative bytes: buf_size=", - buf_size, - ", size=", - dataContext->size, - ", current=", - dataContext->current); - - if (!buf_size) { - return AVERROR_EOF; - } - memcpy(buf, dataContext->data + dataContext->current, buf_size); - dataContext->current += buf_size; - return buf_size; - } - - // The signature of this function is defined by FFMPEG. - static int64_t seek(void* opaque, int64_t offset, int whence) { - auto dataContext = static_cast(opaque); - int64_t ret = -1; - - switch (whence) { - case AVSEEK_SIZE: - ret = dataContext->size; - break; - case SEEK_SET: - dataContext->current = offset; - ret = offset; - break; - default: - break; - } - - return ret; - } - - private: - struct DataContext { - const uint8_t* data; - int64_t size; - int64_t current; - }; - - DataContext dataContext_; -}; - at::Tensor wrapDecoderPointerToTensor( std::unique_ptr uniqueDecoder) { VideoDecoder* decoder = uniqueDecoder.release(); diff --git a/test/decoders/CMakeLists.txt b/test/decoders/CMakeLists.txt index 1dd6ce153..126dd2794 100644 --- a/test/decoders/CMakeLists.txt +++ b/test/decoders/CMakeLists.txt @@ -28,6 +28,7 @@ target_include_directories(VideoDecoderTest PRIVATE ../../) target_link_libraries( VideoDecoderTest ${libtorchcodec_library_name} + ${libtorchcodec_custom_ops_name} GTest::gtest_main ) diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 9380276ea..b2a090d11 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -5,6 +5,7 @@ // LICENSE file in the root directory of this source tree. #include "src/torchcodec/decoders/_core/VideoDecoder.h" +#include "src/torchcodec/decoders/_core/AVIOBytesContext.h" #include #include From fa2445ea16b800e42636bbf71128cdcafd629a0d Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 10:13:59 -0700 Subject: [PATCH 13/54] Lint --- src/torchcodec/decoders/_core/AVIOBytesContext.cpp | 2 +- src/torchcodec/decoders/_core/AVIOContextHolder.cpp | 2 +- src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/decoders/_core/AVIOBytesContext.cpp b/src/torchcodec/decoders/_core/AVIOBytesContext.cpp index ce4c32872..c1851b9ca 100644 --- a/src/torchcodec/decoders/_core/AVIOBytesContext.cpp +++ b/src/torchcodec/decoders/_core/AVIOBytesContext.cpp @@ -4,8 +4,8 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include #include "src/torchcodec/decoders/_core/AVIOBytesContext.h" +#include namespace facebook::torchcodec { diff --git a/src/torchcodec/decoders/_core/AVIOContextHolder.cpp b/src/torchcodec/decoders/_core/AVIOContextHolder.cpp index 863d41e28..1fc4f5ecf 100644 --- a/src/torchcodec/decoders/_core/AVIOContextHolder.cpp +++ b/src/torchcodec/decoders/_core/AVIOContextHolder.cpp @@ -4,8 +4,8 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include #include "src/torchcodec/decoders/_core/AVIOContextHolder.h" +#include namespace facebook::torchcodec { diff --git a/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp b/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp index 9343695a4..1d63addc5 100644 --- a/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp @@ -4,13 +4,13 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include #include "src/torchcodec/decoders/_core/AVIOFileLikeContext.h" +#include namespace facebook::torchcodec { AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike) - : fileLike_{UniquePyObject(new py::object(fileLike))} { + : fileLike_{UniquePyObject(new py::object(fileLike))} { { // TODO: Is it necessary to acquire the GIL here? Is it maybe even // harmful? At the moment, this is only called from within a pybind From ffdbbfb377a8f98eb26a456a4ac565e0f5673385 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 10:28:53 -0700 Subject: [PATCH 14/54] Explicit assert on spec object --- src/torchcodec/decoders/_core/ops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 48565bafe..b0c9a0389 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -39,10 +39,13 @@ def load_torchcodec_extension(): try: torch.ops.load_library(_get_extension_path(decoder_library_name)) torch.ops.load_library(_get_extension_path(custom_ops_library_name)) + spec = importlib.util.spec_from_file_location( pybind_ops_module_name, _get_extension_path(pybind_ops_library_name), ) + assert spec is not None + global _pybind_ops _pybind_ops = importlib.util.module_from_spec(spec) return From c7d9df39e6d01a6c7a45a8a5126c3477ee85f8a2 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 11:48:23 -0700 Subject: [PATCH 15/54] Manual exception raising --- src/torchcodec/decoders/_core/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index b0c9a0389..062049398 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -44,7 +44,8 @@ def load_torchcodec_extension(): pybind_ops_module_name, _get_extension_path(pybind_ops_library_name), ) - assert spec is not None + if spec is None: + raise RuntimeError("spec is None") global _pybind_ops _pybind_ops = importlib.util.module_from_spec(spec) From 5134aff3491ff1ba42a60f1708aae77b52a59ef8 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 11:50:02 -0700 Subject: [PATCH 16/54] Undo in order to merge --- src/torchcodec/decoders/_core/ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 062049398..d3a3600b5 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -44,8 +44,6 @@ def load_torchcodec_extension(): pybind_ops_module_name, _get_extension_path(pybind_ops_library_name), ) - if spec is None: - raise RuntimeError("spec is None") global _pybind_ops _pybind_ops = importlib.util.module_from_spec(spec) From 799307051da6f69d623b492ab53fa23ec3b03fbc Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 12:02:33 -0700 Subject: [PATCH 17/54] Raise ImportError on spec failure --- src/torchcodec/decoders/_core/ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index d3a3600b5..9f47aec0b 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -44,6 +44,8 @@ def load_torchcodec_extension(): pybind_ops_module_name, _get_extension_path(pybind_ops_library_name), ) + if spec is None: + raise ImportError("Unable to load spec for pybind_ops") global _pybind_ops _pybind_ops = importlib.util.module_from_spec(spec) From f4ece883a42a7d6d4bb643f33b5c28089d4086d5 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 12:31:38 -0700 Subject: [PATCH 18/54] Print path --- src/torchcodec/decoders/_core/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 9f47aec0b..5033561ba 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -45,7 +45,7 @@ def load_torchcodec_extension(): _get_extension_path(pybind_ops_library_name), ) if spec is None: - raise ImportError("Unable to load spec for pybind_ops") + raise ImportError(f"Unable to load spec for pybind_ops {_get_extension_path(pybind_ops_library_name}") global _pybind_ops _pybind_ops = importlib.util.module_from_spec(spec) From 2b4f21336202baf9692ee14b037584c6d59d22df Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 12:38:14 -0700 Subject: [PATCH 19/54] Close paren --- src/torchcodec/decoders/_core/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 5033561ba..69867ce13 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -45,7 +45,7 @@ def load_torchcodec_extension(): _get_extension_path(pybind_ops_library_name), ) if spec is None: - raise ImportError(f"Unable to load spec for pybind_ops {_get_extension_path(pybind_ops_library_name}") + raise ImportError(f"Unable to load spec for pybind_ops {_get_extension_path(pybind_ops_library_name)}") global _pybind_ops _pybind_ops = importlib.util.module_from_spec(spec) From 01884b34bcadc248f48d6515563fed6041a27ce3 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 14:28:50 -0700 Subject: [PATCH 20/54] Load and importlib --- src/torchcodec/decoders/_core/CMakeLists.txt | 5 ----- src/torchcodec/decoders/_core/PyBindOps.cpp | 6 +----- src/torchcodec/decoders/_core/ops.py | 5 +++-- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 1ecfc651e..5e0ac05a1 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -97,11 +97,6 @@ function(make_torchcodec_libraries "${decoder_library_name}" "${ffmpeg_include_dirs}" ) - target_compile_definitions( - ${pybind_ops_library_name} - PUBLIC - TORCHCODEC_PYBIND=_torchcodec_pybind_ops${ffmpeg_major_version} - ) # pybind11 quirk, see: # https://pybind11.readthedocs.io/en/stable/faq.html#someclass-declared-with-greater-visibility-than-the-type-of-its-field-someclass-member-wattributes target_compile_options( diff --git a/src/torchcodec/decoders/_core/PyBindOps.cpp b/src/torchcodec/decoders/_core/PyBindOps.cpp index 0c19b253d..a8f3d4cba 100644 --- a/src/torchcodec/decoders/_core/PyBindOps.cpp +++ b/src/torchcodec/decoders/_core/PyBindOps.cpp @@ -37,11 +37,7 @@ int64_t create_from_file_like( return reinterpret_cast(decoder); } -#ifndef TORCHCODEC_PYBIND -#error TORCHCODEC_PYBIND must be defined. -#endif - -PYBIND11_MODULE(TORCHCODEC_PYBIND, m) { +PYBIND11_MODULE(_torchcodec_pybind_ops, m) { m.def("create_from_file_like", &create_from_file_like); } diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 69867ce13..9d47367c2 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -35,17 +35,18 @@ def load_torchcodec_extension(): decoder_library_name = f"libtorchcodec_decoder{ffmpeg_major_version}" custom_ops_library_name = f"libtorchcodec_custom_ops{ffmpeg_major_version}" pybind_ops_library_name = f"libtorchcodec_pybind_ops{ffmpeg_major_version}" - pybind_ops_module_name = f"_torchcodec_pybind_ops{ffmpeg_major_version}" + pybind_ops_module_name = "torchcodec._torchcodec_pybind_ops" try: torch.ops.load_library(_get_extension_path(decoder_library_name)) torch.ops.load_library(_get_extension_path(custom_ops_library_name)) + torch.ops.load_library(_get_extension_path(pybind_ops_library_name)) spec = importlib.util.spec_from_file_location( pybind_ops_module_name, _get_extension_path(pybind_ops_library_name), ) if spec is None: - raise ImportError(f"Unable to load spec for pybind_ops {_get_extension_path(pybind_ops_library_name)}") + raise ImportError(f"Unable to load spec for pybind_ops") global _pybind_ops _pybind_ops = importlib.util.module_from_spec(spec) From 45342a77c456678be47b89bcc7705bce8f521ec5 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 14:29:28 -0700 Subject: [PATCH 21/54] Lint --- src/torchcodec/decoders/_core/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 9d47367c2..59003ad3b 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -46,7 +46,7 @@ def load_torchcodec_extension(): _get_extension_path(pybind_ops_library_name), ) if spec is None: - raise ImportError(f"Unable to load spec for pybind_ops") + raise ImportError("Unable to load spec for pybind_ops") global _pybind_ops _pybind_ops = importlib.util.module_from_spec(spec) From 3608b50d9554a0ccea31ebe78cee37d5df62248b Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 17:43:29 -0700 Subject: [PATCH 22/54] Add FFmpeg version in exception traceback message --- src/torchcodec/decoders/_core/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 59003ad3b..febd2e1b0 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -54,11 +54,11 @@ def load_torchcodec_extension(): except Exception as e: # TODO: recording and reporting exceptions this way is OK for now as it's just for debugging, # but we should probably handle that via a proper logging mechanism. - exceptions.append(e) + exceptions.append(ffmpeg_major_version, e) traceback = ( "\n[start of libtorchcodec loading traceback]\n" - + "\n".join(str(e) for e in exceptions) + + "\n".join(f"FFmpeg version {v}: {str(e)}" for v, e in exceptions) + "\n[end of libtorchcodec loading traceback]." ) raise RuntimeError( From f36d050fcc8938f6cfa47bd0007e264cfa295bcc Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 17:55:49 -0700 Subject: [PATCH 23/54] Make exception args tuple; refactor visiblity of context stuff --- .../decoders/_core/AVIOContextHolder.h | 10 ++++----- .../decoders/_core/AVIOFileLikeContext.h | 22 +++++++++---------- src/torchcodec/decoders/_core/ops.py | 2 +- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/torchcodec/decoders/_core/AVIOContextHolder.h b/src/torchcodec/decoders/_core/AVIOContextHolder.h index c7a8d8ac6..e66cf0ca1 100644 --- a/src/torchcodec/decoders/_core/AVIOContextHolder.h +++ b/src/torchcodec/decoders/_core/AVIOContextHolder.h @@ -10,10 +10,6 @@ namespace facebook::torchcodec { -// These signatures are defined by FFmpeg. -using AVIOReadFunction = int (*)(void*, uint8_t*, int); -using AVIOSeekFunction = int64_t (*)(void*, int64_t, int); - // The AVIOContextHolder serves several purposes: // // 1. It is a smart pointer for the AVIOContext. It has the logic to create @@ -33,7 +29,7 @@ using AVIOSeekFunction = int64_t (*)(void*, int64_t, int); // write callback functions. // While it's not required, it is natural for the derived classes to make // all of the above members. Base classes need to call -// createAVIOContext(), ideally in there constructor. +// createAVIOContext(), ideally in their constructor. // 3. A generic handle for those that just need to manage having access to an // AVIOContext, but aren't necessarily concerned with how it was customized. class AVIOContextHolder { @@ -47,6 +43,10 @@ class AVIOContextHolder { // enforced by having a pure virtual methods, but we don't have any.) AVIOContextHolder() = default; + // These signatures are defined by FFmpeg. + using AVIOReadFunction = int (*)(void*, uint8_t*, int); + using AVIOSeekFunction = int64_t (*)(void*, int64_t, int); + // Deriving classes should call this function in their constructor. void createAVIOContext( AVIOReadFunction read, diff --git a/src/torchcodec/decoders/_core/AVIOFileLikeContext.h b/src/torchcodec/decoders/_core/AVIOFileLikeContext.h index de8b5f615..45e45425b 100644 --- a/src/torchcodec/decoders/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/decoders/_core/AVIOFileLikeContext.h @@ -15,18 +15,6 @@ namespace py = pybind11; namespace facebook::torchcodec { -// Necessary to make sure that we hold the GIL when we delete a py::object. -struct PyObjectDeleter { - inline void operator()(py::object* obj) const { - if (obj) { - py::gil_scoped_acquire gil; - delete obj; - } - } -}; - -using UniquePyObject = std::unique_ptr; - class AVIOFileLikeContext : public AVIOContextHolder { public: explicit AVIOFileLikeContext(py::object fileLike); @@ -43,6 +31,16 @@ class AVIOFileLikeContext : public AVIOContextHolder { // and that's, at least, hard. For all of the common pitfalls, see: // // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + struct PyObjectDeleter { + inline void operator()(py::object* obj) const { + if (obj) { + py::gil_scoped_acquire gil; + delete obj; + } + } + }; + + using UniquePyObject = std::unique_ptr; UniquePyObject fileLike_; }; diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index febd2e1b0..738150138 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -54,7 +54,7 @@ def load_torchcodec_extension(): except Exception as e: # TODO: recording and reporting exceptions this way is OK for now as it's just for debugging, # but we should probably handle that via a proper logging mechanism. - exceptions.append(ffmpeg_major_version, e) + exceptions.append((ffmpeg_major_version, e)) traceback = ( "\n[start of libtorchcodec loading traceback]\n" From 6819070d75e9917661cffe194f1ebb90302f642f Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 18:20:13 -0700 Subject: [PATCH 24/54] Try find_spec --- src/torchcodec/decoders/_core/ops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 738150138..1b30c701b 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -46,7 +46,9 @@ def load_torchcodec_extension(): _get_extension_path(pybind_ops_library_name), ) if spec is None: - raise ImportError("Unable to load spec for pybind_ops") + spec = importlib.util.find_spec(pybind_ops_module_name) + if spec is None: + raise ImportError("Unable to load spec for pybind_ops") global _pybind_ops _pybind_ops = importlib.util.module_from_spec(spec) From 89c86980d614550b4f566d1b44539f28c5a19162 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 18:54:04 -0700 Subject: [PATCH 25/54] Trying import_module as backup --- src/torchcodec/decoders/_core/ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 1b30c701b..391856d25 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -45,12 +45,12 @@ def load_torchcodec_extension(): pybind_ops_module_name, _get_extension_path(pybind_ops_library_name), ) + global _pybind_ops if spec is None: - spec = importlib.util.find_spec(pybind_ops_module_name) - if spec is None: + _pybind_ops = importlib.util.import_module(pybind_ops_module_name) + if _pybind_ops is None: raise ImportError("Unable to load spec for pybind_ops") - global _pybind_ops _pybind_ops = importlib.util.module_from_spec(spec) return except Exception as e: From 59c129f772f0a0769f7b184941b23c444f69fa29 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 19:02:22 -0700 Subject: [PATCH 26/54] Using plain _trochcodec_pybind_ops --- src/torchcodec/decoders/_core/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 391856d25..542e809fe 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -35,7 +35,7 @@ def load_torchcodec_extension(): decoder_library_name = f"libtorchcodec_decoder{ffmpeg_major_version}" custom_ops_library_name = f"libtorchcodec_custom_ops{ffmpeg_major_version}" pybind_ops_library_name = f"libtorchcodec_pybind_ops{ffmpeg_major_version}" - pybind_ops_module_name = "torchcodec._torchcodec_pybind_ops" + pybind_ops_module_name = "_torchcodec_pybind_ops" try: torch.ops.load_library(_get_extension_path(decoder_library_name)) torch.ops.load_library(_get_extension_path(custom_ops_library_name)) @@ -47,7 +47,7 @@ def load_torchcodec_extension(): ) global _pybind_ops if spec is None: - _pybind_ops = importlib.util.import_module(pybind_ops_module_name) + _pybind_ops = importlib.import_module(pybind_ops_module_name) if _pybind_ops is None: raise ImportError("Unable to load spec for pybind_ops") From e3d08e3b5e841e2da6aef347fddc068345c9e02f Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Mar 2025 19:37:45 -0700 Subject: [PATCH 27/54] Better module loading error reporting --- src/torchcodec/decoders/_core/ops.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 542e809fe..c2bc47f53 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -31,26 +31,26 @@ def load_torchcodec_extension(): # correct .so file, so this for-loop succeeds on the first iteration. exceptions = [] + pybind_ops_module_name = "_torchcodec_pybind_ops" for ffmpeg_major_version in (7, 6, 5, 4): decoder_library_name = f"libtorchcodec_decoder{ffmpeg_major_version}" custom_ops_library_name = f"libtorchcodec_custom_ops{ffmpeg_major_version}" pybind_ops_library_name = f"libtorchcodec_pybind_ops{ffmpeg_major_version}" - pybind_ops_module_name = "_torchcodec_pybind_ops" try: torch.ops.load_library(_get_extension_path(decoder_library_name)) torch.ops.load_library(_get_extension_path(custom_ops_library_name)) - torch.ops.load_library(_get_extension_path(pybind_ops_library_name)) + pybind_ops_library_path = _get_extension_path(pybind_ops_library_name) spec = importlib.util.spec_from_file_location( pybind_ops_module_name, - _get_extension_path(pybind_ops_library_name), + pybind_ops_library_path, ) - global _pybind_ops if spec is None: - _pybind_ops = importlib.import_module(pybind_ops_module_name) - if _pybind_ops is None: - raise ImportError("Unable to load spec for pybind_ops") + raise ImportError( + f"Unable to load spec for module {pybind_ops_module_name} from path {pybind_ops_library_path}" + ) + global _pybind_ops _pybind_ops = importlib.util.module_from_spec(spec) return except Exception as e: From e9a726fb5a1641bd2f63e71f88994ee99f094b3b Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 18 Mar 2025 05:58:32 -0700 Subject: [PATCH 28/54] Do both load and dynamic import --- src/torchcodec/decoders/_core/ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index c2bc47f53..bd261ebd6 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -41,6 +41,7 @@ def load_torchcodec_extension(): torch.ops.load_library(_get_extension_path(custom_ops_library_name)) pybind_ops_library_path = _get_extension_path(pybind_ops_library_name) + torch.ops.load_library(pybind_ops_library_path) spec = importlib.util.spec_from_file_location( pybind_ops_module_name, pybind_ops_library_path, From 591995fcdff3f85cee53a2e637670234da369bab Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 18 Mar 2025 06:14:43 -0700 Subject: [PATCH 29/54] Support both RawIOBase and BytesIO --- src/torchcodec/decoders/_core/ops.py | 2 +- test/decoders/test_ops.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index bd261ebd6..187539723 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -137,7 +137,7 @@ def create_from_bytes( def create_from_file_like( - file_like: io.RawIOBase, seek_mode: Optional[str] = None + file_like: io.RawIOBase | io.BytesIO, seek_mode: Optional[str] = None ) -> torch.Tensor: assert _pybind_ops is not None return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode)) diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 4da7c17e6..88cfa176b 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -343,7 +343,10 @@ def get_frame1_and_frame_time6(decoder): assert_frames_equal(frame_time6, reference_frame_time6.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) - @pytest.mark.parametrize("create_from", ("file", "tensor", "bytes", "file_like")) + @pytest.mark.parametrize( + "create_from", + ("file", "tensor", "bytes", "file_like_rawio", "file_like_bufferedio"), + ) def test_create_decoder(self, create_from, device): path = str(NASA_VIDEO.path) if create_from == "file": @@ -356,8 +359,12 @@ def test_create_decoder(self, create_from, device): with open(path, "rb") as f: video_bytes = f.read() decoder = create_from_bytes(video_bytes) - elif create_from == "file_like": + elif create_from == "file_like_rawio": decoder = create_from_file_like(open(path, mode="rb", buffering=0), "exact") + elif create_from == "file_like_bufferedio": + decoder = create_from_file_like( + open(path, mode="rb", buffering=-4096), "exact" + ) else: raise ValueError("Oops, double check the parametrization of this test!") From 0ff2e69e6330ddf35df7fd9218e13815fa22d58f Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 18 Mar 2025 06:41:12 -0700 Subject: [PATCH 30/54] Use Union instead of pipe --- src/torchcodec/decoders/_core/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 187539723..87b334247 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -9,7 +9,7 @@ import json import warnings from types import ModuleType -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from torch.library import get_ctx, register_fake @@ -137,7 +137,7 @@ def create_from_bytes( def create_from_file_like( - file_like: io.RawIOBase | io.BytesIO, seek_mode: Optional[str] = None + file_like: Union[io.RawIOBase, io.BytesIO], seek_mode: Optional[str] = None ) -> torch.Tensor: assert _pybind_ops is not None return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode)) From c1555c2f76d280d7afe5f48831ff6399e5373c1d Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 18 Mar 2025 07:29:02 -0700 Subject: [PATCH 31/54] Comments --- src/torchcodec/decoders/_core/AVIOBytesContext.h | 4 ++-- src/torchcodec/decoders/_core/AVIOFileLikeContext.h | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/AVIOBytesContext.h b/src/torchcodec/decoders/_core/AVIOBytesContext.h index dd4d68555..411866dc0 100644 --- a/src/torchcodec/decoders/_core/AVIOBytesContext.h +++ b/src/torchcodec/decoders/_core/AVIOBytesContext.h @@ -10,8 +10,8 @@ namespace facebook::torchcodec { -// TODO: make comment below better -// memory buffer that is passed in. +// Enables users to pass in the entire video as bytes. Our read and seek +// functions then traverse the bytes in memory. class AVIOBytesContext : public AVIOContextHolder { public: explicit AVIOBytesContext(const void* data, int64_t dataSize); diff --git a/src/torchcodec/decoders/_core/AVIOFileLikeContext.h b/src/torchcodec/decoders/_core/AVIOFileLikeContext.h index 45e45425b..4613b0a33 100644 --- a/src/torchcodec/decoders/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/decoders/_core/AVIOFileLikeContext.h @@ -15,6 +15,8 @@ namespace py = pybind11; namespace facebook::torchcodec { +// Enables uers to pass in a Python file-like object. We then forward all read +// and seek calls back up to the methods on the Python object. class AVIOFileLikeContext : public AVIOContextHolder { public: explicit AVIOFileLikeContext(py::object fileLike); From a3f6b9e4c5aad89e0fa8f7f3c37e88a0466ad9bc Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 18 Mar 2025 15:37:09 -0400 Subject: [PATCH 32/54] Update src/torchcodec/decoders/_core/AVIOContextHolder.h Co-authored-by: Nicolas Hug --- src/torchcodec/decoders/_core/AVIOContextHolder.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/AVIOContextHolder.h b/src/torchcodec/decoders/_core/AVIOContextHolder.h index e66cf0ca1..f5e4236df 100644 --- a/src/torchcodec/decoders/_core/AVIOContextHolder.h +++ b/src/torchcodec/decoders/_core/AVIOContextHolder.h @@ -14,7 +14,7 @@ namespace facebook::torchcodec { // // 1. It is a smart pointer for the AVIOContext. It has the logic to create // a new AVIOContext and will appropriately free the AVIOContext when it -// goes out of scope. Note that this requires more than just the having a +// goes out of scope. Note that this requires more than just having a // UniqueAVIOContext, as the AVIOContext points to a buffer which must be // freed. // 2. It is a base class for AVIOContext specializations. When specializing a From 040321a734288b99db4fea2ae4b0b272ae83c204 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 18 Mar 2025 15:48:55 -0400 Subject: [PATCH 33/54] Update src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp Co-authored-by: Nicolas Hug --- src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp b/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp index 1d63addc5..15448345d 100644 --- a/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp @@ -47,7 +47,7 @@ int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { request, " bytes but, received ", chunk_len, - " bytes. The given object does not confirm to read protocol of file object."); + " bytes. The given object does not conform to read protocol of file object."); memcpy(buf, chunk.data(), chunk_len); buf += chunk_len; num_read += chunk_len; From edbb5e722a2d4792012c741360dfb83bb69f76fb Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 19 Mar 2025 16:13:29 -0400 Subject: [PATCH 34/54] Update src/torchcodec/decoders/_core/AVIOContextHolder.h Co-authored-by: Nicolas Hug --- src/torchcodec/decoders/_core/AVIOContextHolder.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/AVIOContextHolder.h b/src/torchcodec/decoders/_core/AVIOContextHolder.h index f5e4236df..26bb06f08 100644 --- a/src/torchcodec/decoders/_core/AVIOContextHolder.h +++ b/src/torchcodec/decoders/_core/AVIOContextHolder.h @@ -31,7 +31,8 @@ namespace facebook::torchcodec { // all of the above members. Base classes need to call // createAVIOContext(), ideally in their constructor. // 3. A generic handle for those that just need to manage having access to an -// AVIOContext, but aren't necessarily concerned with how it was customized. +// AVIOContext, but aren't necessarily concerned with how it was customized: +// typically, the VideoDecoder. class AVIOContextHolder { public: virtual ~AVIOContextHolder(); From 7e6667c5e1e1eb03dd0a5cfc8d370fd3075bef81 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 19 Mar 2025 14:36:32 -0700 Subject: [PATCH 35/54] Address comments --- setup.py | 2 +- .../decoders/_core/AVIOFileLikeContext.cpp | 6 ++-- .../decoders/_core/AVIOFileLikeContext.h | 5 ++++ src/torchcodec/decoders/_core/CMakeLists.txt | 30 ++++++++----------- src/torchcodec/decoders/_core/PyBindOps.cpp | 4 +-- .../decoders/_core/VideoDecoderOps.cpp | 1 + src/torchcodec/decoders/_core/ops.py | 4 +-- test/decoders/test_ops.py | 2 +- 8 files changed, 28 insertions(+), 26 deletions(-) diff --git a/setup.py b/setup.py index 9120c7fe0..91f99c22c 100644 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ def run(self): super().run() def build_extension(self, ext): - """Call our CMake build system to build libtorchcodec?.so""" + """Call our CMake build system to build libtorchcodec*.so""" # Setuptools was designed to build one extension (.so file) at a time, # calling this method for each Extension object. We're using a # CMake-based build where all our extensions are built together at once. diff --git a/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp b/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp index 15448345d..9ed5dedef 100644 --- a/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp @@ -29,12 +29,12 @@ AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike) int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { auto fileLike = static_cast(opaque); + // Note that we acquire the GIL outside of the loop. This is likely more + // efficient than releasing and acquiring it each loop iteration. + py::gil_scoped_acquire gil; int num_read = 0; while (num_read < buf_size) { int request = buf_size - num_read; - // TODO: It is maybe more efficient to grab the lock once in the - // surrounding scope? - py::gil_scoped_acquire gil; auto chunk = static_cast( static_cast((*fileLike)->attr("read")(request))); int chunk_len = static_cast(chunk.length()); diff --git a/src/torchcodec/decoders/_core/AVIOFileLikeContext.h b/src/torchcodec/decoders/_core/AVIOFileLikeContext.h index 4613b0a33..7be07f2b6 100644 --- a/src/torchcodec/decoders/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/decoders/_core/AVIOFileLikeContext.h @@ -33,6 +33,11 @@ class AVIOFileLikeContext : public AVIOContextHolder { // and that's, at least, hard. For all of the common pitfalls, see: // // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + // + // We maintain a reference to the file-like object because the file-like + // object that was created on the Python side must live as long as our + // potential use. That is, even if there are no more references to the object + // on the Python side, we require that the object is still live. struct PyObjectDeleter { inline void operator()(py::object* obj) const { if (obj) { diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 5e0ac05a1..9d1bc2de7 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -10,8 +10,7 @@ find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) function(make_torchcodec_sublibrary library_name sources - dependent_libraries - ffmpeg_include_dirs) + library_dependencies) add_library(${library_name} SHARED ${sources}) set_target_properties(${library_name} PROPERTIES CXX_STANDARD 17) @@ -20,7 +19,6 @@ function(make_torchcodec_sublibrary ./../../../../ "${TORCH_INSTALL_PREFIX}/include" ${Python3_INCLUDE_DIRS} - ${ffmpeg_include_dirs} ) # Avoid adding the "lib" prefix which we already add explicitly. @@ -29,14 +27,15 @@ function(make_torchcodec_sublibrary target_link_libraries( ${library_name} PUBLIC - ${dependent_libraries} + ${library_dependencies} ) endfunction() function(make_torchcodec_libraries ffmpeg_major_version - ffmpeg_target - ffmpeg_include_dirs) + ffmpeg_target) + + # TODO: List each library and its purpose. # Create libtorchcodec_decoderN.so set(decoder_library_name "libtorchcodec_decoder${ffmpeg_major_version}") @@ -52,14 +51,14 @@ function(make_torchcodec_libraries list(APPEND decoder_sources CPUOnlyDevice.cpp) endif() - set(decoder_dependent_libraries + set(decoder_library_dependencies ${ffmpeg_target} ${TORCH_LIBRARIES} ${Python3_LIBRARIES} ) if(ENABLE_CUDA) - list(APPEND decoder_dependent_libraries + list(APPEND decoder_library_dependencies ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} ) @@ -68,8 +67,7 @@ function(make_torchcodec_libraries make_torchcodec_sublibrary( "${decoder_library_name}" "${decoder_sources}" - "${decoder_dependent_libraries}" - "${ffmpeg_include_dirs}" + "${decoder_library_dependencies}" ) # Create libtorchcodec_custom_opsN.so @@ -82,7 +80,6 @@ function(make_torchcodec_libraries "${custom_ops_library_name}" "${custom_ops_sources}" "${decoder_library_name}" - "${ffmpeg_include_dirs}" ) # Create libtorchcodec_pybind_opsN.so @@ -95,7 +92,6 @@ function(make_torchcodec_libraries "${pybind_ops_library_name}" "${pybind_ops_sources}" "${decoder_library_name}" - "${ffmpeg_include_dirs}" ) # pybind11 quirk, see: # https://pybind11.readthedocs.io/en/stable/faq.html#someclass-declared-with-greater-visibility-than-the-type-of-its-field-someclass-member-wattributes @@ -137,10 +133,10 @@ if(DEFINED ENV{BUILD_AGAINST_ALL_FFMPEG_FROM_S3}) ${CMAKE_CURRENT_SOURCE_DIR}/fetch_and_expose_non_gpl_ffmpeg_libs.cmake ) - make_torchcodec_libraries(7 ffmpeg7 $ffmpeg7_INCLUDE_DIRs) - make_torchcodec_libraries(6 ffmpeg6 $ffmpeg6_INCLUDE_DIRS) - make_torchcodec_libraries(4 ffmpeg4 $ffmpeg4_INCLUDE_DIRS) - make_torchcodec_libraries(5 ffmpeg5 $ffmpeg5_INCLUDE_DIRS) + make_torchcodec_libraries(7 ffmpeg7) + make_torchcodec_libraries(6 ffmpeg6) + make_torchcodec_libraries(4 ffmpeg4) + make_torchcodec_libraries(5 ffmpeg5) else() message( STATUS @@ -180,7 +176,7 @@ else() ) endif() - make_torchcodec_libraries(${ffmpeg_major_version} PkgConfig::LIBAV ${LIBAV_INCLUDE_DIRS}) + make_torchcodec_libraries(${ffmpeg_major_version} PkgConfig::LIBAV) # Expose these values updwards so that the test compilation does not need # to re-figure it out. FIXME: it's not great that we just copy-paste the diff --git a/src/torchcodec/decoders/_core/PyBindOps.cpp b/src/torchcodec/decoders/_core/PyBindOps.cpp index a8f3d4cba..638031403 100644 --- a/src/torchcodec/decoders/_core/PyBindOps.cpp +++ b/src/torchcodec/decoders/_core/PyBindOps.cpp @@ -31,9 +31,9 @@ int64_t create_from_file_like( realSeek = seekModeFromString(seek_mode.value()); } - auto contextHolder = std::make_unique(file_like); + auto avioContextHolder = std::make_unique(file_like); - VideoDecoder* decoder = new VideoDecoder(std::move(contextHolder), realSeek); + VideoDecoder* decoder = new VideoDecoder(std::move(avioContextHolder), realSeek); return reinterpret_cast(decoder); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 4ce31c23c..1b98107de 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -129,6 +129,7 @@ at::Tensor create_from_tensor( at::Tensor video_tensor, std::optional seek_mode) { TORCH_CHECK(video_tensor.is_contiguous(), "video_tensor must be contiguous"); + TORCH_CHECK(video_tensor.scalar_type() == torch::kUInt8, "video_tensor must be kUInt8"); void* data = video_tensor.mutable_data_ptr(); size_t length = video_tensor.numel(); diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 87b334247..977916ef9 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -21,7 +21,7 @@ _pybind_ops: Optional[ModuleType] = None -def load_torchcodec_extension(): +def load_torchcodec_shared_libraries(): # Successively try to load libtorchcodec7.so, libtorchcodec6.so, # libtorchcodec5.so, and libtorchcodec4.so. Each of these correspond to an # ffmpeg major version. This should cover all potential ffmpeg versions @@ -79,7 +79,7 @@ def load_torchcodec_extension(): ) -load_torchcodec_extension() +load_torchcodec_shared_libraries() # Note: We use disallow_in_graph because PyTorch does constant propagation of diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 88cfa176b..64c41cd60 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -363,7 +363,7 @@ def test_create_decoder(self, create_from, device): decoder = create_from_file_like(open(path, mode="rb", buffering=0), "exact") elif create_from == "file_like_bufferedio": decoder = create_from_file_like( - open(path, mode="rb", buffering=-4096), "exact" + open(path, mode="rb", buffering=4096), "exact" ) else: raise ValueError("Oops, double check the parametrization of this test!") From ca28adcb03e8a11e307c14746165e660e5684808 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 19 Mar 2025 14:38:54 -0700 Subject: [PATCH 36/54] Lint --- src/torchcodec/decoders/_core/PyBindOps.cpp | 3 ++- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/PyBindOps.cpp b/src/torchcodec/decoders/_core/PyBindOps.cpp index 638031403..26eadc8fb 100644 --- a/src/torchcodec/decoders/_core/PyBindOps.cpp +++ b/src/torchcodec/decoders/_core/PyBindOps.cpp @@ -33,7 +33,8 @@ int64_t create_from_file_like( auto avioContextHolder = std::make_unique(file_like); - VideoDecoder* decoder = new VideoDecoder(std::move(avioContextHolder), realSeek); + VideoDecoder* decoder = + new VideoDecoder(std::move(avioContextHolder), realSeek); return reinterpret_cast(decoder); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 1b98107de..3739a71b3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -129,7 +129,9 @@ at::Tensor create_from_tensor( at::Tensor video_tensor, std::optional seek_mode) { TORCH_CHECK(video_tensor.is_contiguous(), "video_tensor must be contiguous"); - TORCH_CHECK(video_tensor.scalar_type() == torch::kUInt8, "video_tensor must be kUInt8"); + TORCH_CHECK( + video_tensor.scalar_type() == torch::kUInt8, + "video_tensor must be kUInt8"); void* data = video_tensor.mutable_data_ptr(); size_t length = video_tensor.numel(); From ceaa1a6b15a1090a745f4e3ac29b52483ffa1157 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 20 Mar 2025 13:00:58 -0400 Subject: [PATCH 37/54] Test pass on Mac --- setup.py | 17 +++++---- src/torchcodec/decoders/_core/CMakeLists.txt | 40 +++++++++++++------- src/torchcodec/decoders/_core/ops.py | 1 - 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/setup.py b/setup.py index 91f99c22c..8b7ab6d1c 100644 --- a/setup.py +++ b/setup.py @@ -136,21 +136,22 @@ def copy_extensions_to_source(self): This is called by setuptools at the end of .run() during editable installs. """ self.get_finalized_command("build_py") - extension = "" + extensions = [] if sys.platform == "linux": - extension = "so" + extensions = ["so"] elif sys.platform == "darwin": - extension = "dylib" + extensions = ["dylib", "so"] else: raise NotImplementedError( "Platforms other than linux/darwin are not supported yet" ) - for so_file in self._install_prefix.glob(f"*.{extension}"): - assert "libtorchcodec" in so_file.name - destination = Path("src/torchcodec/") / so_file.name - print(f"Copying {so_file} to {destination}") - self.copy_file(so_file, destination, level=self.verbose) + for ext in extensions: + for lib_file in self._install_prefix.glob(f"*.{ext}"): + assert "libtorchcodec" in lib_file.name + destination = Path("src/torchcodec/") / lib_file.name + print(f"Copying {lib_file} to {destination}") + self.copy_file(lib_file, destination, level=self.verbose) NOT_A_LICENSE_VIOLATION_VAR = "I_CONFIRM_THIS_IS_NOT_A_LICENSE_VIOLATION" diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 9d1bc2de7..7b3273743 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -3,16 +3,19 @@ project(TorchCodec) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) +find_package(pybind11 REQUIRED) find_package(Torch REQUIRED) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}") find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic -Werror ${TORCH_CXX_FLAGS}") + function(make_torchcodec_sublibrary library_name + type sources library_dependencies) - add_library(${library_name} SHARED ${sources}) + add_library(${library_name} ${type} ${sources}) set_target_properties(${library_name} PROPERTIES CXX_STANDARD 17) target_include_directories(${library_name} PRIVATE @@ -40,9 +43,9 @@ function(make_torchcodec_libraries # Create libtorchcodec_decoderN.so set(decoder_library_name "libtorchcodec_decoder${ffmpeg_major_version}") set(decoder_sources - AVIOContextHolder.cpp - FFMPEGCommon.cpp - VideoDecoder.cpp + AVIOContextHolder.cpp + FFMPEGCommon.cpp + VideoDecoder.cpp ) if(ENABLE_CUDA) @@ -54,7 +57,6 @@ function(make_torchcodec_libraries set(decoder_library_dependencies ${ffmpeg_target} ${TORCH_LIBRARIES} - ${Python3_LIBRARIES} ) if(ENABLE_CUDA) @@ -66,6 +68,7 @@ function(make_torchcodec_libraries make_torchcodec_sublibrary( "${decoder_library_name}" + SHARED "${decoder_sources}" "${decoder_library_dependencies}" ) @@ -73,11 +76,12 @@ function(make_torchcodec_libraries # Create libtorchcodec_custom_opsN.so set(custom_ops_library_name "libtorchcodec_custom_ops${ffmpeg_major_version}") set(custom_ops_sources - AVIOBytesContext.cpp - VideoDecoderOps.cpp + AVIOBytesContext.cpp + VideoDecoderOps.cpp ) make_torchcodec_sublibrary( "${custom_ops_library_name}" + SHARED "${custom_ops_sources}" "${decoder_library_name}" ) @@ -85,21 +89,31 @@ function(make_torchcodec_libraries # Create libtorchcodec_pybind_opsN.so set(pybind_ops_library_name "libtorchcodec_pybind_ops${ffmpeg_major_version}") set(pybind_ops_sources - AVIOFileLikeContext.cpp - PyBindOps.cpp + AVIOFileLikeContext.cpp + PyBindOps.cpp + ) + set(pybind_ops_dependencies + ${decoder_library_name} + pybind11::module ) make_torchcodec_sublibrary( "${pybind_ops_library_name}" + MODULE "${pybind_ops_sources}" - "${decoder_library_name}" + "${pybind_ops_dependencies}" ) # pybind11 quirk, see: # https://pybind11.readthedocs.io/en/stable/faq.html#someclass-declared-with-greater-visibility-than-the-type-of-its-field-someclass-member-wattributes target_compile_options( - ${pybind_ops_library_name} - PUBLIC + ${pybind_ops_library_name} + PUBLIC "-fvisibility=hidden" ) + target_link_options( + ${pybind_ops_library_name} + PUBLIC + "-undefined dynamic_lookup" + ) # Install all libraries. set( diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 977916ef9..fdaa6155c 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -41,7 +41,6 @@ def load_torchcodec_shared_libraries(): torch.ops.load_library(_get_extension_path(custom_ops_library_name)) pybind_ops_library_path = _get_extension_path(pybind_ops_library_name) - torch.ops.load_library(pybind_ops_library_path) spec = importlib.util.spec_from_file_location( pybind_ops_module_name, pybind_ops_library_path, From 9b06e79155f2f8c4bc4017e554d6d25430dc9668 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 20 Mar 2025 11:35:03 -0700 Subject: [PATCH 38/54] Better comments --- setup.py | 5 ++++ src/torchcodec/decoders/_core/CMakeLists.txt | 31 +++++++++++++++----- src/torchcodec/decoders/_core/ops.py | 4 +-- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 8b7ab6d1c..f16521764 100644 --- a/setup.py +++ b/setup.py @@ -140,6 +140,11 @@ def copy_extensions_to_source(self): if sys.platform == "linux": extensions = ["so"] elif sys.platform == "darwin": + # Mac has BOTH .dylib and .so as library extensions. Short version + # is that a .dylib is a shared library that can be both dynamically + # loaded and depended on by other libraries; a .so can only be a + # dynamically loaded module. For more, see: + # https://stackoverflow.com/a/2339910 extensions = ["dylib", "so"] else: raise NotImplementedError( diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 7b3273743..904115a4b 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -38,9 +38,21 @@ function(make_torchcodec_libraries ffmpeg_major_version ffmpeg_target) - # TODO: List each library and its purpose. - - # Create libtorchcodec_decoderN.so + # We create three shared libraries per version of FFmpeg, where the version + # is denoted by N: + # + # 1. libtorchcodec_decoderN.so: Base library which contains the + # implementation of VideoDecoder and everything VideoDecoder needs. + # + # 2. libtorchcodec_custom_opsN.so: Implementation of the PyTorch custom + # ops. Depends on libtorchcodec_decoderN.so. + # + # 3. libtorchcodec_pybind_opsN.so: Implementation of the pybind11 ops. We + # keep these separate from the PyTorch custom ops because we have to + # load these libraries separately on the Python side. Depends on + # libtorchcodec_decoderN.so. + + # 1. Create libtorchcodec_decoderN.so. set(decoder_library_name "libtorchcodec_decoder${ffmpeg_major_version}") set(decoder_sources AVIOContextHolder.cpp @@ -73,7 +85,7 @@ function(make_torchcodec_libraries "${decoder_library_dependencies}" ) - # Create libtorchcodec_custom_opsN.so + # 2. Create libtorchcodec_custom_opsN.so. set(custom_ops_library_name "libtorchcodec_custom_ops${ffmpeg_major_version}") set(custom_ops_sources AVIOBytesContext.cpp @@ -86,7 +98,7 @@ function(make_torchcodec_libraries "${decoder_library_name}" ) - # Create libtorchcodec_pybind_opsN.so + # 3. Create libtorchcodec_pybind_opsN.so. set(pybind_ops_library_name "libtorchcodec_pybind_ops${ffmpeg_major_version}") set(pybind_ops_sources AVIOFileLikeContext.cpp @@ -98,7 +110,9 @@ function(make_torchcodec_libraries ) make_torchcodec_sublibrary( "${pybind_ops_library_name}" - MODULE + MODULE # Note that this not SHARED; otherwise we build the wrong kind + # of library on Mac. On Mac, SHARED becomes .dylib and MODULE becomes + # a .so. We want pybind11 libraries to become .so. "${pybind_ops_sources}" "${pybind_ops_dependencies}" ) @@ -109,6 +123,9 @@ function(make_torchcodec_libraries PUBLIC "-fvisibility=hidden" ) + # If we don't make sure this flag is set, we run into segfauls at import + # time on Mac. See: + # https://github.com/pybind/pybind11/issues/3907#issuecomment-1170412764 target_link_options( ${pybind_ops_library_name} PUBLIC @@ -124,7 +141,7 @@ function(make_torchcodec_libraries ) # The install step is invoked within CMakeBuild.build_library() in - # setup.py and just copies the built .so files from the temp + # setup.py and just copies the built files from the temp # cmake/setuptools build folder into the CMAKE_INSTALL_PREFIX folder. We # still need to manually pass "DESTINATION ..." for cmake to copy those # files in CMAKE_INSTALL_PREFIX instead of CMAKE_INSTALL_PREFIX/lib. diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index fdaa6155c..1272e72ac 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -22,8 +22,8 @@ def load_torchcodec_shared_libraries(): - # Successively try to load libtorchcodec7.so, libtorchcodec6.so, - # libtorchcodec5.so, and libtorchcodec4.so. Each of these correspond to an + # Successively try to load libtorchcodec_*7.so, libtorchcodec_*6.so, + # libtorchcodec_*5.so, and libtorchcodec_*4.so. Each of these correspond to an # ffmpeg major version. This should cover all potential ffmpeg versions # installed on the user's machine. # From 3896a706a04684b573290a3640fae4f8a185ca5e Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 20 Mar 2025 11:40:03 -0700 Subject: [PATCH 39/54] More comments --- src/torchcodec/decoders/_core/CMakeLists.txt | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 904115a4b..1382dc941 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -41,16 +41,18 @@ function(make_torchcodec_libraries # We create three shared libraries per version of FFmpeg, where the version # is denoted by N: # - # 1. libtorchcodec_decoderN.so: Base library which contains the - # implementation of VideoDecoder and everything VideoDecoder needs. + # 1. libtorchcodec_decoderN.{ext}: Base library which contains the + # implementation of VideoDecoder and everything VideoDecoder needs. On + # Linux, {ext} is so. On Mac, it is dylib. # - # 2. libtorchcodec_custom_opsN.so: Implementation of the PyTorch custom - # ops. Depends on libtorchcodec_decoderN.so. + # 2. libtorchcodec_custom_opsN.{ext}: Implementation of the PyTorch custom + # ops. Depends on libtorchcodec_decoderN.{ext}. On Linux, {ext} is so. + # On Mac, it is dylib. # - # 3. libtorchcodec_pybind_opsN.so: Implementation of the pybind11 ops. We + # 3. libtorchcodec_pybind_opsN.{ext}: Implementation of the pybind11 ops. We # keep these separate from the PyTorch custom ops because we have to # load these libraries separately on the Python side. Depends on - # libtorchcodec_decoderN.so. + # libtorchcodec_decoderN.{ext}. On BOTH Linux and Mac {ext} is so. # 1. Create libtorchcodec_decoderN.so. set(decoder_library_name "libtorchcodec_decoder${ffmpeg_major_version}") From e9b6c76da5246541d6e0fdc08be0cddd4b0b2500 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 20 Mar 2025 11:41:24 -0700 Subject: [PATCH 40/54] More more comments --- src/torchcodec/decoders/_core/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 1382dc941..8b28fbadc 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -54,7 +54,7 @@ function(make_torchcodec_libraries # load these libraries separately on the Python side. Depends on # libtorchcodec_decoderN.{ext}. On BOTH Linux and Mac {ext} is so. - # 1. Create libtorchcodec_decoderN.so. + # 1. Create libtorchcodec_decoderN.{ext}. set(decoder_library_name "libtorchcodec_decoder${ffmpeg_major_version}") set(decoder_sources AVIOContextHolder.cpp @@ -87,7 +87,7 @@ function(make_torchcodec_libraries "${decoder_library_dependencies}" ) - # 2. Create libtorchcodec_custom_opsN.so. + # 2. Create libtorchcodec_custom_opsN.{ext}. set(custom_ops_library_name "libtorchcodec_custom_ops${ffmpeg_major_version}") set(custom_ops_sources AVIOBytesContext.cpp From d94b97c6220daab5f5f6b102fad61586fa613e07 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 20 Mar 2025 13:50:30 -0700 Subject: [PATCH 41/54] Add pybind11 in some workflows --- .github/workflows/cpp_tests.yaml | 6 +++--- .github/workflows/lint.yaml | 2 +- CONTRIBUTING.md | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/cpp_tests.yaml b/.github/workflows/cpp_tests.yaml index 6ae56c70f..5e31babc7 100644 --- a/.github/workflows/cpp_tests.yaml +++ b/.github/workflows/cpp_tests.yaml @@ -34,12 +34,12 @@ jobs: python-version: '3.12' - name: Update pip run: python -m pip install --upgrade pip - - name: Install dependencies + - name: Install torch dependencies run: | python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu - - name: Install ffmpeg and pkg-config + - name: Install ffmpeg, pkg-config and pybind11 run: | - conda install "ffmpeg=${{ matrix.ffmpeg-version-for-tests }}" pkg-config -c conda-forge + conda install "ffmpeg=${{ matrix.ffmpeg-version-for-tests }}" pkg-config pybind11 -c conda-forge ffmpeg -version - name: Build and run C++ tests run: | diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 71cc071c8..c156a833c 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -63,7 +63,7 @@ jobs: - name: Install dependencies and FFmpeg run: | python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu - conda install "ffmpeg=7.0.1" pkg-config -c conda-forge + conda install "ffmpeg=7.0.1" pkg-config pybind11 -c conda-forge ffmpeg -version - name: Build and install torchcodec run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bc3ec3bb0..d516bc272 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,6 +20,7 @@ test locally you will need the following dependencies: installation already. - cmake - pkg-config +- pybind11 - FFmpeg - PyTorch nightly @@ -29,7 +30,7 @@ Start by installing the **nightly** build of PyTorch following the Then, the easiest way to install the rest of the dependencies is to run: ```bash -conda install cmake pkg-config ffmpeg -c conda-forge +conda install cmake pkg-config pbyind11 ffmpeg -c conda-forge ``` ### Clone and build From bd598c358898c4731553d5c7d9de101e82ff0584 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 20 Mar 2025 20:51:49 -0700 Subject: [PATCH 42/54] Make sure custom_ops has Python dependencies --- src/torchcodec/decoders/_core/CMakeLists.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 8b28fbadc..ef64d7260 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -3,6 +3,7 @@ project(TorchCodec) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(PYBIND11_FINDPYTHON ON) find_package(pybind11 REQUIRED) find_package(Torch REQUIRED) find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) @@ -93,11 +94,15 @@ function(make_torchcodec_libraries AVIOBytesContext.cpp VideoDecoderOps.cpp ) + set(custom_ops_dependencies + ${decoder_library_name} + ${Python3_LIBRARIES} + ) make_torchcodec_sublibrary( "${custom_ops_library_name}" SHARED "${custom_ops_sources}" - "${decoder_library_name}" + "${custom_ops_dependencies}" ) # 3. Create libtorchcodec_pybind_opsN.so. From bd3ecab97f498adbe5fac37a32c29c0cfa821ea8 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 21 Mar 2025 07:35:53 -0700 Subject: [PATCH 43/54] Add pre-build script to wheel building --- .github/workflows/linux_wheel.yaml | 1 + packaging/pre_build_script.sh | 5 +++++ src/torchcodec/decoders/_core/ops.py | 1 + 3 files changed, 7 insertions(+) create mode 100644 packaging/pre_build_script.sh diff --git a/.github/workflows/linux_wheel.yaml b/.github/workflows/linux_wheel.yaml index f5e665f55..ae8acd2aa 100644 --- a/.github/workflows/linux_wheel.yaml +++ b/.github/workflows/linux_wheel.yaml @@ -49,6 +49,7 @@ jobs: test-infra-repository: pytorch/test-infra test-infra-ref: main build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: packagin/pre_build_script.sh post-script: packaging/post_build_script.sh smoke-test-script: packaging/fake_smoke_test.py package-name: torchcodec diff --git a/packaging/pre_build_script.sh b/packaging/pre_build_script.sh new file mode 100644 index 000000000..6f03138b7 --- /dev/null +++ b/packaging/pre_build_script.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +set -ex + +conda install -y pybind11 -c conda-forge diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 1272e72ac..b2571842a 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -20,6 +20,7 @@ _pybind_ops: Optional[ModuleType] = None +# TODO: More on why we need two different ways of loading external modules. def load_torchcodec_shared_libraries(): # Successively try to load libtorchcodec_*7.so, libtorchcodec_*6.so, From e7f49c4765b9bede7715374b6f9e71849d045b07 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 21 Mar 2025 07:39:46 -0700 Subject: [PATCH 44/54] Forgot a g --- .github/workflows/linux_wheel.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/linux_wheel.yaml b/.github/workflows/linux_wheel.yaml index ae8acd2aa..cd53219f7 100644 --- a/.github/workflows/linux_wheel.yaml +++ b/.github/workflows/linux_wheel.yaml @@ -49,7 +49,7 @@ jobs: test-infra-repository: pytorch/test-infra test-infra-ref: main build-matrix: ${{ needs.generate-matrix.outputs.matrix }} - pre-script: packagin/pre_build_script.sh + pre-script: packaging/pre_build_script.sh post-script: packaging/post_build_script.sh smoke-test-script: packaging/fake_smoke_test.py package-name: torchcodec From 0f8556a79860e8b69ba0c0ac6d5574a167e8c62d Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 21 Mar 2025 07:53:21 -0700 Subject: [PATCH 45/54] Add pre-build script to rest of workflows --- .github/workflows/docs.yaml | 1 + .github/workflows/linux_cuda_wheel.yaml | 1 + .github/workflows/macos_wheel.yaml | 1 + 3 files changed, 3 insertions(+) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 89e8401de..60bfbfa2e 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -38,6 +38,7 @@ jobs: test-infra-repository: pytorch/test-infra test-infra-ref: main build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: packaging/pre_build_script.sh post-script: packaging/post_build_script.sh smoke-test-script: packaging/fake_smoke_test.py package-name: torchcodec diff --git a/.github/workflows/linux_cuda_wheel.yaml b/.github/workflows/linux_cuda_wheel.yaml index 65b06e933..53b5bfc20 100644 --- a/.github/workflows/linux_cuda_wheel.yaml +++ b/.github/workflows/linux_cuda_wheel.yaml @@ -48,6 +48,7 @@ jobs: test-infra-repository: pytorch/test-infra test-infra-ref: main build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: packaging/pre_build_script.sh post-script: packaging/post_build_script.sh smoke-test-script: packaging/fake_smoke_test.py package-name: torchcodec diff --git a/.github/workflows/macos_wheel.yaml b/.github/workflows/macos_wheel.yaml index b7cc965eb..d9472765c 100644 --- a/.github/workflows/macos_wheel.yaml +++ b/.github/workflows/macos_wheel.yaml @@ -49,6 +49,7 @@ jobs: test-infra-repository: pytorch/test-infra test-infra-ref: main build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: packaging/pre_build_script.sh post-script: packaging/post_build_script.sh smoke-test-script: packaging/fake_smoke_test.py runner-type: macos-m1-stable From 66db2724c2b868939f209d02b12780f0c4d85ec0 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 21 Mar 2025 07:53:56 -0700 Subject: [PATCH 46/54] Lint --- src/torchcodec/decoders/_core/ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index b2571842a..54a59078f 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -22,6 +22,7 @@ # TODO: More on why we need two different ways of loading external modules. + def load_torchcodec_shared_libraries(): # Successively try to load libtorchcodec_*7.so, libtorchcodec_*6.so, # libtorchcodec_*5.so, and libtorchcodec_*4.so. Each of these correspond to an From 4ca294b84ae4c47a7443f33b61fdf1bca68788e6 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 21 Mar 2025 10:55:17 -0700 Subject: [PATCH 47/54] Better comments --- packaging/pre_build_script.sh | 4 ++++ src/torchcodec/decoders/_core/CMakeLists.txt | 14 +++++++++++--- src/torchcodec/decoders/_core/PyBindOps.cpp | 2 +- src/torchcodec/decoders/_core/ops.py | 14 +++++++++++++- 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/packaging/pre_build_script.sh b/packaging/pre_build_script.sh index 6f03138b7..f22244e9c 100644 --- a/packaging/pre_build_script.sh +++ b/packaging/pre_build_script.sh @@ -2,4 +2,8 @@ set -ex +# We need to install pybind11 because we need its CMake helpers in order to +# compile correctly on Mac. Pybind11 is actually a C++ header-only library, +# and PyTorch actually has it included. PyTorch, however, does not have the +# CMake helpers. conda install -y pybind11 -c conda-forge diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index ef64d7260..c6dc15385 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -113,17 +113,25 @@ function(make_torchcodec_libraries ) set(pybind_ops_dependencies ${decoder_library_name} - pybind11::module + pybind11::module # This library dependency makes sure we have the right + # Python libraries included as well as all of the right + # settings so that we can successfully load the shared + # library as a Python module. ) make_torchcodec_sublibrary( "${pybind_ops_library_name}" MODULE # Note that this not SHARED; otherwise we build the wrong kind # of library on Mac. On Mac, SHARED becomes .dylib and MODULE becomes - # a .so. We want pybind11 libraries to become .so. + # a .so. We want pybind11 libraries to become .so. If this is + # changed to SHARED, we will be able to succesfully compile a + # .dylib, but we will not be able to succesfully import that as + # a Python module on Mac. "${pybind_ops_sources}" "${pybind_ops_dependencies}" ) - # pybind11 quirk, see: + # pybind11 limits the visibility of symbols in the shared library to prevent + # stray initialization of py::objects. The rest of the object code must + # match. See: # https://pybind11.readthedocs.io/en/stable/faq.html#someclass-declared-with-greater-visibility-than-the-type-of-its-field-someclass-member-wattributes target_compile_options( ${pybind_ops_library_name} diff --git a/src/torchcodec/decoders/_core/PyBindOps.cpp b/src/torchcodec/decoders/_core/PyBindOps.cpp index 26eadc8fb..b998e084f 100644 --- a/src/torchcodec/decoders/_core/PyBindOps.cpp +++ b/src/torchcodec/decoders/_core/PyBindOps.cpp @@ -38,7 +38,7 @@ int64_t create_from_file_like( return reinterpret_cast(decoder); } -PYBIND11_MODULE(_torchcodec_pybind_ops, m) { +PYBIND11_MODULE(torchcodec_pybind_ops, m) { m.def("create_from_file_like", &create_from_file_like); } diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 54a59078f..b710ca135 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -31,9 +31,21 @@ def load_torchcodec_shared_libraries(): # # On fbcode, _get_extension_path() is overridden and directly points to the # correct .so file, so this for-loop succeeds on the first iteration. + # + # Note that we use two different methods for loading shared libraries: + # + # 1. torch.ops.load_library(): For PyTorch custom ops. Loading libraries + # through PyTorch registers the custom ops with PyTorch's runtime and + # the ops can be accessed through torch.ops after loading. + # + # 2. importlib: For pybind11 modules. We load them dynamically, rather + # than using a plain import statement. A plain import statement only + # works when the module name and file name match exactly, and the + # shared library file is in the import path. Our shared libraries do + # not meet those conditions. exceptions = [] - pybind_ops_module_name = "_torchcodec_pybind_ops" + pybind_ops_module_name = "torchcodec_pybind_ops" for ffmpeg_major_version in (7, 6, 5, 4): decoder_library_name = f"libtorchcodec_decoder{ffmpeg_major_version}" custom_ops_library_name = f"libtorchcodec_custom_ops{ffmpeg_major_version}" From d28088814bdbfe150cea56115c2dd9bd5a11e81b Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 21 Mar 2025 12:02:05 -0700 Subject: [PATCH 48/54] Use string_view instead of string for bytes --- .../decoders/_core/AVIOFileLikeContext.cpp | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp b/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp index 9ed5dedef..c5a5eab81 100644 --- a/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp @@ -32,27 +32,38 @@ int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { // Note that we acquire the GIL outside of the loop. This is likely more // efficient than releasing and acquiring it each loop iteration. py::gil_scoped_acquire gil; - int num_read = 0; - while (num_read < buf_size) { - int request = buf_size - num_read; - auto chunk = static_cast( - static_cast((*fileLike)->attr("read")(request))); - int chunk_len = static_cast(chunk.length()); - if (chunk_len == 0) { + + int totalNumRead = 0; + while (totalNumRead < buf_size) { + int request = buf_size - totalNumRead; + + // The Python method returns the actual bytes, which we access through the + // py::bytes wrapper. That wrapper, however, does not provide us access to + // the underlying data pointer, which we need for the memcpy below. So we + // convert the bytes to a string_view to get access to the data pointer. + // Becauase it's a view and not a copy, it should be cheap. + auto bytesRead = static_cast((*fileLike)->attr("read")(request)); + auto bytesView = static_cast(bytes); + + int numBytesRead = static_cast(bytesView.size()); + if (numBytesRead == 0) { break; } + TORCH_CHECK( - chunk_len <= request, + numBytesRead <= request, "Requested up to ", request, " bytes but, received ", - chunk_len, + numBytesRead, " bytes. The given object does not conform to read protocol of file object."); - memcpy(buf, chunk.data(), chunk_len); - buf += chunk_len; - num_read += chunk_len; + + std::memcpy(buf, bytesView.data(), numBytesRead); + buf += numBytesRead; + totalNumRead += numBytesRead; } - return num_read == 0 ? AVERROR_EOF : num_read; + + return totalNumRead == 0 ? AVERROR_EOF : totalNumRead; } int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) { @@ -60,6 +71,7 @@ int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) { if (whence == AVSEEK_SIZE) { return AVERROR(EIO); } + auto fileLike = static_cast(opaque); py::gil_scoped_acquire gil; return py::cast((*fileLike)->attr("seek")(offset, whence)); From 52d5a6f4bf9c26d2c37775f001d32e9a3d430a9e Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 21 Mar 2025 12:08:42 -0700 Subject: [PATCH 49/54] Remove todo --- src/torchcodec/decoders/_core/ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index b710ca135..6ce509133 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -20,8 +20,6 @@ _pybind_ops: Optional[ModuleType] = None -# TODO: More on why we need two different ways of loading external modules. - def load_torchcodec_shared_libraries(): # Successively try to load libtorchcodec_*7.so, libtorchcodec_*6.so, From 9f2469e4288866a7c54fbbe9095b09e9722f2fc2 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 21 Mar 2025 12:54:13 -0700 Subject: [PATCH 50/54] Avoid negative buffer sizes --- .../decoders/_core/AVIOBytesContext.cpp | 20 ++++++++++--------- .../decoders/_core/AVIOFileLikeContext.cpp | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/torchcodec/decoders/_core/AVIOBytesContext.cpp b/src/torchcodec/decoders/_core/AVIOBytesContext.cpp index c1851b9ca..0d1e9d413 100644 --- a/src/torchcodec/decoders/_core/AVIOBytesContext.cpp +++ b/src/torchcodec/decoders/_core/AVIOBytesContext.cpp @@ -26,23 +26,25 @@ int AVIOBytesContext::read(void* opaque, uint8_t* buf, int buf_size) { ", size=", dataContext->size); - buf_size = FFMIN( - buf_size, static_cast(dataContext->size - dataContext->current)); + int64_t numBytesRead = std::min( + static_cast(buf_size), dataContext->size - dataContext->current); + TORCH_CHECK( - buf_size >= 0, - "Tried to read negative bytes: buf_size=", - buf_size, + numBytesRead >= 0, + "Tried to read negative bytes: numBytesRead=", + numBytesRead, ", size=", dataContext->size, ", current=", dataContext->current); - if (!buf_size) { + if (numBytesRead == 0) { return AVERROR_EOF; } - memcpy(buf, dataContext->data + dataContext->current, buf_size); - dataContext->current += buf_size; - return buf_size; + + std::memcpy(buf, dataContext->data + dataContext->current, numBytesRead); + dataContext->current += numBytesRead; + return numBytesRead; } // The signature of this function is defined by FFMPEG. diff --git a/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp b/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp index c5a5eab81..60d1503ae 100644 --- a/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/decoders/_core/AVIOFileLikeContext.cpp @@ -43,7 +43,7 @@ int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { // convert the bytes to a string_view to get access to the data pointer. // Becauase it's a view and not a copy, it should be cheap. auto bytesRead = static_cast((*fileLike)->attr("read")(request)); - auto bytesView = static_cast(bytes); + auto bytesView = static_cast(bytesRead); int numBytesRead = static_cast(bytesView.size()); if (numBytesRead == 0) { From 72f4ffad826693c35150492d8b73b9a75ff52621 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 21 Mar 2025 13:13:38 -0700 Subject: [PATCH 51/54] Better comment --- src/torchcodec/decoders/_core/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index c6dc15385..f0a8568fe 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -116,7 +116,9 @@ function(make_torchcodec_libraries pybind11::module # This library dependency makes sure we have the right # Python libraries included as well as all of the right # settings so that we can successfully load the shared - # library as a Python module. + # library as a Python module on Mac. If we instead use + # ${Python3_LIBRARIES}, it works on Linux but not on + # Mac. ) make_torchcodec_sublibrary( "${pybind_ops_library_name}" From 9e84c98d64f3ffda6d17f35e83e5176e1a71b5e9 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 24 Mar 2025 06:49:47 -0700 Subject: [PATCH 52/54] Update comments --- src/torchcodec/decoders/_core/ops.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 6ce509133..289cad78c 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -32,15 +32,15 @@ def load_torchcodec_shared_libraries(): # # Note that we use two different methods for loading shared libraries: # - # 1. torch.ops.load_library(): For PyTorch custom ops. Loading libraries - # through PyTorch registers the custom ops with PyTorch's runtime and - # the ops can be accessed through torch.ops after loading. + # 1. torch.ops.load_library(): For PyTorch custom ops and the C++ only + # libraries the custom ops depend on. Loading libraries through PyTorch + # registers the custom ops with PyTorch's runtime and the ops can be + # accessed through torch.ops after loading. # # 2. importlib: For pybind11 modules. We load them dynamically, rather # than using a plain import statement. A plain import statement only - # works when the module name and file name match exactly, and the - # shared library file is in the import path. Our shared libraries do - # not meet those conditions. + # works when the module name and file name match exactly. Our shared + # libraries do not meet those conditions. exceptions = [] pybind_ops_module_name = "torchcodec_pybind_ops" From b034fff18e8ed9f65813d2dd7bc5f750f935b897 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 26 Mar 2025 11:14:44 -0700 Subject: [PATCH 53/54] More generic way to import pybind11 --- src/torchcodec/_internally_replaced_utils.py | 16 +++++++++++++++- src/torchcodec/decoders/_core/PyBindOps.cpp | 2 +- src/torchcodec/decoders/_core/ops.py | 17 +++++------------ 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/torchcodec/_internally_replaced_utils.py b/src/torchcodec/_internally_replaced_utils.py index 0833eb82f..4ee78f3e5 100644 --- a/src/torchcodec/_internally_replaced_utils.py +++ b/src/torchcodec/_internally_replaced_utils.py @@ -7,11 +7,12 @@ import importlib import sys from pathlib import Path +from types import ModuleType # Copy pasted from torchvision # https://github.com/pytorch/vision/blob/947ae1dc71867f28021d5bc0ff3a19c249236e2a/torchvision/_internally_replaced_utils.py#L25 -def _get_extension_path(lib_name): +def _get_extension_path(lib_name: str) -> str: extension_suffixes = [] if sys.platform == "linux": extension_suffixes = importlib.machinery.EXTENSION_SUFFIXES @@ -34,3 +35,16 @@ def _get_extension_path(lib_name): raise ImportError return ext_specs.origin + + +def _load_pybind11_module(module_name: str, library_path: str) -> ModuleType: + spec = importlib.util.spec_from_file_location( + module_name, + library_path, + ) + if spec is None: + raise ImportError( + f"Unable to load spec for module {module_name} from path {library_path}" + ) + + return importlib.util.module_from_spec(spec) diff --git a/src/torchcodec/decoders/_core/PyBindOps.cpp b/src/torchcodec/decoders/_core/PyBindOps.cpp index b998e084f..0b0f6f177 100644 --- a/src/torchcodec/decoders/_core/PyBindOps.cpp +++ b/src/torchcodec/decoders/_core/PyBindOps.cpp @@ -38,7 +38,7 @@ int64_t create_from_file_like( return reinterpret_cast(decoder); } -PYBIND11_MODULE(torchcodec_pybind_ops, m) { +PYBIND11_MODULE(decoder_core_pybind_ops, m) { m.def("create_from_file_like", &create_from_file_like); } diff --git a/src/torchcodec/decoders/_core/ops.py b/src/torchcodec/decoders/_core/ops.py index 289cad78c..0f0bdfe25 100644 --- a/src/torchcodec/decoders/_core/ops.py +++ b/src/torchcodec/decoders/_core/ops.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import importlib import io import json import warnings @@ -16,6 +15,7 @@ from torchcodec._internally_replaced_utils import ( # @manual=//pytorch/torchcodec/src:internally_replaced_utils _get_extension_path, + _load_pybind11_module, ) _pybind_ops: Optional[ModuleType] = None @@ -43,7 +43,7 @@ def load_torchcodec_shared_libraries(): # libraries do not meet those conditions. exceptions = [] - pybind_ops_module_name = "torchcodec_pybind_ops" + pybind_ops_module_name = "decoder_core_pybind_ops" for ffmpeg_major_version in (7, 6, 5, 4): decoder_library_name = f"libtorchcodec_decoder{ffmpeg_major_version}" custom_ops_library_name = f"libtorchcodec_custom_ops{ffmpeg_major_version}" @@ -53,17 +53,10 @@ def load_torchcodec_shared_libraries(): torch.ops.load_library(_get_extension_path(custom_ops_library_name)) pybind_ops_library_path = _get_extension_path(pybind_ops_library_name) - spec = importlib.util.spec_from_file_location( - pybind_ops_module_name, - pybind_ops_library_path, - ) - if spec is None: - raise ImportError( - f"Unable to load spec for module {pybind_ops_module_name} from path {pybind_ops_library_path}" - ) - global _pybind_ops - _pybind_ops = importlib.util.module_from_spec(spec) + _pybind_ops = _load_pybind11_module( + pybind_ops_module_name, pybind_ops_library_path + ) return except Exception as e: # TODO: recording and reporting exceptions this way is OK for now as it's just for debugging, From 0ac90ba0401391c594d5ab4e8dbc0121c998e537 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 26 Mar 2025 11:28:43 -0700 Subject: [PATCH 54/54] Assert origin is there --- src/torchcodec/_internally_replaced_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/_internally_replaced_utils.py b/src/torchcodec/_internally_replaced_utils.py index 4ee78f3e5..a5a3ffa39 100644 --- a/src/torchcodec/_internally_replaced_utils.py +++ b/src/torchcodec/_internally_replaced_utils.py @@ -32,7 +32,10 @@ def _get_extension_path(lib_name: str) -> str: ) ext_specs = extfinder.find_spec(lib_name) if ext_specs is None: - raise ImportError + raise ImportError(f"No spec found for {lib_name}") + + if ext_specs.origin is None: + raise ImportError(f"Existing spec found for {lib_name} does not have an origin") return ext_specs.origin