From 851b399b008e27e69f62fdfbfef255503f746360 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Tue, 29 Oct 2024 11:53:19 -0700 Subject: [PATCH 1/8] . --- src/torchcodec/decoders/_core/CudaDevice.cpp | 19 +++++++++++-- .../decoders/_core/DeviceInterface.h | 3 +- .../decoders/_core/VideoDecoder.cpp | 7 +++-- test/decoders/test_video_decoder_ops.py | 28 +++++++++++++++++-- 4 files changed, 49 insertions(+), 8 deletions(-) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index dea0e7293..d86c374bc 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -201,7 +201,8 @@ void convertAVFrameToDecodedOutputOnCuda( const VideoDecoder::VideoStreamDecoderOptions& options, AVCodecContext* codecContext, VideoDecoder::RawDecodedOutput& rawOutput, - VideoDecoder::DecodedOutput& output) { + VideoDecoder::DecodedOutput& output, + std::optional preAllocatedOutputTensor) { AVFrame* src = rawOutput.frame.get(); TORCH_CHECK( @@ -213,7 +214,21 @@ void convertAVFrameToDecodedOutputOnCuda( NppiSize oSizeROI = {width, height}; Npp8u* input[2] = {src->data[0], src->data[1]}; torch::Tensor& dst = output.frame; - dst = allocateDeviceTensor({height, width, 3}, options.device); + if (preAllocatedOutputTensor.has_value()) { + dst = preAllocatedOutputTensor.value(); + auto shape = dst.sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == height) && (shape[1] == width) && + (shape[2] == 3), + "Expected tensor of shape ", + height, + "x", + width, + "x3, got ", + shape); + } else { + dst = allocateDeviceTensor({height, width, 3}, options.device); + } // Use the user-requested GPU for running the NPP kernel. c10::cuda::CUDAGuard deviceGuard(device); diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index e557e6232..772bdfe63 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -37,7 +37,8 @@ void convertAVFrameToDecodedOutputOnCuda( const VideoDecoder::VideoStreamDecoderOptions& options, AVCodecContext* codecContext, VideoDecoder::RawDecodedOutput& rawOutput, - VideoDecoder::DecodedOutput& output); + VideoDecoder::DecodedOutput& output, + std::optional preAllocatedOutputTensor = std::nullopt); void releaseContextOnCuda( const torch::Device& device, diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 76c744936..422b98bd6 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -196,7 +196,7 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( options.height.value_or(*metadata.height), options.width.value_or(*metadata.width), 3}, - {torch::kUInt8})), + at::TensorOptions(options.device).dtype(torch::kUInt8))), ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {} @@ -859,13 +859,14 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( convertAVFrameToDecodedOutputOnCPU( rawOutput, output, preAllocatedOutputTensor); } else if (streamInfo.options.device.type() == torch::kCUDA) { - // TODO: handle pre-allocated output tensor + // TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput. convertAVFrameToDecodedOutputOnCuda( streamInfo.options.device, streamInfo.options, streamInfo.codecContext.get(), rawOutput, - output); + output, + preAllocatedOutputTensor); } else { TORCH_CHECK( false, "Invalid device type: " + streamInfo.options.device.str()); diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 53b6a61c2..75ddc8202 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -57,6 +57,12 @@ def seek(self, pts: float): seek_to_pts(self.decoder, pts) +# Asserts that at most percentage of the elements are different by more than abs_tolerance. +def assert_tensor_nearly_equal(frame1, frame2, percentage=0.3, abs_tolerance=20): + diff = (frame2.float() - frame1.float()).abs() + assert (diff > abs_tolerance).float().mean() <= percentage / 100.0 + + class TestOps: def test_seek_and_next(self): decoder = create_from_file(str(NASA_VIDEO.path)) @@ -137,6 +143,24 @@ def test_get_frames_at_indices(self): assert_tensor_equal(frames0and180[0], reference_frame0) assert_tensor_equal(frames0and180[1], reference_frame180) + @needs_cuda + def test_get_frames_at_indices_with_cuda(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + scan_all_streams_to_update_metadata(decoder) + add_video_stream(decoder, device="cuda") + frames0and180, *_ = get_frames_at_indices( + decoder, stream_index=3, frame_indices=[0, 180] + ) + reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) + reference_frame180 = NASA_VIDEO.get_frame_data_by_index( + INDEX_OF_FRAME_AT_6_SECONDS + ) + assert frames0and180.device.type == "cuda" + assert_tensor_nearly_equal(frames0and180[0].to("cpu"), reference_frame0) + assert_tensor_nearly_equal( + frames0and180[1].to("cpu"), reference_frame180, 0.3, 30 + ) + def test_get_frames_at_indices_unsorted_indices(self): decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder) @@ -657,8 +681,8 @@ def test_cuda_decoder(self): assert frame0.device.type == "cuda" frame0_cpu = frame0.to("cpu") reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) - # GPU decode is not bit-accurate. In the following assertion we ensure - # not more than 0.3% of values have a difference greater than 20. + # GPU decode is not bit-accurate. So we allow some tolerance. + assert_tensor_nearly_equal(frame0_cpu, reference_frame0) diff = (reference_frame0.float() - frame0_cpu.float()).abs() assert (diff > 20).float().mean() <= 0.003 assert pts == torch.tensor([0]) From 328ea073c94c7eacc19564d6b1d0064fccbdf70d Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Tue, 29 Oct 2024 12:01:46 -0700 Subject: [PATCH 2/8] . --- test/decoders/test_video_decoder_ops.py | 14 +++++++------- test/utils.py | 6 ++++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 75ddc8202..f323aef19 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -36,7 +36,13 @@ seek_to_pts, ) -from ..utils import assert_tensor_equal, NASA_AUDIO, NASA_VIDEO, needs_cuda +from ..utils import ( + assert_tensor_equal, + assert_tensor_nearly_equal, + NASA_AUDIO, + NASA_VIDEO, + needs_cuda, +) torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -57,12 +63,6 @@ def seek(self, pts: float): seek_to_pts(self.decoder, pts) -# Asserts that at most percentage of the elements are different by more than abs_tolerance. -def assert_tensor_nearly_equal(frame1, frame2, percentage=0.3, abs_tolerance=20): - diff = (frame2.float() - frame1.float()).abs() - assert (diff > abs_tolerance).float().mean() <= percentage / 100.0 - - class TestOps: def test_seek_and_next(self): decoder = create_from_file(str(NASA_VIDEO.path)) diff --git a/test/utils.py b/test/utils.py index 573970731..5464aa339 100644 --- a/test/utils.py +++ b/test/utils.py @@ -33,6 +33,12 @@ def assert_tensor_equal(*args, **kwargs): torch.testing.assert_close(*args, **kwargs, atol=absolute_tolerance, rtol=0) +# Asserts that at most percentage of the elements are different by more than abs_tolerance. +def assert_tensor_nearly_equal(frame1, frame2, percentage=0.3, abs_tolerance=20): + diff = (frame2.float() - frame1.float()).abs() + assert (diff > abs_tolerance).float().mean() <= percentage / 100.0 + + # For use with floating point metadata, or in other instances where we are not confident # that reference and test tensors can be exactly equal. This is true for pts and duration # in seconds, as the reference values are from ffprobe's JSON output. In that case, it is From 2b47104ba9e21c7ce9b3a1c96507a00df1908cd4 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Tue, 29 Oct 2024 12:16:35 -0700 Subject: [PATCH 3/8] . --- test/decoders/test_video_decoder_ops.py | 34 +++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index f323aef19..33e004c18 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -222,6 +222,40 @@ def test_get_frames_by_pts(self): with pytest.raises(AssertionError): assert_tensor_equal(frames[0], frames[-1]) + # TODO: Figure out how to parameterize this test to run on both CPU and CUDA.abs + # The question is how to have the @needs_cuda decorator with the pytest.mark.parametrize + # decorator on the same test. + @needs_cuda + def test_get_frames_by_pts_with_cuda(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + _add_video_stream(decoder, device="cuda") + scan_all_streams_to_update_metadata(decoder) + stream_index = 3 + + # Note: 13.01 should give the last video frame for the NASA video + timestamps = [2, 0, 1, 0 + 1e-3, 13.01, 2 + 1e-3] + + expected_frames = [ + get_frame_at_pts(decoder, seconds=pts)[0] for pts in timestamps + ] + + frames, *_ = get_frames_by_pts( + decoder, + stream_index=stream_index, + timestamps=timestamps, + ) + for frame, expected_frame in zip(frames, expected_frames): + assert_tensor_equal(frame, expected_frame) + + # first and last frame should be equal, at pts=2 [+ eps]. We then modify + # the first frame and assert that it's now different from the last + # frame. This ensures a copy was properly made during the de-duplication + # logic. + assert_tensor_equal(frames[0], frames[-1]) + frames[0] += 20 + with pytest.raises(AssertionError): + assert_tensor_equal(frames[0], frames[-1]) + def test_pts_apis_against_index_ref(self): # Non-regression test for https://github.com/pytorch/torchcodec/pull/287 # Get all frames in the video, then query all frames with all time-based From 20889340264197c94b39b8420e19a4f865f09356 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Tue, 29 Oct 2024 13:10:07 -0700 Subject: [PATCH 4/8] . --- benchmarks/samplers/benchmark_samplers.py | 104 +++++++++++------- .../decoders/_core/CPUOnlyDevice.cpp | 3 +- src/torchcodec/decoders/_video_decoder.py | 6 +- 3 files changed, 70 insertions(+), 43 deletions(-) diff --git a/benchmarks/samplers/benchmark_samplers.py b/benchmarks/samplers/benchmark_samplers.py index 1ea363ed3..59d1d79eb 100644 --- a/benchmarks/samplers/benchmark_samplers.py +++ b/benchmarks/samplers/benchmark_samplers.py @@ -1,3 +1,4 @@ +import argparse from pathlib import Path from time import perf_counter_ns @@ -45,8 +46,7 @@ def report_stats(times, num_frames, unit="ms"): return med, fps -def sample(sampler, **kwargs): - decoder = VideoDecoder(VIDEO_PATH) +def sample(decoder, sampler, **kwargs): return sampler( decoder, num_frames_per_clip=10, @@ -54,42 +54,64 @@ def sample(sampler, **kwargs): ) -VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4" -NUM_EXP = 30 - -for num_clips in (1, 50): - print("-" * 10) - print(f"{num_clips = }") - - print("clips_at_random_indices ", end="") - times, num_frames = bench( - sample, clips_at_random_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2 - ) - report_stats(times, num_frames, unit="ms") - - print("clips_at_regular_indices ", end="") - times, num_frames = bench( - sample, clips_at_regular_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2 - ) - report_stats(times, num_frames, unit="ms") - - print("clips_at_random_timestamps ", end="") - times, num_frames = bench( - sample, - clips_at_random_timestamps, - num_clips=num_clips, - num_exp=NUM_EXP, - warmup=2, - ) - report_stats(times, num_frames, unit="ms") - - print("clips_at_regular_timestamps ", end="") - seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long - times, num_frames = bench( - sample, - clips_at_regular_timestamps, - seconds_between_clip_starts=seconds_between_clip_starts, - num_exp=NUM_EXP, - warmup=2, - ) - report_stats(times, num_frames, unit="ms") +def main(device, video): + NUM_EXP = 30 + + for num_clips in (1, 50): + print("-" * 10) + print(f"{num_clips = }") + + print("clips_at_random_indices ", end="") + decoder = VideoDecoder(video, device=device) + times, num_frames = bench( + sample, + decoder, + clips_at_random_indices, + num_clips=num_clips, + num_exp=NUM_EXP, + warmup=2, + ) + report_stats(times, num_frames, unit="ms") + + print("clips_at_regular_indices ", end="") + times, num_frames = bench( + sample, + decoder, + clips_at_regular_indices, + num_clips=num_clips, + num_exp=NUM_EXP, + warmup=2, + ) + report_stats(times, num_frames, unit="ms") + + print("clips_at_random_timestamps ", end="") + times, num_frames = bench( + sample, + decoder, + clips_at_random_timestamps, + num_clips=num_clips, + num_exp=NUM_EXP, + warmup=2, + ) + report_stats(times, num_frames, unit="ms") + + print("clips_at_regular_timestamps ", end="") + seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long + times, num_frames = bench( + sample, + decoder, + clips_at_regular_timestamps, + seconds_between_clip_starts=seconds_between_clip_starts, + num_exp=NUM_EXP, + warmup=2, + ) + report_stats(times, num_frames, unit="ms") + + +if __name__ == "__main__": + DEFAULT_VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4" + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--video", type=str, default=str(DEFAULT_VIDEO_PATH)) + args = parser.parse_args() + main(args.device, args.video) diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 910ed4a6b..20a4e3803 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -19,7 +19,8 @@ void convertAVFrameToDecodedOutputOnCuda( const VideoDecoder::VideoStreamDecoderOptions& options, AVCodecContext* codecContext, VideoDecoder::RawDecodedOutput& rawOutput, - VideoDecoder::DecodedOutput& output) { + VideoDecoder::DecodedOutput& output, + std::optional preAllocatedOutputTensor) { throwUnsupportedDeviceError(device); } diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index e401dd457..952e7d6cb 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Literal, Optional, Tuple, Union -from torch import Tensor +from torch import device, Tensor from torchcodec import Frame, FrameBatch from torchcodec.decoders import _core as core @@ -41,6 +41,8 @@ class VideoDecoder: instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded decoding which is best if you are running a single instance of ``VideoDecoder``. Default: 1. + device (str or torch.device, optional): The device to use for decoding. + .. note:: @@ -64,6 +66,7 @@ def __init__( stream_index: Optional[int] = None, dimension_order: Literal["NCHW", "NHWC"] = "NCHW", num_ffmpeg_threads: int = 1, + device: Optional[Union[str, device]] = "cpu", ): if isinstance(source, str): self._decoder = core.create_from_file(source) @@ -92,6 +95,7 @@ def __init__( stream_index=stream_index, dimension_order=dimension_order, num_threads=num_ffmpeg_threads, + device=device, ) self.metadata, self.stream_index = _get_and_validate_stream_metadata( From 85721ddaaa2eefb464548bc11e46500a79894a3a Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 30 Oct 2024 06:57:35 -0700 Subject: [PATCH 5/8] . --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 2 +- src/torchcodec/decoders/_video_decoder.py | 15 +++++++-------- test/decoders/test_video_decoder_ops.py | 8 ++++---- test/utils.py | 7 ++++--- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 422b98bd6..c6213a165 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -855,11 +855,11 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( output.duration = getDuration(frame); output.durationSeconds = ptsToSeconds( getDuration(frame), formatContext_->streams[streamIndex]->time_base); + // TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput. if (streamInfo.options.device.type() == torch::kCPU) { convertAVFrameToDecodedOutputOnCPU( rawOutput, output, preAllocatedOutputTensor); } else if (streamInfo.options.device.type() == torch::kCUDA) { - // TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput. convertAVFrameToDecodedOutputOnCuda( streamInfo.options.device, streamInfo.options, diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 952e7d6cb..bd1022de2 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -36,14 +36,6 @@ class VideoDecoder: This can be either "NCHW" (default) or "NHWC", where N is the batch size, C is the number of channels, H is the height, and W is the width of the frames. - num_ffmpeg_threads (int, optional): The number of threads to use for decoding. - Use 1 for single-threaded decoding which may be best if you are running multiple - instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded - decoding which is best if you are running a single instance of ``VideoDecoder``. - Default: 1. - device (str or torch.device, optional): The device to use for decoding. - - .. note:: Frames are natively decoded in NHWC format by the underlying @@ -51,6 +43,13 @@ class VideoDecoder: cheap no-copy operation that allows these frames to be transformed using the `torchvision transforms `_. + num_ffmpeg_threads (int, optional): The number of threads to use for decoding. + Use 1 for single-threaded decoding which may be best if you are running multiple + instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded + decoding which is best if you are running a single instance of ``VideoDecoder``. + Default: 1. + device (str or torch.device, optional): The device to use for decoding. Default: "cpu". + Attributes: metadata (VideoStreamMetadata): Metadata of the video stream. diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 33e004c18..825405b74 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -37,8 +37,8 @@ ) from ..utils import ( + assert_tensor_close_on_at_least, assert_tensor_equal, - assert_tensor_nearly_equal, NASA_AUDIO, NASA_VIDEO, needs_cuda, @@ -156,8 +156,8 @@ def test_get_frames_at_indices_with_cuda(self): INDEX_OF_FRAME_AT_6_SECONDS ) assert frames0and180.device.type == "cuda" - assert_tensor_nearly_equal(frames0and180[0].to("cpu"), reference_frame0) - assert_tensor_nearly_equal( + assert_tensor_close_on_at_least(frames0and180[0].to("cpu"), reference_frame0) + assert_tensor_close_on_at_least( frames0and180[1].to("cpu"), reference_frame180, 0.3, 30 ) @@ -716,7 +716,7 @@ def test_cuda_decoder(self): frame0_cpu = frame0.to("cpu") reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) # GPU decode is not bit-accurate. So we allow some tolerance. - assert_tensor_nearly_equal(frame0_cpu, reference_frame0) + assert_tensor_close_on_at_least(frame0_cpu, reference_frame0) diff = (reference_frame0.float() - frame0_cpu.float()).abs() assert (diff > 20).float().mean() <= 0.003 assert pts == torch.tensor([0]) diff --git a/test/utils.py b/test/utils.py index 5464aa339..e4b50260a 100644 --- a/test/utils.py +++ b/test/utils.py @@ -33,10 +33,11 @@ def assert_tensor_equal(*args, **kwargs): torch.testing.assert_close(*args, **kwargs, atol=absolute_tolerance, rtol=0) -# Asserts that at most percentage of the elements are different by more than abs_tolerance. -def assert_tensor_nearly_equal(frame1, frame2, percentage=0.3, abs_tolerance=20): +# Asserts that at least `percentage`% of the values are within the absolute tolerance. +def assert_tensor_close_on_at_least(frame1, frame2, percentage=99.7, abs_tolerance=20): diff = (frame2.float() - frame1.float()).abs() - assert (diff > abs_tolerance).float().mean() <= percentage / 100.0 + diff_percentage = 100.0 - percentage + assert (diff > abs_tolerance).float().mean() <= diff_percentage / 100.0 # For use with floating point metadata, or in other instances where we are not confident From 5597543b09a589a9da5c0c9b6f204ce55cd31eee Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 30 Oct 2024 07:59:26 -0700 Subject: [PATCH 6/8] . --- benchmarks/samplers/benchmark_samplers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/benchmarks/samplers/benchmark_samplers.py b/benchmarks/samplers/benchmark_samplers.py index 59d1d79eb..b9b27bc2c 100644 --- a/benchmarks/samplers/benchmark_samplers.py +++ b/benchmarks/samplers/benchmark_samplers.py @@ -54,7 +54,7 @@ def sample(decoder, sampler, **kwargs): ) -def main(device, video): +def run_sampler_benchmarks(device, video): NUM_EXP = 30 for num_clips in (1, 50): @@ -108,10 +108,14 @@ def main(device, video): report_stats(times, num_frames, unit="ms") -if __name__ == "__main__": +def main(): DEFAULT_VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4" parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--video", type=str, default=str(DEFAULT_VIDEO_PATH)) args = parser.parse_args() - main(args.device, args.video) + run_sampler_benchmarks(args.device, args.video) + + +if __name__ == "__main__": + main() From 74a1d24834a2bb6fefe90dafecb7163f0024beb7 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 30 Oct 2024 08:17:36 -0700 Subject: [PATCH 7/8] . --- .github/workflows/linux_cuda_wheel.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/linux_cuda_wheel.yaml b/.github/workflows/linux_cuda_wheel.yaml index 842013ec6..b4c5b69ad 100644 --- a/.github/workflows/linux_cuda_wheel.yaml +++ b/.github/workflows/linux_cuda_wheel.yaml @@ -89,7 +89,7 @@ jobs: # For some reason nvidia::libnpp=12.4 doesn't install but nvidia/label/cuda-12.4.0::libnpp does. # So we use the latter convention for libnpp. # We install conda packages at the start because otherwise conda may have conflicts with dependencies. - default-packages: "nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }} conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }}" + default-packages: "nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }} conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }} conda-forge::xorg-libxau" - name: Check env run: | ${CONDA_RUN} env From f717df6cc6301963f79e52cb7431e0c40aeede53 Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Wed, 30 Oct 2024 08:36:06 -0700 Subject: [PATCH 8/8] . --- .github/workflows/linux_cuda_wheel.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/linux_cuda_wheel.yaml b/.github/workflows/linux_cuda_wheel.yaml index b4c5b69ad..842013ec6 100644 --- a/.github/workflows/linux_cuda_wheel.yaml +++ b/.github/workflows/linux_cuda_wheel.yaml @@ -89,7 +89,7 @@ jobs: # For some reason nvidia::libnpp=12.4 doesn't install but nvidia/label/cuda-12.4.0::libnpp does. # So we use the latter convention for libnpp. # We install conda packages at the start because otherwise conda may have conflicts with dependencies. - default-packages: "nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }} conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }} conda-forge::xorg-libxau" + default-packages: "nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }} conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }}" - name: Check env run: | ${CONDA_RUN} env