From f51d283531b5b7697ee4b2e6c58fbded993c1d7d Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 13 Nov 2024 08:45:15 -0800 Subject: [PATCH 1/2] Fix colorspace --- src/torchcodec/decoders/_core/CudaDevice.cpp | 23 +++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 8aa464e4e..75dea7d8f 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -223,13 +223,24 @@ void convertAVFrameToDecodedOutputOnCuda( Npp8u* input[2] = {src->data[0], src->data[1]}; auto start = std::chrono::high_resolution_clock::now(); - NppStatus status = nppiNV12ToRGB_8u_P2C3R( - input, - src->linesize[0], - static_cast(dst.data_ptr()), - dst.stride(0), - oSizeROI); + NppStatus status; + if (src->colorspace == AVColorSpace::AVCOL_SPC_BT709) { + status = nppiNV12ToRGB_709HDTV_8u_P2C3R( + input, + src->linesize[0], + static_cast(dst.data_ptr()), + dst.stride(0), + oSizeROI); + } else { + status = nppiNV12ToRGB_8u_P2C3R( + input, + src->linesize[0], + static_cast(dst.data_ptr()), + dst.stride(0), + oSizeROI); + } TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); + // Make the pytorch stream wait for the npp kernel to finish before using the // output. at::cuda::CUDAEvent nppDoneEvent; From d7db262127b25e9f9969db4b319a7a55ce5c00b5 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 13 Nov 2024 09:14:34 -0800 Subject: [PATCH 2/2] Improve color accuracy of BT709 frames on CUDA --- test/decoders/test_video_decoder_ops.py | 6 ++++-- test/utils.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index eb3b55db6..ec56d9fe9 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -118,8 +118,10 @@ def test_get_frame_at_pts(self, device): # return the next frame since the right boundary of the interval is # open. next_frame, _, _ = get_frame_at_pts(decoder, 6.039367) - with pytest.raises(AssertionError): - frame_compare_function(next_frame, reference_frame6.to(device)) + if device == "cpu": + # We can only compare exact equality on CPU. + with pytest.raises(AssertionError): + frame_compare_function(next_frame, reference_frame6.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_at_index(self, device): diff --git a/test/utils.py b/test/utils.py index 89957da7a..0fa843851 100644 --- a/test/utils.py +++ b/test/utils.py @@ -44,7 +44,7 @@ def assert_tensor_equal(*args, **kwargs): # Asserts that at least `percentage`% of the values are within the absolute tolerance. def assert_tensor_close_on_at_least( - actual_tensor, ref_tensor, percentage=90, abs_tolerance=20 + actual_tensor, ref_tensor, percentage=90, abs_tolerance=19 ): assert ( actual_tensor.device == ref_tensor.device