diff --git a/benchmarks/decoders/BenchmarkDecodersMain.cpp b/benchmarks/decoders/BenchmarkDecodersMain.cpp index c1b15bafb..5be64e6d6 100644 --- a/benchmarks/decoders/BenchmarkDecodersMain.cpp +++ b/benchmarks/decoders/BenchmarkDecodersMain.cpp @@ -63,7 +63,7 @@ void runNDecodeIterations( decoder->addVideoStreamDecoder(-1); for (double pts : ptsList) { decoder->setCursorPtsInSeconds(pts); - torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame; + torch::Tensor tensor = decoder->getNextDecodedOutput().frame; } if (i + 1 == warmupIterations) { start = std::chrono::high_resolution_clock::now(); @@ -95,7 +95,7 @@ void runNdecodeIterationsGrabbingConsecutiveFrames( VideoDecoder::createFromFilePath(videoPath); decoder->addVideoStreamDecoder(-1); for (int j = 0; j < consecutiveFrameCount; ++j) { - torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame; + torch::Tensor tensor = decoder->getNextDecodedOutput().frame; } if (i + 1 == warmupIterations) { start = std::chrono::high_resolution_clock::now(); diff --git a/benchmarks/decoders/gpu_benchmark.py b/benchmarks/decoders/gpu_benchmark.py index 4fd7c8ad6..a19c1d431 100644 --- a/benchmarks/decoders/gpu_benchmark.py +++ b/benchmarks/decoders/gpu_benchmark.py @@ -5,59 +5,43 @@ import torch.utils.benchmark as benchmark import torchcodec -import torchvision.transforms.v2.functional as F +from torchvision.transforms import Resize -RESIZED_WIDTH = 256 -RESIZED_HEIGHT = 256 - -def transfer_and_resize_frame(frame, resize_device_string): - # This should be a no-op if the frame is already on the target device. - frame = frame.to(resize_device_string) - frame = F.resize(frame, (RESIZED_HEIGHT, RESIZED_WIDTH)) +def transfer_and_resize_frame(frame, device): + # This should be a no-op if the frame is already on the device. + frame = frame.to(device) + frame = Resize((256, 256))(frame) return frame -def decode_full_video(video_path, decode_device_string, resize_device_string): - # We use the core API instead of SimpleVideoDecoder because the core API - # allows us to natively resize as part of the decode step. - print(f"{decode_device_string=} {resize_device_string=}") +def decode_full_video(video_path, decode_device): decoder = torchcodec.decoders._core.create_from_file(video_path) num_threads = None - if "cuda" in decode_device_string: + if "cuda" in decode_device: num_threads = 1 - width = None - height = None - if "native" in resize_device_string: - width = RESIZED_WIDTH - height = RESIZED_HEIGHT torchcodec.decoders._core.add_video_stream( - decoder, - stream_index=-1, - device_string=decode_device_string, - num_threads=num_threads, - width=width, - height=height, + decoder, stream_index=0, device_string=decode_device, num_threads=num_threads ) - start_time = time.time() frame_count = 0 while True: try: frame, *_ = torchcodec.decoders._core.get_next_frame(decoder) - if resize_device_string != "none" and "native" not in resize_device_string: - frame = transfer_and_resize_frame(frame, resize_device_string) + # You can do a resize to simulate extra preproc work that happens + # on the GPU by uncommenting the following line: + # frame = transfer_and_resize_frame(frame, decode_device) frame_count += 1 except Exception as e: print("EXCEPTION", e) break - + # print(f"current {frame_count=}", flush=True) end_time = time.time() elapsed = end_time - start_time fps = frame_count / (end_time - start_time) print( - f"****** DECODED full video {decode_device_string=} {frame_count=} {elapsed=} {fps=}" + f"****** DECODED full video {decode_device=} {frame_count=} {elapsed=} {fps=}" ) return frame_count, end_time - start_time @@ -70,12 +54,6 @@ def main(): type=str, help="Comma-separated devices to test decoding on.", ) - parser.add_argument( - "--resize_devices", - default="cuda:0,cpu,native,none", - type=str, - help="Comma-separated devices to test preroc (resize) on. Use 'none' to specify no resize.", - ) parser.add_argument( "--video", type=str, @@ -100,44 +78,23 @@ def main(): decode_full_video(video_path, device) return - resize_devices = args.resize_devices.split(",") - resize_devices = [d for d in resize_devices if d != ""] - if len(resize_devices) == 0: - resize_devices.append("none") - - label = "Decode+Resize Time" - results = [] - for decode_device_string in args.devices.split(","): - for resize_device_string in resize_devices: - decode_label = decode_device_string - if "cuda" in decode_label: - # Shorten "cuda:0" to "cuda" - decode_label = "cuda" - resize_label = resize_device_string - if "cuda" in resize_device_string: - # Shorten "cuda:0" to "cuda" - resize_label = "cuda" - print("decode_device", decode_device_string) - print("resize_device", resize_device_string) - t = benchmark.Timer( - stmt="decode_full_video(video_path, decode_device_string, resize_device_string)", - globals={ - "decode_device_string": decode_device_string, - "video_path": video_path, - "decode_full_video": decode_full_video, - "resize_device_string": resize_device_string, - }, - label=label, - description=f"video={os.path.basename(video_path)}", - sub_label=f"D={decode_label} R={resize_label}", - ).blocked_autorange() - results.append(t) + for device in args.devices.split(","): + print("device", device) + t = benchmark.Timer( + stmt="decode_full_video(video_path, device)", + globals={ + "device": device, + "video_path": video_path, + "decode_full_video": decode_full_video, + }, + label="Decode+Resize Time", + sub_label=f"video={os.path.basename(video_path)}", + description=f"decode_device={device}", + ).blocked_autorange() + results.append(t) compare = benchmark.Compare(results) compare.print() - print("Key: D=Decode, R=Resize") - print("Native resize is done as part of the decode step") - print("none resize means there is no resize step -- native or otherwise") if __name__ == "__main__": diff --git a/examples/basic_example.py b/examples/basic_example.py index 693c8c47d..abbc1b469 100644 --- a/examples/basic_example.py +++ b/examples/basic_example.py @@ -171,14 +171,3 @@ def plot(frames: torch.Tensor, title : Optional[str] = None): # %% plot(frame_at_2_seconds.data, "Frame displayed at 2 seconds") plot(first_two_seconds.data, "Frames displayed during [0, 2) seconds") - -# %% -# Using a CUDA GPU to accelerate decoding -# --------------------------------------- -# -# If you have a CUDA GPU that has NVDEC, you can decode on the GPU. -if torch.cuda.is_available(): - cuda_decoder = SimpleVideoDecoder(raw_video_bytes, device="cuda:0") - cuda_frame = cuda_decoder.get_frame_displayed_at(seconds=2) - print(cuda_frame.data.device) # should be cuda:0 - plot(cuda_frame.data.to("cpu"), "Frame displayed at 2 seconds on CUDA") diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 4c9d00122..53a7abdc3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -759,7 +759,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getDecodedOutputWithFilter( if (activeStreamIndices_.size() == 0) { throw std::runtime_error("No active streams configured."); } - VLOG(9) << "Starting getNextDecodedOutputNoDemux()"; + VLOG(9) << "Starting getNextDecodedOutput()"; resetDecodeStats(); if (maybeDesiredPts_.has_value()) { VLOG(9) << "maybeDesiredPts_=" << *maybeDesiredPts_; @@ -920,7 +920,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( return output; } -VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux( +VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestamp( double seconds) { for (auto& [streamIndex, stream] : streams_) { double frameStartTime = ptsToSeconds(stream.currentPts, stream.timeBase); @@ -985,7 +985,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( int64_t pts = stream.allFrames[frameIndex].pts; setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase)); - return getNextDecodedOutputNoDemux(); + return getNextDecodedOutput(); } VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndexes( @@ -1138,7 +1138,7 @@ VideoDecoder::getFramesDisplayedByTimestampInRange( return output; } -VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() { +VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutput() { return getDecodedOutputWithFilter( [this](int frameStreamIndex, AVFrame* frame) { StreamInfo& activeStream = streams_[frameStreamIndex]; diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index a84d7da56..17a4f8a7d 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -153,8 +153,8 @@ class VideoDecoder { // ---- SINGLE FRAME SEEK AND DECODING API ---- // Places the cursor at the first frame on or after the position in seconds. - // Calling getNextDecodedOutputNoDemux() will return the first frame at or - // after this position. + // Calling getNextFrameAsTensor() will return the first frame at or after this + // position. void setCursorPtsInSeconds(double seconds); struct DecodedOutput { // The actual decoded output as a Tensor. @@ -180,14 +180,13 @@ class VideoDecoder { }; // Decodes the frame where the current cursor position is. It also advances // the cursor to the next frame. - DecodedOutput getNextDecodedOutputNoDemux(); - // Decodes the first frame in any added stream that is visible at a given - // timestamp. Frames in the video have a presentation timestamp and a - // duration. For example, if a frame has presentation timestamp of 5.0s and a - // duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0). - // i.e. it will be returned when this function is called with seconds=5.0 or - // seconds=5.999, etc. - DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds); + DecodedOutput getNextDecodedOutput(); + // Decodes the frame that is visible at a given timestamp. Frames in the video + // have a presentation timestamp and a duration. For example, if a frame has + // presentation timestamp of 5.0s and a duration of 1.0s, it will be visible + // in the timestamp range [5.0, 6.0). i.e. it will be returned when this + // function is called with seconds=5.0 or seconds=5.999, etc. + DecodedOutput getFrameDisplayedAtTimestamp(double seconds); DecodedOutput getFrameAtIndex(int streamIndex, int64_t frameIndex); struct BatchDecodedOutput { torch::Tensor frames; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 073ef658c..9ea81839d 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -147,7 +147,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); VideoDecoder::DecodedOutput result; try { - result = videoDecoder->getNextDecodedOutputNoDemux(); + result = videoDecoder->getNextDecodedOutput(); } catch (const VideoDecoder::EndOfFileException& e) { throw pybind11::stop_iteration(e.what()); } @@ -161,7 +161,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder) { OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto result = videoDecoder->getFrameDisplayedAtTimestampNoDemux(seconds); + auto result = videoDecoder->getFrameDisplayedAtTimestamp(seconds); return makeOpsDecodedOutput(result); } diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py index 3e773c565..c201962b7 100644 --- a/src/torchcodec/decoders/_simple_video_decoder.py +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Iterable, Iterator, Literal, Tuple, Union -from torch import device as torch_device, Tensor +from torch import Tensor from torchcodec.decoders import _core as core @@ -89,14 +89,6 @@ class SimpleVideoDecoder: This can be either "NCHW" (default) or "NHWC", where N is the batch size, C is the number of channels, H is the height, and W is the width of the frames. - device (torch.device, optional): The device to use for decoding. - Currently we only support CPU and CUDA devices. If CUDA is used, - we use NVDEC and CUDA to do decoding and color-conversion - respectively. The resulting frame is left on the GPU for further - processing. - You can either pass in a string like "cpu" or "cuda:0" or a - torch.device like torch.device("cuda:0"). - Default: ``torch.device("cpu")``. .. note:: @@ -114,7 +106,6 @@ def __init__( self, source: Union[str, Path, bytes, Tensor], dimension_order: Literal["NCHW", "NHWC"] = "NCHW", - device: Union[str, torch_device] = torch_device("cpu"), ): if isinstance(source, str): self._decoder = core.create_from_file(source) @@ -138,20 +129,7 @@ def __init__( ) core.scan_all_streams_to_update_metadata(self._decoder) - num_threads = None - if isinstance(device, str): - device = torch_device(device) - if device.type == "cuda": - # Using multiple CPU threads seems to slow down decoding on CUDA. - # CUDA internally uses dedicated hardware to do decoding so we - # don't need CPU software threads here. - num_threads = 1 - core.add_video_stream( - self._decoder, - dimension_order=dimension_order, - device_string=str(device), - num_threads=num_threads, - ) + core.add_video_stream(self._decoder, dimension_order=dimension_order) self.metadata, self._stream_index = _get_and_validate_stream_metadata( self._decoder diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 04cbed0a6..1fe19316a 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -152,7 +152,7 @@ TEST(VideoDecoderTest, RespectsWidthAndHeightFromOptions) { streamOptions.width = 100; streamOptions.height = 120; decoder->addVideoStreamDecoder(-1, streamOptions); - torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame; + torch::Tensor tensor = decoder->getNextDecodedOutput().frame; EXPECT_EQ(tensor.sizes(), std::vector({3, 120, 100})); } @@ -163,7 +163,7 @@ TEST(VideoDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) { VideoDecoder::VideoStreamDecoderOptions streamOptions; streamOptions.dimensionOrder = "NHWC"; decoder->addVideoStreamDecoder(-1, streamOptions); - torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame; + torch::Tensor tensor = decoder->getNextDecodedOutput().frame; EXPECT_EQ(tensor.sizes(), std::vector({270, 480, 3})); } @@ -172,12 +172,12 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) { std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); - auto output = ourDecoder->getNextDecodedOutputNoDemux(); + auto output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor0FromOurDecoder = output.frame; EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 0.0); EXPECT_EQ(output.pts, 0); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor1FromOurDecoder = output.frame; EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000); @@ -219,12 +219,12 @@ TEST(GPUVideoDecoderTest, ReturnsFirstTwoFramesOfVideo) { ASSERT_TRUE(streamOptions.device.is_cuda()); ASSERT_EQ(streamOptions.device.type(), torch::DeviceType::CUDA); ourDecoder->addVideoStreamDecoder(-1, streamOptions); - auto output = ourDecoder->getNextDecodedOutputNoDemux(); + auto output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor1FromOurDecoder = output.frame; EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 0.0); EXPECT_EQ(output.pts, 0); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor2FromOurDecoder = output.frame; EXPECT_EQ(tensor2FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000); @@ -306,11 +306,11 @@ TEST_P(VideoDecoderTest, SeeksCloseToEof) { createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); ourDecoder->setCursorPtsInSeconds(388388. / 30'000); - auto output = ourDecoder->getNextDecodedOutputNoDemux(); + auto output = ourDecoder->getNextDecodedOutput(); EXPECT_EQ(output.ptsSeconds, 388'388. / 30'000); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextDecodedOutput(); EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000); - EXPECT_THROW(ourDecoder->getNextDecodedOutputNoDemux(), std::exception); + EXPECT_THROW(ourDecoder->getNextDecodedOutput(), std::exception); } TEST_P(VideoDecoderTest, GetsFrameDisplayedAtTimestamp) { @@ -318,19 +318,18 @@ TEST_P(VideoDecoderTest, GetsFrameDisplayedAtTimestamp) { std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); - auto output = ourDecoder->getFrameDisplayedAtTimestampNoDemux(6.006); + auto output = ourDecoder->getFrameDisplayedAtTimestamp(6.006); EXPECT_EQ(output.ptsSeconds, 6.006); // The frame's duration is 0.033367 according to ffprobe, // so the next frame is displayed at timestamp=6.039367. const double kNextFramePts = 6.039366666666667; // The frame that is displayed a microsecond before the next frame is still // the previous frame. - output = - ourDecoder->getFrameDisplayedAtTimestampNoDemux(kNextFramePts - 1e-6); + output = ourDecoder->getFrameDisplayedAtTimestamp(kNextFramePts - 1e-6); EXPECT_EQ(output.ptsSeconds, 6.006); // The frame that is displayed at the exact pts of the frame is the next // frame. - output = ourDecoder->getFrameDisplayedAtTimestampNoDemux(kNextFramePts); + output = ourDecoder->getFrameDisplayedAtTimestamp(kNextFramePts); EXPECT_EQ(output.ptsSeconds, kNextFramePts); // This is the timestamp of the last frame in this video. @@ -340,7 +339,7 @@ TEST_P(VideoDecoderTest, GetsFrameDisplayedAtTimestamp) { kPtsOfLastFrameInVideoStream + kDurationOfLastFrameInVideoStream; // Sanity check: make sure duration is strictly positive. EXPECT_GT(kPtsPlusDurationOfLastFrame, kPtsOfLastFrameInVideoStream); - output = ourDecoder->getFrameDisplayedAtTimestampNoDemux( + output = ourDecoder->getFrameDisplayedAtTimestamp( kPtsPlusDurationOfLastFrame - 1e-6); EXPECT_EQ(output.ptsSeconds, kPtsOfLastFrameInVideoStream); } @@ -351,7 +350,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); ourDecoder->setCursorPtsInSeconds(6.0); - auto output = ourDecoder->getNextDecodedOutputNoDemux(); + auto output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor6FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000); torch::Tensor tensor6FromFFMPEG = @@ -367,7 +366,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { EXPECT_GT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 180); ourDecoder->setCursorPtsInSeconds(6.1); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor61FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 183'183. / 30'000); torch::Tensor tensor61FromFFMPEG = @@ -387,7 +386,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { EXPECT_LT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 10); ourDecoder->setCursorPtsInSeconds(10.0); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor10FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 300'300. / 30'000); torch::Tensor tensor10FromFFMPEG = @@ -404,7 +403,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { EXPECT_GT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 60); ourDecoder->setCursorPtsInSeconds(6.0); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextDecodedOutput(); tensor6FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000); EXPECT_TRUE(torch::equal(tensor6FromOurDecoder, tensor6FromFFMPEG)); @@ -419,7 +418,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { constexpr double kPtsOfLastFrameInVideoStream = 389'389. / 30'000; // ~12.9 ourDecoder->setCursorPtsInSeconds(kPtsOfLastFrameInVideoStream); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextDecodedOutput(); torch::Tensor tensor7FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000); torch::Tensor tensor7FromFFMPEG = diff --git a/test/decoders/test_simple_video_decoder.py b/test/decoders/test_simple_video_decoder.py index e7d4d8915..424a16873 100644 --- a/test/decoders/test_simple_video_decoder.py +++ b/test/decoders/test_simple_video_decoder.py @@ -45,34 +45,6 @@ def test_create_fails(self): with pytest.raises(TypeError, match="Unknown source type"): decoder = SimpleVideoDecoder(123) # noqa - def test_can_accept_devices(self): - # You can pass a CPU device as a string... - decoder = SimpleVideoDecoder(NASA_VIDEO.path, device="cpu") - assert_tensor_equal(decoder[0], NASA_VIDEO.get_frame_data_by_index(0)) - - # ...or as a torch.device. - decoder = SimpleVideoDecoder(NASA_VIDEO.path, device=torch.device("cpu")) - assert_tensor_equal(decoder[0], NASA_VIDEO.get_frame_data_by_index(0)) - - if torch.cuda.is_available(): - # You can pass a CUDA device as a string... - decoder = SimpleVideoDecoder(NASA_VIDEO.path, device="cuda") - frame = decoder[0] - assert frame.device.type == "cuda" - assert frame.shape == torch.Size( - [NASA_VIDEO.num_color_channels, NASA_VIDEO.height, NASA_VIDEO.width] - ) - - # ...or as a torch.device. - decoder = SimpleVideoDecoder(NASA_VIDEO.path, device=torch.device("cuda")) - frame = decoder[0] - assert frame.device.type == "cuda" - assert frame.shape == torch.Size( - [NASA_VIDEO.num_color_channels, NASA_VIDEO.height, NASA_VIDEO.width] - ) - # TODO: compare tensor values too. We don't compare values because - # the exact values are hardware-dependent. - def test_getitem_int(self): decoder = SimpleVideoDecoder(NASA_VIDEO.path)