Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tests/test_installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@ def test_get_pip_commands_valid():
assert result == expected


def test_get_pip_commands_with_uv():
cmds = [["package1"], ["package2", "--upgrade"]]
expected = [
["uv", "pip", "install", "package1"],
["uv", "pip", "install", "package2", "--upgrade"],
]

result = get_pip_commands(cmds, use_uv=True)
assert result == expected


def test_get_pip_commands_none_input():
cmds = [["package1"], None]
with pytest.raises(AssertionError):
Expand Down
17 changes: 13 additions & 4 deletions torchruntime/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ def print_usage(entry_command: str):

Examples:
{entry_command} install
{entry_command} install --uv
{entry_command} install torch==2.2.0 torchvision==0.17.0
{entry_command} install torch>=2.0.0 torchaudio
{entry_command} install --uv torch>=2.0.0 torchaudio
{entry_command} install torch==2.1.* torchvision>=0.16.0 torchaudio==2.1.0

{entry_command} test # Runs all tests (import, devices, math, functions)
Expand All @@ -31,6 +32,9 @@ def print_usage(entry_command: str):
If no packages are specified, the latest available versions
of torch, torchaudio and torchvision will be installed.

Options:
--uv Use uv instead of pip for installation

Version specification formats (follows pip format):
package==2.1.0 Exact version
package>=2.0.0 Minimum version
Expand All @@ -56,16 +60,21 @@ def main():
command = sys.argv[1]

if command == "install":
package_versions = sys.argv[2:] if len(sys.argv) > 2 else None
install(package_versions)
args = sys.argv[2:] if len(sys.argv) > 2 else []
use_uv = "--uv" in args
# Remove --uv from args to get package list
package_versions = [arg for arg in args if arg != "--uv"] if args else None
install(package_versions, use_uv=use_uv)
elif command == "test":
subcommand = sys.argv[2] if len(sys.argv) > 2 else "all"
test(subcommand)
elif command == "info":
info()
else:
print(f"Unknown command: {command}")
print_usage()
entry_path = sys.argv[0]
cli = "python -m torchruntime" if "__main__.py" in entry_path else "torchruntime"
print_usage(cli)


if __name__ == "__main__":
Expand Down
13 changes: 9 additions & 4 deletions torchruntime/installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,13 @@ def get_install_commands(torch_platform, packages):
raise ValueError(f"Unsupported platform: {torch_platform}")


def get_pip_commands(cmds):
def get_pip_commands(cmds, use_uv=False):
assert not any(cmd is None for cmd in cmds)
return [PIP_PREFIX + cmd for cmd in cmds]
if use_uv:
pip_prefix = ["uv", "pip", "install"]
else:
pip_prefix = [sys.executable, "-m", "pip", "install"]
return [pip_prefix + cmd for cmd in cmds]


def run_commands(cmds):
Expand All @@ -87,13 +91,14 @@ def run_commands(cmds):
subprocess.run(cmd)


def install(packages=[]):
def install(packages=[], use_uv=False):
"""
packages: a list of strings with package names (and optionally their versions in pip-format). e.g. ["torch", "torchvision"] or ["torch>=2.0", "torchaudio==0.16.0"]. Defaults to ["torch", "torchvision", "torchaudio"].
use_uv: bool, whether to use uv for installation. Defaults to False.
"""

gpu_infos = get_gpus()
torch_platform = get_torch_platform(gpu_infos)
cmds = get_install_commands(torch_platform, packages)
cmds = get_pip_commands(cmds)
cmds = get_pip_commands(cmds, use_uv=use_uv)
run_commands(cmds)