diff --git a/tests/test_installer.py b/tests/test_installer.py index 70e7d21..ade0358 100644 --- a/tests/test_installer.py +++ b/tests/test_installer.py @@ -92,6 +92,17 @@ def test_get_pip_commands_valid(): assert result == expected +def test_get_pip_commands_with_uv(): + cmds = [["package1"], ["package2", "--upgrade"]] + expected = [ + ["uv", "pip", "install", "package1"], + ["uv", "pip", "install", "package2", "--upgrade"], + ] + + result = get_pip_commands(cmds, use_uv=True) + assert result == expected + + def test_get_pip_commands_none_input(): cmds = [["package1"], None] with pytest.raises(AssertionError): diff --git a/torchruntime/__main__.py b/torchruntime/__main__.py index 8079d11..97d1acd 100644 --- a/torchruntime/__main__.py +++ b/torchruntime/__main__.py @@ -15,8 +15,9 @@ def print_usage(entry_command: str): Examples: {entry_command} install + {entry_command} install --uv {entry_command} install torch==2.2.0 torchvision==0.17.0 - {entry_command} install torch>=2.0.0 torchaudio + {entry_command} install --uv torch>=2.0.0 torchaudio {entry_command} install torch==2.1.* torchvision>=0.16.0 torchaudio==2.1.0 {entry_command} test # Runs all tests (import, devices, math, functions) @@ -31,6 +32,9 @@ def print_usage(entry_command: str): If no packages are specified, the latest available versions of torch, torchaudio and torchvision will be installed. +Options: + --uv Use uv instead of pip for installation + Version specification formats (follows pip format): package==2.1.0 Exact version package>=2.0.0 Minimum version @@ -56,8 +60,11 @@ def main(): command = sys.argv[1] if command == "install": - package_versions = sys.argv[2:] if len(sys.argv) > 2 else None - install(package_versions) + args = sys.argv[2:] if len(sys.argv) > 2 else [] + use_uv = "--uv" in args + # Remove --uv from args to get package list + package_versions = [arg for arg in args if arg != "--uv"] if args else None + install(package_versions, use_uv=use_uv) elif command == "test": subcommand = sys.argv[2] if len(sys.argv) > 2 else "all" test(subcommand) @@ -65,7 +72,9 @@ def main(): info() else: print(f"Unknown command: {command}") - print_usage() + entry_path = sys.argv[0] + cli = "python -m torchruntime" if "__main__.py" in entry_path else "torchruntime" + print_usage(cli) if __name__ == "__main__": diff --git a/torchruntime/installer.py b/torchruntime/installer.py index 4c5f248..475244f 100644 --- a/torchruntime/installer.py +++ b/torchruntime/installer.py @@ -76,9 +76,13 @@ def get_install_commands(torch_platform, packages): raise ValueError(f"Unsupported platform: {torch_platform}") -def get_pip_commands(cmds): +def get_pip_commands(cmds, use_uv=False): assert not any(cmd is None for cmd in cmds) - return [PIP_PREFIX + cmd for cmd in cmds] + if use_uv: + pip_prefix = ["uv", "pip", "install"] + else: + pip_prefix = [sys.executable, "-m", "pip", "install"] + return [pip_prefix + cmd for cmd in cmds] def run_commands(cmds): @@ -87,13 +91,14 @@ def run_commands(cmds): subprocess.run(cmd) -def install(packages=[]): +def install(packages=[], use_uv=False): """ packages: a list of strings with package names (and optionally their versions in pip-format). e.g. ["torch", "torchvision"] or ["torch>=2.0", "torchaudio==0.16.0"]. Defaults to ["torch", "torchvision", "torchaudio"]. + use_uv: bool, whether to use uv for installation. Defaults to False. """ gpu_infos = get_gpus() torch_platform = get_torch_platform(gpu_infos) cmds = get_install_commands(torch_platform, packages) - cmds = get_pip_commands(cmds) + cmds = get_pip_commands(cmds, use_uv=use_uv) run_commands(cmds)