Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/decoders/benchmark_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
import os
import timeit
from pathlib import Path

import torch
import torch.utils.benchmark as benchmark
Expand Down Expand Up @@ -303,9 +304,8 @@ def get_test_resource_path(filename: str) -> str:
resource = importlib.resources.files(__package__).joinpath(filename)
with importlib.resources.as_file(resource) as path:
return os.fspath(path)
return os.path.join(
os.path.dirname(__file__), "..", "..", "test", "resources", filename
)

return str(Path(__file__).parent / f"../../test/resources/{filename}")


def create_torchcodec_decoder_from_file(video_file):
Expand Down
11 changes: 4 additions & 7 deletions benchmarks/decoders/gpu_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import argparse
import os
import pathlib
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

import torch

Expand Down Expand Up @@ -102,9 +101,7 @@ def main():
parser.add_argument(
"--video",
type=str,
default=str(
pathlib.Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
),
default=str(Path(__file__).parent / "../../test/resources/nasa_13013.mp4"),
)
parser.add_argument(
"--use_torch_benchmark",
Expand Down Expand Up @@ -177,7 +174,7 @@ def main():
"use_multiple_gpus": args.use_multiple_gpus,
},
label=label,
description=f"threads={args.num_threads} work={args.num_videos} video={os.path.basename(video_path)}",
description=f"threads={args.num_threads} work={args.num_videos} video={Path(video_path).name}",
sub_label=f"D={decode_label} R={resize_label} T={args.num_threads} W={args.num_videos}",
).blocked_autorange()
results.append(t)
Expand All @@ -191,7 +188,7 @@ def main():
"resize_device_string": resize_device_string,
},
label=label,
description=f"video={os.path.basename(video_path)}",
description=f"video={Path(video_path).name}",
sub_label=f"D={decode_label} R={resize_label}",
).blocked_autorange()
results.append(t)
Expand Down
7 changes: 4 additions & 3 deletions src/torchcodec/_internally_replaced_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
# LICENSE file in the root directory of this source tree.

import importlib
import os
import sys
from pathlib import Path


# Copy pasted from torchvision
# https://github.com/pytorch/vision/blob/947ae1dc71867f28021d5bc0ff3a19c249236e2a/torchvision/_internally_replaced_utils.py#L25
def _get_extension_path(lib_name):
lib_dir = os.path.dirname(__file__)
extension_suffixes = []
if sys.platform == "linux":
extension_suffixes = importlib.machinery.EXTENSION_SUFFIXES
Expand All @@ -27,7 +26,9 @@ def _get_extension_path(lib_name):
extension_suffixes,
)

extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
extfinder = importlib.machinery.FileFinder(
str(Path(__file__).parent), loader_details
)
ext_specs = extfinder.find_spec(lib_name)
if ext_specs is None:
raise ImportError
Expand Down
4 changes: 2 additions & 2 deletions test/decoders/manual_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from pathlib import Path

import torchcodec
from torchvision.io.image import write_png

decoder = torchcodec.decoders._core.create_from_file(
os.path.dirname(__file__) + "/../resources/nasa_13013.mp4"
str(Path(__file__).parent / "../resources/nasa_13013.mp4")
)
torchcodec.decoders._core.scan_all_streams_to_update_metadata(decoder)
torchcodec.decoders._core.add_video_stream(decoder, stream_index=3)
Expand Down
Loading