diff --git a/benchmarks/decoders/gpu_benchmark.py b/benchmarks/decoders/gpu_benchmark.py new file mode 100644 index 000000000..8655f889d --- /dev/null +++ b/benchmarks/decoders/gpu_benchmark.py @@ -0,0 +1,206 @@ +import argparse +import os +import pathlib +import time +from concurrent.futures import ThreadPoolExecutor + +import torch + +import torch.utils.benchmark as benchmark + +import torchcodec +import torchvision.transforms.v2.functional as F + +RESIZED_WIDTH = 256 +RESIZED_HEIGHT = 256 + + +def transfer_and_resize_frame(frame, resize_device_string): + # This should be a no-op if the frame is already on the target device. + frame = frame.to(resize_device_string) + frame = F.resize(frame, (RESIZED_HEIGHT, RESIZED_WIDTH)) + return frame + + +def decode_full_video(video_path, decode_device_string, resize_device_string): + # We use the core API instead of SimpleVideoDecoder because the core API + # allows us to natively resize as part of the decode step. + print(f"{decode_device_string=} {resize_device_string=}") + decoder = torchcodec.decoders._core.create_from_file(video_path) + num_threads = None + if "cuda" in decode_device_string: + num_threads = 1 + width = None + height = None + if "native" in resize_device_string: + width = RESIZED_WIDTH + height = RESIZED_HEIGHT + torchcodec.decoders._core._add_video_stream( + decoder, + stream_index=-1, + device=decode_device_string, + num_threads=num_threads, + width=width, + height=height, + ) + + start_time = time.time() + frame_count = 0 + while True: + try: + frame, *_ = torchcodec.decoders._core.get_next_frame(decoder) + if resize_device_string != "none" and "native" not in resize_device_string: + frame = transfer_and_resize_frame(frame, resize_device_string) + + frame_count += 1 + except Exception as e: + print("EXCEPTION", e) + break + + end_time = time.time() + elapsed = end_time - start_time + fps = frame_count / (end_time - start_time) + print( + f"****** DECODED full video {decode_device_string=} {frame_count=} {elapsed=} {fps=}" + ) + return frame_count, end_time - start_time + + +def decode_videos_using_threads( + video_path, + decode_device_string, + resize_device_string, + num_videos, + num_threads, + use_multiple_gpus, +): + executor = ThreadPoolExecutor(max_workers=num_threads) + for i in range(num_videos): + actual_decode_device = decode_device_string + if "cuda" in decode_device_string and use_multiple_gpus: + actual_decode_device = f"cuda:{i % torch.cuda.device_count()}" + executor.submit( + decode_full_video, video_path, actual_decode_device, resize_device_string + ) + executor.shutdown(wait=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--devices", + default="cuda:0,cpu", + type=str, + help="Comma-separated devices to test decoding on.", + ) + parser.add_argument( + "--resize_devices", + default="cuda:0,cpu,native,none", + type=str, + help="Comma-separated devices to test preroc (resize) on. Use 'none' to specify no resize.", + ) + parser.add_argument( + "--video", + type=str, + default=str( + pathlib.Path(__file__).parent / "../../test/resources/nasa_13013.mp4" + ), + ) + parser.add_argument( + "--use_torch_benchmark", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Use pytorch benchmark to measure decode time with warmup and " + "autorange. Without this we just run one iteration without warmup " + "to measure the cold start time." + ), + ) + parser.add_argument( + "--num_threads", + type=int, + default=1, + help="Number of threads to use for decoding. Only used when --use_torch_benchmark is set.", + ) + parser.add_argument( + "--num_videos", + type=int, + default=50, + help="Number of videos to decode in parallel. Only used when --num_threads is set.", + ) + parser.add_argument( + "--use_multiple_gpus", + action=argparse.BooleanOptionalAction, + default=True, + help=("Use multiple GPUs to decode multiple videos in multi-threaded mode."), + ) + args = parser.parse_args() + video_path = args.video + + if not args.use_torch_benchmark: + for device in args.devices.split(","): + print("Testing on", device) + decode_full_video(video_path, device) + return + + resize_devices = args.resize_devices.split(",") + resize_devices = [d for d in resize_devices if d != ""] + if len(resize_devices) == 0: + resize_devices.append("none") + + label = "Decode+Resize Time" + + results = [] + for decode_device_string in args.devices.split(","): + for resize_device_string in resize_devices: + decode_label = decode_device_string + if "cuda" in decode_label: + # Shorten "cuda:0" to "cuda" + decode_label = "cuda" + resize_label = resize_device_string + if "cuda" in resize_device_string: + # Shorten "cuda:0" to "cuda" + resize_label = "cuda" + print("decode_device", decode_device_string) + print("resize_device", resize_device_string) + if args.num_threads > 1: + t = benchmark.Timer( + stmt="decode_videos_using_threads(video_path, decode_device_string, resize_device_string, num_videos, num_threads, use_multiple_gpus)", + globals={ + "decode_device_string": decode_device_string, + "video_path": video_path, + "decode_full_video": decode_full_video, + "decode_videos_using_threads": decode_videos_using_threads, + "resize_device_string": resize_device_string, + "num_videos": args.num_videos, + "num_threads": args.num_threads, + "use_multiple_gpus": args.use_multiple_gpus, + }, + label=label, + description=f"threads={args.num_threads} work={args.num_videos} video={os.path.basename(video_path)}", + sub_label=f"D={decode_label} R={resize_label} T={args.num_threads} W={args.num_videos}", + ).blocked_autorange() + results.append(t) + else: + t = benchmark.Timer( + stmt="decode_full_video(video_path, decode_device_string, resize_device_string)", + globals={ + "decode_device_string": decode_device_string, + "video_path": video_path, + "decode_full_video": decode_full_video, + "resize_device_string": resize_device_string, + }, + label=label, + description=f"video={os.path.basename(video_path)}", + sub_label=f"D={decode_label} R={resize_label}", + ).blocked_autorange() + results.append(t) + compare = benchmark.Compare(results) + compare.print() + print("Key: D=Decode, R=Resize T=threads W=work (number of videos to decode)") + print("Native resize is done as part of the decode step") + print("none resize means there is no resize step -- native or otherwise") + + +if __name__ == "__main__": + main() diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index eff5b1f66..2527c2177 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -34,12 +34,16 @@ function(make_torchcodec_library library_name ffmpeg_target) ${Python3_INCLUDE_DIRS} ) + set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES} + ${Python3_LIBRARIES}) + if(ENABLE_CUDA) + list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY} + ${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} ) + endif() target_link_libraries( ${library_name} PUBLIC - ${ffmpeg_target} - ${TORCH_LIBRARIES} - ${Python3_LIBRARIES} + ${NEEDED_LIBRARIES} ) # We already set the library_name to be libtorchcodecN, so we don't want diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 8676fefda..404d87502 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -1,4 +1,5 @@ #include +#include "src/torchcodec/decoders/_core/DeviceInterface.h" namespace facebook::torchcodec { @@ -6,14 +7,25 @@ namespace facebook::torchcodec { // So all functions will throw an error because they should only be called if // the device is not CPU. -void throwUnsupportedDeviceError(const torch::Device& device) { +[[noreturn]] void throwUnsupportedDeviceError(const torch::Device& device) { TORCH_CHECK( device.type() != torch::kCPU, "Device functions should only be called if the device is not CPU.") - throw std::runtime_error("Unsupported device: " + device.str()); + TORCH_CHECK(false, "Unsupported device: " + device.str()); } -void initializeDeviceContext(const torch::Device& device) { +void convertAVFrameToDecodedOutputOnCuda( + const torch::Device& device, + const VideoDecoder::VideoStreamDecoderOptions& options, + AVCodecContext* codecContext, + VideoDecoder::RawDecodedOutput& rawOutput, + VideoDecoder::DecodedOutput& output) { + throwUnsupportedDeviceError(device); +} + +void initializeContextOnCuda( + const torch::Device& device, + AVCodecContext* codecContext) { throwUnsupportedDeviceError(device); } diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 9aa7d8ab8..58234922b 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -1,6 +1,52 @@ +#include +#include +#include #include +#include "src/torchcodec/decoders/_core/DeviceInterface.h" +#include "src/torchcodec/decoders/_core/FFMPEGCommon.h" +#include "src/torchcodec/decoders/_core/VideoDecoder.h" + +extern "C" { +#include +#include +#include +} namespace facebook::torchcodec { +namespace { +AVBufferRef* getCudaContext(const torch::Device& device) { + enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); + TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); + torch::DeviceIndex deviceIndex = device.index(); + // FFMPEG cannot handle negative device indices. + // For single GPU- machines libtorch returns -1 for the device index. So for + // that case we set the device index to 0. + // TODO: Double check if this works for multi-GPU machines correctly. + deviceIndex = std::max(deviceIndex, 0); + std::string deviceOrdinal = std::to_string(deviceIndex); + AVBufferRef* hw_device_ctx; + int err = av_hwdevice_ctx_create( + &hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0); + if (err < 0) { + TORCH_CHECK( + false, + "Failed to create specified HW device", + getFFMPEGErrorStringFromErrorCode(err)); + } + return hw_device_ctx; +} + +torch::Tensor allocateDeviceTensor( + at::IntArrayRef shape, + torch::Device device, + const torch::Dtype dtype = torch::kUInt8) { + return torch::empty( + shape, + torch::TensorOptions() + .dtype(dtype) + .layout(torch::kStrided) + .device(device)); +} void throwErrorIfNonCudaDevice(const torch::Device& device) { TORCH_CHECK( @@ -10,13 +56,70 @@ void throwErrorIfNonCudaDevice(const torch::Device& device) { throw std::runtime_error("Unsupported device: " + device.str()); } } +} // namespace -void initializeDeviceContext(const torch::Device& device) { +void initializeContextOnCuda( + const torch::Device& device, + AVCodecContext* codecContext) { throwErrorIfNonCudaDevice(device); - // TODO: https://github.com/pytorch/torchcodec/issues/238: Implement CUDA - // device. - throw std::runtime_error( - "CUDA device is unimplemented. Follow this issue for tracking progress: https://github.com/pytorch/torchcodec/issues/238"); + // It is important for pytorch itself to create the cuda context. If ffmpeg + // creates the context it may not be compatible with pytorch. + // This is a dummy tensor to initialize the cuda context. + torch::Tensor dummyTensorForCudaInitialization = torch::empty( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); + codecContext->hw_device_ctx = getCudaContext(device); + return; +} + +void convertAVFrameToDecodedOutputOnCuda( + const torch::Device& device, + const VideoDecoder::VideoStreamDecoderOptions& options, + AVCodecContext* codecContext, + VideoDecoder::RawDecodedOutput& rawOutput, + VideoDecoder::DecodedOutput& output) { + AVFrame* src = rawOutput.frame.get(); + + TORCH_CHECK( + src->format == AV_PIX_FMT_CUDA, + "Expected format to be AV_PIX_FMT_CUDA, got " + + std::string(av_get_pix_fmt_name((AVPixelFormat)src->format))); + int width = options.width.value_or(codecContext->width); + int height = options.height.value_or(codecContext->height); + NppiSize oSizeROI = {width, height}; + Npp8u* input[2] = {src->data[0], src->data[1]}; + torch::Tensor& dst = output.frame; + dst = allocateDeviceTensor({height, width, 3}, options.device); + + // Use the user-requested GPU for running the NPP kernel. + c10::cuda::CUDAGuard deviceGuard(device); + + auto start = std::chrono::high_resolution_clock::now(); + + NppStatus status = nppiNV12ToRGB_8u_P2C3R( + input, + src->linesize[0], + static_cast(dst.data_ptr()), + dst.stride(0), + oSizeROI); + TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); + // Make the pytorch stream wait for the npp kernel to finish before using the + // output. + at::cuda::CUDAEvent nppDoneEvent; + at::cuda::CUDAStream nppStreamWrapper = + c10::cuda::getStreamFromExternal(nppGetStream(), device.index()); + nppDoneEvent.record(nppStreamWrapper); + nppDoneEvent.block(at::cuda::getCurrentCUDAStream()); + + auto end = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end - start; + VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width + << " took: " << duration.count() << "us" << std::endl; + if (options.dimensionOrder == "NCHW") { + // The docs guaranty this to return a view: + // https://pytorch.org/docs/stable/generated/torch.permute.html + dst = dst.permute({2, 0, 1}); + } } } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index dfe3247ae..3ef428fde 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -10,6 +10,11 @@ #include #include #include +#include "src/torchcodec/decoders/_core/VideoDecoder.h" + +extern "C" { +#include +} namespace facebook::torchcodec { @@ -23,6 +28,15 @@ namespace facebook::torchcodec { // Initialize the hardware device that is specified in `device`. Some builds // support CUDA and others only support CPU. -void initializeDeviceContext(const torch::Device& device); +void initializeContextOnCuda( + const torch::Device& device, + AVCodecContext* codecContext); + +void convertAVFrameToDecodedOutputOnCuda( + const torch::Device& device, + const VideoDecoder::VideoStreamDecoderOptions& options, + AVCodecContext* codecContext, + VideoDecoder::RawDecodedOutput& rawOutput, + VideoDecoder::DecodedOutput& output); } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 446b12a4b..c16009cd2 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -428,8 +428,12 @@ void VideoDecoder::addVideoStreamDecoder( streamInfo.codecContext.reset(codecContext); int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); - if (options.device.type() != torch::kCPU) { - initializeDeviceContext(options.device); + if (options.device.type() == torch::kCPU) { + // No more initialization needed for CPU. + } else if (options.device.type() == torch::kCUDA) { + initializeContextOnCuda(options.device, codecContext); + } else { + throw std::invalid_argument("Invalid device type: " + options.device.str()); } TORCH_CHECK_EQ(retVal, AVSUCCESS); retVal = avcodec_open2(streamInfo.codecContext.get(), codec, nullptr); @@ -856,6 +860,28 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( output.duration = getDuration(frame); output.durationSeconds = ptsToSeconds( getDuration(frame), formatContext_->streams[streamIndex]->time_base); + if (streamInfo.options.device.type() == torch::kCPU) { + convertAVFrameToDecodedOutputOnCPU(rawOutput, output); + } else if (streamInfo.options.device.type() == torch::kCUDA) { + convertAVFrameToDecodedOutputOnCuda( + streamInfo.options.device, + streamInfo.options, + streamInfo.codecContext.get(), + rawOutput, + output); + } else { + throw std::invalid_argument( + "Invalid device type: " + streamInfo.options.device.str()); + } + return output; +} + +void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( + VideoDecoder::RawDecodedOutput& rawOutput, + DecodedOutput& output) { + int streamIndex = rawOutput.streamIndex; + AVFrame* frame = rawOutput.frame.get(); + auto& streamInfo = streams_[streamIndex]; if (output.streamType == AVMEDIA_TYPE_VIDEO) { if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { int width = streamInfo.options.width.value_or(frame->width); @@ -884,7 +910,6 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( // audio decoding. throw std::runtime_error("Audio is not supported yet."); } - return output; } VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index a3a1888b5..41509adc0 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -364,6 +364,9 @@ class VideoDecoder { const AVFrame* frame); void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput); DecodedOutput convertAVFrameToDecodedOutput(RawDecodedOutput& rawOutput); + void convertAVFrameToDecodedOutputOnCPU( + RawDecodedOutput& rawOutput, + DecodedOutput& output); DecoderOptions options_; ContainerMetadata containerMetadata_; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 8b3a373d1..03800a043 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -168,8 +168,9 @@ void _add_video_stream( if (device.has_value()) { if (device.value() == "cpu") { options.device = torch::Device(torch::kCPU); - } else if (device.value() == "cuda") { - options.device = torch::Device(torch::kCUDA); + } else if (device.value().starts_with("cuda")) { + std::string deviceStr(device.value()); + options.device = torch::Device(deviceStr); } else { throw std::runtime_error( "Invalid device=" + std::string(device.value()) + diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 2fbe8d9ad..1bb28feb5 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -464,8 +464,19 @@ def test_color_conversion_library_with_generated_videos( def test_cuda_decoder(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) - with pytest.raises(RuntimeError, match="CUDA device is unimplemented"): - add_video_stream(decoder, device="cuda") + add_video_stream(decoder, device="cuda") + frame0, pts, duration = get_next_frame(decoder) + assert frame0.device.type == "cuda" + frame0_cpu = frame0.to("cpu") + reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) + # GPU decode is not bit-accurate. In the following assertion we ensure + # not more than 0.3% of values have a difference greater than 20. + diff = (reference_frame0.float() - frame0_cpu.float()).abs() + assert (diff > 20).float().mean() <= 0.003 + assert pts == torch.tensor([0]) + torch.testing.assert_close( + duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 + ) if __name__ == "__main__":