diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 75dea7d8f..5d48d26a1 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -225,7 +225,7 @@ void convertAVFrameToDecodedOutputOnCuda( auto start = std::chrono::high_resolution_clock::now(); NppStatus status; if (src->colorspace == AVColorSpace::AVCOL_SPC_BT709) { - status = nppiNV12ToRGB_709HDTV_8u_P2C3R( + status = nppiNV12ToRGB_709CSC_8u_P2C3R( input, src->linesize[0], static_cast(dst.data_ptr()), diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 4ad6fed13..854ea91ca 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -15,7 +15,6 @@ assert_tensor_close, assert_tensor_equal, cpu_and_cuda, - get_frame_compare_function, H265_VIDEO, NASA_VIDEO, ) @@ -74,11 +73,10 @@ def test_getitem_int(self, num_ffmpeg_threads, device): ref_frame180 = NASA_VIDEO.get_frame_data_by_index(180).to(device) ref_frame_last = NASA_VIDEO.get_frame_data_by_index(289).to(device) - frame_compare_function = get_frame_compare_function(device) - frame_compare_function(ref_frame0, decoder[0]) - frame_compare_function(ref_frame1, decoder[1]) - frame_compare_function(ref_frame180, decoder[180]) - frame_compare_function(ref_frame_last, decoder[-1]) + assert_tensor_equal(ref_frame0, decoder[0]) + assert_tensor_equal(ref_frame1, decoder[1]) + assert_tensor_equal(ref_frame180, decoder[180]) + assert_tensor_equal(ref_frame_last, decoder[-1]) def test_getitem_numpy_int(self): decoder = VideoDecoder(NASA_VIDEO.path) @@ -113,7 +111,6 @@ def test_getitem_numpy_int(self): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_getitem_slice(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device) - frame_compare_function = get_frame_compare_function(device) # ensure that the degenerate case of a range of size 1 works @@ -127,7 +124,7 @@ def test_getitem_slice(self, device): NASA_VIDEO.width, ] ) - frame_compare_function(ref0, slice0) + assert_tensor_equal(ref0, slice0) ref4 = NASA_VIDEO.get_frame_data_by_range(4, 5).to(device) slice4 = decoder[4:5] @@ -139,7 +136,7 @@ def test_getitem_slice(self, device): NASA_VIDEO.width, ] ) - frame_compare_function(ref4, slice4) + assert_tensor_equal(ref4, slice4) ref8 = NASA_VIDEO.get_frame_data_by_range(8, 9).to(device) slice8 = decoder[8:9] @@ -151,7 +148,7 @@ def test_getitem_slice(self, device): NASA_VIDEO.width, ] ) - frame_compare_function(ref8, slice8) + assert_tensor_equal(ref8, slice8) ref180 = NASA_VIDEO.get_frame_data_by_index(180).to(device) slice180 = decoder[180:181] @@ -163,7 +160,7 @@ def test_getitem_slice(self, device): NASA_VIDEO.width, ] ) - frame_compare_function(ref180, slice180[0]) + assert_tensor_equal(ref180, slice180[0]) # contiguous ranges ref0_9 = NASA_VIDEO.get_frame_data_by_range(0, 9).to(device) @@ -176,7 +173,7 @@ def test_getitem_slice(self, device): NASA_VIDEO.width, ] ) - frame_compare_function(ref0_9, slice0_9) + assert_tensor_equal(ref0_9, slice0_9) ref4_8 = NASA_VIDEO.get_frame_data_by_range(4, 8).to(device) slice4_8 = decoder[4:8] @@ -188,7 +185,7 @@ def test_getitem_slice(self, device): NASA_VIDEO.width, ] ) - frame_compare_function(ref4_8, slice4_8) + assert_tensor_equal(ref4_8, slice4_8) # ranges with a stride ref15_35 = NASA_VIDEO.get_frame_data_by_range(15, 36, 5).to(device) @@ -201,7 +198,7 @@ def test_getitem_slice(self, device): NASA_VIDEO.width, ] ) - frame_compare_function(ref15_35, slice15_35) + assert_tensor_equal(ref15_35, slice15_35) ref0_9_2 = NASA_VIDEO.get_frame_data_by_range(0, 9, 2).to(device) slice0_9_2 = decoder[0:9:2] @@ -213,7 +210,7 @@ def test_getitem_slice(self, device): NASA_VIDEO.width, ] ) - frame_compare_function(ref0_9_2, slice0_9_2) + assert_tensor_equal(ref0_9_2, slice0_9_2) # negative numbers in the slice ref386_389 = NASA_VIDEO.get_frame_data_by_range(386, 390).to(device) @@ -226,15 +223,15 @@ def test_getitem_slice(self, device): NASA_VIDEO.width, ] ) - frame_compare_function(ref386_389, slice386_389) + assert_tensor_equal(ref386_389, slice386_389) # an empty range is valid! empty_frame = decoder[5:5] - frame_compare_function(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device)) + assert_tensor_equal(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device)) # slices that are out-of-range are also valid - they return an empty tensor also_empty = decoder[10000:] - frame_compare_function(also_empty, NASA_VIDEO.empty_chw_tensor.to(device)) + assert_tensor_equal(also_empty, NASA_VIDEO.empty_chw_tensor.to(device)) # should be just a copy all_frames = decoder[:].to(device) @@ -247,7 +244,7 @@ def test_getitem_slice(self, device): ] ) for sliced, ref in zip(all_frames, decoder): - frame_compare_function(sliced, ref) + assert_tensor_equal(sliced, ref) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_getitem_fails(self, device): @@ -275,27 +272,26 @@ def test_iteration(self, device): ref_frame35 = NASA_VIDEO.get_frame_data_by_index(35).to(device) ref_frame180 = NASA_VIDEO.get_frame_data_by_index(180).to(device) ref_frame_last = NASA_VIDEO.get_frame_data_by_index(289).to(device) - frame_compare_function = get_frame_compare_function(device) # Access an arbitrary frame to make sure that the later iteration # still works as expected. The underlying C++ decoder object is # actually stateful, and accessing a frame will move its internal # cursor. - frame_compare_function(ref_frame35, decoder[35]) + assert_tensor_equal(ref_frame35, decoder[35]) for i, frame in enumerate(decoder): if i == 0: - frame_compare_function(ref_frame0, frame) + assert_tensor_equal(ref_frame0, frame) elif i == 1: - frame_compare_function(ref_frame1, frame) + assert_tensor_equal(ref_frame1, frame) elif i == 9: - frame_compare_function(ref_frame9, frame) + assert_tensor_equal(ref_frame9, frame) elif i == 35: - frame_compare_function(ref_frame35, frame) + assert_tensor_equal(ref_frame35, frame) elif i == 180: - frame_compare_function(ref_frame180, frame) + assert_tensor_equal(ref_frame180, frame) elif i == 389: - frame_compare_function(ref_frame_last, frame) + assert_tensor_equal(ref_frame_last, frame) def test_iteration_slow(self): decoder = VideoDecoder(NASA_VIDEO.path) @@ -314,12 +310,11 @@ def test_iteration_slow(self): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_at(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device) - frame_compare_function = get_frame_compare_function(device) ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9).to(device) frame9 = decoder.get_frame_at(9) - frame_compare_function(ref_frame9, frame9.data) + assert_tensor_equal(ref_frame9, frame9.data) assert isinstance(frame9.pts_seconds, float) expected_frame_info = NASA_VIDEO.get_frame_info(9) assert frame9.pts_seconds == pytest.approx(expected_frame_info.pts_seconds) @@ -330,19 +325,19 @@ def test_get_frame_at(self, device): # test numpy.int64 frame9 = decoder.get_frame_at(numpy.int64(9)) - frame_compare_function(ref_frame9, frame9.data) + assert_tensor_equal(ref_frame9, frame9.data) # test numpy.int32 frame9 = decoder.get_frame_at(numpy.int32(9)) - frame_compare_function(ref_frame9, frame9.data) + assert_tensor_equal(ref_frame9, frame9.data) # test numpy.uint64 frame9 = decoder.get_frame_at(numpy.uint64(9)) - frame_compare_function(ref_frame9, frame9.data) + assert_tensor_equal(ref_frame9, frame9.data) # test numpy.uint32 frame9 = decoder.get_frame_at(numpy.uint32(9)) - frame_compare_function(ref_frame9, frame9.data) + assert_tensor_equal(ref_frame9, frame9.data) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_at_tuple_unpacking(self, device): @@ -368,16 +363,15 @@ def test_get_frame_at_fails(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_at(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device) - frame_compare_function = get_frame_compare_function(device) frames = decoder.get_frames_at([35, 25]) assert isinstance(frames, FrameBatch) - frame_compare_function( + assert_tensor_equal( frames[0].data, NASA_VIDEO.get_frame_data_by_index(35).to(device) ) - frame_compare_function( + assert_tensor_equal( frames[1].data, NASA_VIDEO.get_frame_data_by_index(25).to(device) ) @@ -421,16 +415,15 @@ def test_get_frames_at_fails(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_played_at(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device) - frame_compare_function = get_frame_compare_function(device) ref_frame_played_at_6 = NASA_VIDEO.get_frame_data_by_index(180).to(device) - frame_compare_function( + assert_tensor_equal( ref_frame_played_at_6, decoder.get_frame_played_at(6.006).data ) - frame_compare_function( + assert_tensor_equal( ref_frame_played_at_6, decoder.get_frame_played_at(6.02).data ) - frame_compare_function( + assert_tensor_equal( ref_frame_played_at_6, decoder.get_frame_played_at(6.039366).data ) assert isinstance(decoder.get_frame_played_at(6.02).pts_seconds, float) @@ -459,7 +452,6 @@ def test_get_frame_played_at_fails(self, device): def test_get_frames_played_at(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device) - frame_compare_function = get_frame_compare_function(device) # Note: We know the frame at ~0.84s has index 25, the one at 1.16s has # index 35. We use those indices as reference to test against. @@ -470,7 +462,7 @@ def test_get_frames_played_at(self, device): assert isinstance(frames, FrameBatch) for i in range(len(reference_indices)): - frame_compare_function( + assert_tensor_equal( frames.data[i], NASA_VIDEO.get_frame_data_by_index(reference_indices[i]).to(device), ) @@ -512,7 +504,6 @@ def test_get_frames_in_range(self, stream_index, device): decoder = VideoDecoder( NASA_VIDEO.path, stream_index=stream_index, device=device ) - frame_compare_function = get_frame_compare_function(device) # test degenerate case where we only actually get 1 frame ref_frames9 = NASA_VIDEO.get_frame_data_by_range( @@ -520,7 +511,7 @@ def test_get_frames_in_range(self, stream_index, device): ).to(device) frames9 = decoder.get_frames_in_range(start=9, stop=10) - frame_compare_function(ref_frames9, frames9.data) + assert_tensor_equal(ref_frames9, frames9.data) assert frames9.pts_seconds.device.type == "cpu" assert frames9.pts_seconds[0].item() == pytest.approx( @@ -546,7 +537,7 @@ def test_get_frames_in_range(self, stream_index, device): NASA_VIDEO.get_width(stream_index=stream_index), ] ) - frame_compare_function(ref_frames0_9, frames0_9.data) + assert_tensor_equal(ref_frames0_9, frames0_9.data) assert_tensor_close( NASA_VIDEO.get_pts_seconds_by_range(0, 10, stream_index=stream_index), frames0_9.pts_seconds, @@ -569,7 +560,7 @@ def test_get_frames_in_range(self, stream_index, device): NASA_VIDEO.get_width(stream_index=stream_index), ] ) - frame_compare_function(ref_frames0_8_2, frames0_8_2.data) + assert_tensor_equal(ref_frames0_8_2, frames0_8_2.data) assert_tensor_close( NASA_VIDEO.get_pts_seconds_by_range(0, 10, 2, stream_index=stream_index), frames0_8_2.pts_seconds, @@ -585,7 +576,7 @@ def test_get_frames_in_range(self, stream_index, device): frames0_8_2 = decoder.get_frames_in_range( start=numpy.int64(0), stop=numpy.int64(10), step=numpy.int64(2) ) - frame_compare_function(ref_frames0_8_2, frames0_8_2.data) + assert_tensor_equal(ref_frames0_8_2, frames0_8_2.data) # an empty range is valid! empty_frames = decoder.get_frames_in_range(5, 5) @@ -640,7 +631,6 @@ def test_get_frames_by_pts_in_range(self, stream_index, device): decoder = VideoDecoder( NASA_VIDEO.path, stream_index=stream_index, device=device ) - frame_compare_function = get_frame_compare_function(device) # Note that we are comparing the results of VideoDecoder's method: # get_frames_played_in_range() @@ -663,7 +653,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device): frames0_4 = decoder.get_frames_played_in_range( decoder.get_frame_at(0).pts_seconds, decoder.get_frame_at(5).pts_seconds ) - frame_compare_function( + assert_tensor_equal( frames0_4.data, NASA_VIDEO.get_frame_data_by_range(0, 5, stream_index=stream_index).to( device @@ -675,7 +665,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device): decoder.get_frame_at(0).pts_seconds, decoder.get_frame_at(4).pts_seconds + HALF_DURATION, ) - frame_compare_function(also_frames0_4.data, frames0_4.data) + assert_tensor_equal(also_frames0_4.data, frames0_4.data) # Again, the intention here is to provide the exact values we care about. In practice, our # pts values are slightly smaller, so we nudge the start upwards. @@ -683,7 +673,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device): decoder.get_frame_at(5).pts_seconds, decoder.get_frame_at(10).pts_seconds, ) - frame_compare_function( + assert_tensor_equal( frames5_9.data, NASA_VIDEO.get_frame_data_by_range(5, 10, stream_index=stream_index).to( device @@ -697,7 +687,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device): decoder.get_frame_at(6).pts_seconds, decoder.get_frame_at(6).pts_seconds + HALF_DURATION, ) - frame_compare_function( + assert_tensor_equal( frame6.data, NASA_VIDEO.get_frame_data_by_range(6, 7, stream_index=stream_index).to( device @@ -709,7 +699,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device): decoder.get_frame_at(35).pts_seconds, decoder.get_frame_at(35).pts_seconds + 1e-10, ) - frame_compare_function( + assert_tensor_equal( frame35.data, NASA_VIDEO.get_frame_data_by_range(35, 36, stream_index=stream_index).to( device @@ -725,7 +715,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device): NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds + HALF_DURATION, ) - frame_compare_function( + assert_tensor_equal( frames7_8.data, NASA_VIDEO.get_frame_data_by_range(7, 9, stream_index=stream_index).to( device @@ -737,7 +727,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device): NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds, NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds, ) - frame_compare_function( + assert_tensor_equal( empty_frame.data, NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index).to(device), ) @@ -755,7 +745,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device): NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds + HALF_DURATION, ) - frame_compare_function( + assert_tensor_equal( frame0.data, NASA_VIDEO.get_frame_data_by_range(0, 1, stream_index=stream_index).to( device @@ -767,7 +757,7 @@ def test_get_frames_by_pts_in_range(self, stream_index, device): all_frames = decoder.get_frames_played_in_range( decoder.metadata.begin_stream_seconds, decoder.metadata.end_stream_seconds ) - frame_compare_function(all_frames.data, decoder[:]) + assert_tensor_equal(all_frames.data, decoder[:]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_by_pts_in_range_fails(self, device): diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index ec56d9fe9..bc3d0e62d 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -37,10 +37,8 @@ ) from ..utils import ( - assert_tensor_close_on_at_least, assert_tensor_equal, cpu_and_cuda, - get_frame_compare_function, NASA_AUDIO, NASA_VIDEO, needs_cuda, @@ -70,50 +68,47 @@ class TestOps: def test_seek_and_next(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) - frame_compare_function = get_frame_compare_function(device) frame0, _, _ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) - frame_compare_function(frame0, reference_frame0.to(device)) + assert_tensor_equal(frame0, reference_frame0.to(device)) reference_frame1 = NASA_VIDEO.get_frame_data_by_index(1) frame1, _, _ = get_next_frame(decoder) - frame_compare_function(frame1, reference_frame1.to(device)) + assert_tensor_equal(frame1, reference_frame1.to(device)) seek_to_pts(decoder, 6.0) frame_time6, _, _ = get_next_frame(decoder) reference_frame_time6 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS ) - frame_compare_function(frame_time6, reference_frame_time6.to(device)) + assert_tensor_equal(frame_time6, reference_frame_time6.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_to_negative_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder, device=device) - frame_compare_function = get_frame_compare_function(device) frame0, _, _ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) - frame_compare_function(frame0, reference_frame0.to(device)) + assert_tensor_equal(frame0, reference_frame0.to(device)) seek_to_pts(decoder, -1e-4) frame0, _, _ = get_next_frame(decoder) - frame_compare_function(frame0, reference_frame0.to(device)) + assert_tensor_equal(frame0, reference_frame0.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_at_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) - frame_compare_function = get_frame_compare_function(device) # This frame has pts=6.006 and duration=0.033367, so it should be visible # at timestamps in the range [6.006, 6.039367) (not including the last timestamp). frame6, _, _ = get_frame_at_pts(decoder, 6.006) reference_frame6 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS ) - frame_compare_function(frame6, reference_frame6.to(device)) + assert_tensor_equal(frame6, reference_frame6.to(device)) frame6, _, _ = get_frame_at_pts(decoder, 6.02) - frame_compare_function(frame6, reference_frame6.to(device)) + assert_tensor_equal(frame6, reference_frame6.to(device)) frame6, _, _ = get_frame_at_pts(decoder, 6.039366) - frame_compare_function(frame6, reference_frame6.to(device)) + assert_tensor_equal(frame6, reference_frame6.to(device)) # Note that this timestamp is exactly on a frame boundary, so it should # return the next frame since the right boundary of the interval is # open. @@ -121,43 +116,40 @@ def test_get_frame_at_pts(self, device): if device == "cpu": # We can only compare exact equality on CPU. with pytest.raises(AssertionError): - frame_compare_function(next_frame, reference_frame6.to(device)) + assert_tensor_equal(next_frame, reference_frame6.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_at_index(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder, device=device) - frame_compare_function = get_frame_compare_function(device) frame0, _, _ = get_frame_at_index(decoder, stream_index=3, frame_index=0) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) - frame_compare_function(frame0, reference_frame0.to(device)) + assert_tensor_equal(frame0, reference_frame0.to(device)) # The frame that is played at 6 seconds is frame 180 from a 0-based index. frame6, _, _ = get_frame_at_index(decoder, stream_index=3, frame_index=180) reference_frame6 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS ) - frame_compare_function(frame6, reference_frame6.to(device)) + assert_tensor_equal(frame6, reference_frame6.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_with_info_at_index(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder, device=device) - frame_compare_function = get_frame_compare_function(device) frame6, pts, duration = get_frame_at_index( decoder, stream_index=3, frame_index=180 ) reference_frame6 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS ) - frame_compare_function(frame6, reference_frame6.to(device)) + assert_tensor_equal(frame6, reference_frame6.to(device)) assert pts.item() == pytest.approx(6.006, rel=1e-3) assert duration.item() == pytest.approx(0.03337, rel=1e-3) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_at_indices(self, device): - frame_compare_function = get_frame_compare_function(device) decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder, device=device) @@ -168,8 +160,8 @@ def test_get_frames_at_indices(self, device): reference_frame180 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS ) - frame_compare_function(frames0and180[0], reference_frame0.to(device)) - frame_compare_function(frames0and180[1], reference_frame180.to(device)) + assert_tensor_equal(frames0and180[0], reference_frame0.to(device)) + assert_tensor_equal(frames0and180[1], reference_frame180.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_at_indices_unsorted_indices(self, device): @@ -297,7 +289,6 @@ def test_pts_apis_against_index_ref(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_in_range(self, device): - frame_compare_function = get_frame_compare_function(device) decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder, device=device) @@ -305,57 +296,56 @@ def test_get_frames_in_range(self, device): # ensure that the degenerate case of a range of size 1 works ref_frame0 = NASA_VIDEO.get_frame_data_by_range(0, 1) bulk_frame0, *_ = get_frames_in_range(decoder, stream_index=3, start=0, stop=1) - frame_compare_function(bulk_frame0, ref_frame0.to(device)) + assert_tensor_equal(bulk_frame0, ref_frame0.to(device)) ref_frame1 = NASA_VIDEO.get_frame_data_by_range(1, 2) bulk_frame1, *_ = get_frames_in_range(decoder, stream_index=3, start=1, stop=2) - frame_compare_function(bulk_frame1, ref_frame1.to(device)) + assert_tensor_equal(bulk_frame1, ref_frame1.to(device)) ref_frame389 = NASA_VIDEO.get_frame_data_by_range(389, 390) bulk_frame389, *_ = get_frames_in_range( decoder, stream_index=3, start=389, stop=390 ) - frame_compare_function(bulk_frame389, ref_frame389.to(device)) + assert_tensor_equal(bulk_frame389, ref_frame389.to(device)) # contiguous ranges ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(0, 9) bulk_frames0_9, *_ = get_frames_in_range( decoder, stream_index=3, start=0, stop=9 ) - frame_compare_function(bulk_frames0_9, ref_frames0_9.to(device)) + assert_tensor_equal(bulk_frames0_9, ref_frames0_9.to(device)) ref_frames4_8 = NASA_VIDEO.get_frame_data_by_range(4, 8) bulk_frames4_8, *_ = get_frames_in_range( decoder, stream_index=3, start=4, stop=8 ) - frame_compare_function(bulk_frames4_8, ref_frames4_8.to(device)) + assert_tensor_equal(bulk_frames4_8, ref_frames4_8.to(device)) # ranges with a stride ref_frames15_35 = NASA_VIDEO.get_frame_data_by_range(15, 36, 5) bulk_frames15_35, *_ = get_frames_in_range( decoder, stream_index=3, start=15, stop=36, step=5 ) - frame_compare_function(bulk_frames15_35, ref_frames15_35.to(device)) + assert_tensor_equal(bulk_frames15_35, ref_frames15_35.to(device)) ref_frames0_9_2 = NASA_VIDEO.get_frame_data_by_range(0, 9, 2) bulk_frames0_9_2, *_ = get_frames_in_range( decoder, stream_index=3, start=0, stop=9, step=2 ) - frame_compare_function(bulk_frames0_9_2, ref_frames0_9_2.to(device)) + assert_tensor_equal(bulk_frames0_9_2, ref_frames0_9_2.to(device)) # an empty range is valid! empty_frame, *_ = get_frames_in_range(decoder, stream_index=3, start=5, stop=5) - frame_compare_function(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device)) + assert_tensor_equal(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_throws_exception_at_eof(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, device=device) - frame_compare_function = get_frame_compare_function(device) seek_to_pts(decoder, 12.979633) last_frame, _, _ = get_next_frame(decoder) reference_last_frame = NASA_VIDEO.get_frame_data_by_index(289) - frame_compare_function(last_frame, reference_last_frame.to(device)) + assert_tensor_equal(last_frame, reference_last_frame.to(device)) with pytest.raises(IndexError, match="no more frames"): get_next_frame(decoder) @@ -384,14 +374,13 @@ def get_frame1_and_frame_time6(decoder): # NB: create needs to happen outside the torch.compile region, # for now. Otherwise torch.compile constant-props it. decoder = create_from_file(str(NASA_VIDEO.path)) - frame_compare_function = get_frame_compare_function(device) frame0, frame_time6 = get_frame1_and_frame_time6(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) reference_frame_time6 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS ) - frame_compare_function(frame0, reference_frame0.to(device)) - frame_compare_function(frame_time6, reference_frame_time6.to(device)) + assert_tensor_equal(frame0, reference_frame0.to(device)) + assert_tensor_equal(frame_time6, reference_frame_time6.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_class_based_compile_seek_and_next(self, device): @@ -406,14 +395,13 @@ def class_based_get_frame1_and_frame_time6( return frame0, frame_time6 decoder = ReferenceDecoder(device=device) - frame_compare_function = get_frame_compare_function(device) frame0, frame_time6 = class_based_get_frame1_and_frame_time6(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) reference_frame_time6 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS ) - frame_compare_function(frame0, reference_frame0.to(device)) - frame_compare_function(frame_time6, reference_frame_time6.to(device)) + assert_tensor_equal(frame0, reference_frame0.to(device)) + assert_tensor_equal(frame_time6, reference_frame_time6.to(device)) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("create_from", ("file", "tensor", "bytes")) @@ -431,19 +419,18 @@ def test_create_decoder(self, create_from, device): decoder = create_from_bytes(video_bytes) add_video_stream(decoder, device=device) - frame_compare_function = get_frame_compare_function(device) frame0, _, _ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) - frame_compare_function(frame0, reference_frame0.to(device)) + assert_tensor_equal(frame0, reference_frame0.to(device)) reference_frame1 = NASA_VIDEO.get_frame_data_by_index(1) frame1, _, _ = get_next_frame(decoder) - frame_compare_function(frame1, reference_frame1.to(device)) + assert_tensor_equal(frame1, reference_frame1.to(device)) seek_to_pts(decoder, 6.0) frame_time6, _, _ = get_next_frame(decoder) reference_frame_time6 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS ) - frame_compare_function(frame_time6, reference_frame_time6.to(device)) + assert_tensor_equal(frame_time6, reference_frame_time6.to(device)) # Keeping the metadata tests below for now, but we should remove them # once we remove get_json_metadata(). @@ -703,12 +690,8 @@ def test_cuda_decoder(self): add_video_stream(decoder, device="cuda") frame0, pts, duration = get_next_frame(decoder) assert frame0.device.type == "cuda" - frame0_cpu = frame0.to("cpu") reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) - # GPU decode is not bit-accurate. So we allow some tolerance. - assert_tensor_close_on_at_least(frame0_cpu, reference_frame0) - diff = (reference_frame0.float() - frame0_cpu.float()).abs() - assert (diff > 20).float().mean() <= 0.003 + assert_tensor_equal(frame0, reference_frame0.to("cuda")) assert pts == torch.tensor([0]) torch.testing.assert_close( duration, torch.tensor(0.0334).double(), atol=0, rtol=1e-3 diff --git a/test/samplers/test_video_clip_sampler.py b/test/samplers/test_video_clip_sampler.py index 963a1c9ad..01cd55f1e 100644 --- a/test/samplers/test_video_clip_sampler.py +++ b/test/samplers/test_video_clip_sampler.py @@ -10,7 +10,7 @@ VideoClipSampler, ) -from ..utils import assert_tensor_equal, NASA_VIDEO +from ..utils import NASA_VIDEO @pytest.mark.parametrize( @@ -36,7 +36,7 @@ def test_sampler(sampler_args): video_args = VideoArgs(desired_width=desired_width, desired_height=desired_height) sampler = VideoClipSampler(video_args, sampler_args) clips = sampler(NASA_VIDEO.to_tensor()) - assert_tensor_equal(len(clips), sampler_args.clips_per_video) + assert len(clips) == sampler_args.clips_per_video clip = clips[0] if isinstance(sampler_args, TimeBasedSamplerArgs): # Note: Looks like we have an API inconsistency. diff --git a/test/utils.py b/test/utils.py index 0fa843851..d0663b0a8 100644 --- a/test/utils.py +++ b/test/utils.py @@ -23,42 +23,24 @@ def cpu_and_cuda(): return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) -def get_frame_compare_function(device): - if device == "cpu": - return assert_tensor_equal - else: - return assert_tensor_close_on_at_least - - -# For use with decoded data frames. On Linux, we expect exact, bit-for-bit equality. On -# all other platforms, we allow a small tolerance. FFmpeg does not guarantee bit-for-bit -# equality across systems and architectures, so we also cannot. We currently use Linux -# on x86_64 as our reference system. +# For use with decoded data frames. On CPU Linux, we expect exact, bit-for-bit +# equality. On CUDA Linux, we expect a small tolerance. +# On other platforms (e.g. MacOS), we also allow a small tolerance. FFmpeg does +# not guarantee bit-for-bit equality across systems and architectures, so we +# also cannot. We currently use Linux on x86_64 as our reference system. def assert_tensor_equal(*args, **kwargs): if sys.platform == "linux": - absolute_tolerance = 0 + if args[0].device.type == "cuda": + # CUDA tensors are not exactly equal on Linux, so we need to use a + # higher tolerance. + absolute_tolerance = 2 + else: + absolute_tolerance = 0 else: absolute_tolerance = 3 torch.testing.assert_close(*args, **kwargs, atol=absolute_tolerance, rtol=0) -# Asserts that at least `percentage`% of the values are within the absolute tolerance. -def assert_tensor_close_on_at_least( - actual_tensor, ref_tensor, percentage=90, abs_tolerance=19 -): - assert ( - actual_tensor.device == ref_tensor.device - ), f"Devices don't match: {actual_tensor.device} vs {ref_tensor.device}" - diff = (ref_tensor.float() - actual_tensor.float()).abs() - max_diff_percentage = 100.0 - percentage - if diff.sum() == 0: - return - diff_percentage = (diff > abs_tolerance).float().mean() * 100.0 - assert ( - diff_percentage <= max_diff_percentage - ), f"Diff too high: {diff_percentage} > {max_diff_percentage}" - - # For use with floating point metadata, or in other instances where we are not confident # that reference and test tensors can be exactly equal. This is true for pts and duration # in seconds, as the reference values are from ffprobe's JSON output. In that case, it is