From 745e22e501223dfddd9edd86928390de4fe847bb Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Wed, 3 Sep 2025 15:21:31 +0000 Subject: [PATCH] Improve UX with private repos * Fetch and check credentials stored on the server side if no credentials provided via the command line (otherwise, check the provided credentials as usual) * Detect the default branch using the provided or stored credentials Closes: https://github.com/dstackai/dstack/issues/3061 --- src/dstack/_internal/cli/commands/init.py | 4 +- .../cli/services/configurators/run.py | 72 ++++++-- src/dstack/_internal/cli/services/repos.py | 19 +- src/dstack/_internal/core/services/repos.py | 165 +++++++++--------- src/dstack/_internal/utils/ssh.py | 7 + src/dstack/api/_public/repos.py | 47 ++++- 6 files changed, 193 insertions(+), 121 deletions(-) diff --git a/src/dstack/_internal/cli/commands/init.py b/src/dstack/_internal/cli/commands/init.py index 1aab72d40..7df156bb7 100644 --- a/src/dstack/_internal/cli/commands/init.py +++ b/src/dstack/_internal/cli/commands/init.py @@ -6,12 +6,12 @@ from dstack._internal.cli.commands import BaseCommand from dstack._internal.cli.services.repos import ( get_repo_from_dir, - get_repo_from_url, is_git_repo_url, register_init_repo_args, ) from dstack._internal.cli.utils.common import configure_logging, confirm_ask, console, warn from dstack._internal.core.errors import ConfigurationError +from dstack._internal.core.models.repos.remote import RemoteRepo from dstack._internal.core.services.configs import ConfigManager from dstack.api import Client @@ -101,7 +101,7 @@ def _command(self, args: argparse.Namespace): if repo_url is not None: # Dummy repo branch to avoid autodetection that fails on private repos. # We don't need branch/hash for repo_id anyway. - repo = get_repo_from_url(repo_url, repo_branch="master") + repo = RemoteRepo.from_url(repo_url, repo_branch="master") elif repo_path is not None: repo = get_repo_from_dir(repo_path, local=local) else: diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 3c7c03be9..c51042b57 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -17,7 +17,6 @@ from dstack._internal.cli.services.profile import apply_profile_args, register_profile_args from dstack._internal.cli.services.repos import ( get_repo_from_dir, - get_repo_from_url, init_default_virtual_repo, is_git_repo_url, register_init_repo_args, @@ -43,13 +42,19 @@ ServiceConfiguration, TaskConfiguration, ) +from dstack._internal.core.models.repos import RepoHeadWithCreds from dstack._internal.core.models.repos.base import Repo from dstack._internal.core.models.repos.local import LocalRepo +from dstack._internal.core.models.repos.remote import RemoteRepo, RemoteRepoCreds from dstack._internal.core.models.resources import CPUSpec from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus from dstack._internal.core.services.configs import ConfigManager from dstack._internal.core.services.diff import diff_models -from dstack._internal.core.services.repos import load_repo +from dstack._internal.core.services.repos import ( + InvalidRepoCredentialsError, + get_repo_creds_and_default_branch, + load_repo, +) from dstack._internal.utils.common import local_time from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator from dstack._internal.utils.logging import get_logger @@ -524,15 +529,17 @@ def get_repo( return init_default_virtual_repo(api=self.api) repo: Optional[Repo] = None + repo_head: Optional[RepoHeadWithCreds] = None repo_branch: Optional[str] = configurator_args.repo_branch repo_hash: Optional[str] = configurator_args.repo_hash + repo_creds: Optional[RemoteRepoCreds] = None + git_identity_file: Optional[str] = configurator_args.git_identity_file + git_private_key: Optional[str] = None + oauth_token: Optional[str] = configurator_args.gh_token # Should we (re)initialize the repo? # If any Git credentials provided, we reinitialize the repo, as the user may have provided # updated credentials. - init = ( - configurator_args.git_identity_file is not None - or configurator_args.gh_token is not None - ) + init = git_identity_file is not None or oauth_token is not None url: Optional[str] = None local_path: Optional[Path] = None @@ -565,15 +572,15 @@ def get_repo( local_path = Path.cwd() legacy_local_path = True if url: - repo = get_repo_from_url(repo_url=url, repo_branch=repo_branch, repo_hash=repo_hash) - if not self.api.repos.is_initialized(repo, by_user=True): - init = True + # "master" is a dummy value, we'll fetch the actual default branch later + repo = RemoteRepo.from_url(repo_url=url, repo_branch="master") + repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True) elif local_path: if legacy_local_path: if repo_config := config_manager.get_repo_config(local_path): repo = load_repo(repo_config) - # allow users with legacy configurations use shared repo creds - if self.api.repos.is_initialized(repo, by_user=False): + repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True) + if repo_head is not None: warn( "The repo is not specified but found and will be used in the run\n" "Future versions will not load repos automatically\n" @@ -600,20 +607,55 @@ def get_repo( ) local: bool = configurator_args.local repo = get_repo_from_dir(local_path, local=local) - if not self.api.repos.is_initialized(repo, by_user=True): - init = True + repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True) + if isinstance(repo, RemoteRepo): + repo_branch = repo.run_repo_data.repo_branch + repo_hash = repo.run_repo_data.repo_hash else: assert False, "should not reach here" if repo is None: return init_default_virtual_repo(api=self.api) + if isinstance(repo, RemoteRepo): + assert repo.repo_url is not None + + if repo_head is not None and repo_head.repo_creds is not None: + if git_identity_file is None and oauth_token is None: + git_private_key = repo_head.repo_creds.private_key + oauth_token = repo_head.repo_creds.oauth_token + else: + init = True + + try: + repo_creds, default_repo_branch = get_repo_creds_and_default_branch( + repo_url=repo.repo_url, + identity_file=git_identity_file, + private_key=git_private_key, + oauth_token=oauth_token, + ) + except InvalidRepoCredentialsError as e: + raise CLIError(*e.args) from e + + if repo_branch is None and repo_hash is None: + repo_branch = default_repo_branch + if repo_branch is None: + raise CLIError( + "Failed to automatically detect remote repo branch." + " Specify branch or hash." + ) + repo = RemoteRepo.from_url( + repo_url=repo.repo_url, repo_branch=repo_branch, repo_hash=repo_hash + ) + if init: self.api.repos.init( repo=repo, - git_identity_file=configurator_args.git_identity_file, - oauth_token=configurator_args.gh_token, + git_identity_file=git_identity_file, + oauth_token=oauth_token, + creds=repo_creds, ) + if isinstance(repo, LocalRepo): warn( f"{repo.repo_dir} is a local repo\n" diff --git a/src/dstack/_internal/cli/services/repos.py b/src/dstack/_internal/cli/services/repos.py index 0abc9d9ef..9f74aca8a 100644 --- a/src/dstack/_internal/cli/services/repos.py +++ b/src/dstack/_internal/cli/services/repos.py @@ -1,5 +1,5 @@ import argparse -from typing import Literal, Optional, Union, overload +from typing import Literal, Union, overload import git @@ -8,7 +8,6 @@ from dstack._internal.core.models.repos.local import LocalRepo from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo, RepoError from dstack._internal.core.models.repos.virtual import VirtualRepo -from dstack._internal.core.services.repos import get_default_branch from dstack._internal.utils.path import PathLike from dstack.api._public import Client @@ -43,22 +42,6 @@ def init_default_virtual_repo(api: Client) -> VirtualRepo: return repo -def get_repo_from_url( - repo_url: str, repo_branch: Optional[str] = None, repo_hash: Optional[str] = None -) -> RemoteRepo: - if repo_branch is None and repo_hash is None: - repo_branch = get_default_branch(repo_url) - if repo_branch is None: - raise CLIError( - "Failed to automatically detect remote repo branch. Specify branch or hash." - ) - return RemoteRepo.from_url( - repo_url=repo_url, - repo_branch=repo_branch, - repo_hash=repo_hash, - ) - - @overload def get_repo_from_dir(repo_dir: PathLike, local: Literal[False] = False) -> RemoteRepo: ... diff --git a/src/dstack/_internal/core/services/repos.py b/src/dstack/_internal/core/services/repos.py index a054519fd..3090863e6 100644 --- a/src/dstack/_internal/core/services/repos.py +++ b/src/dstack/_internal/core/services/repos.py @@ -1,9 +1,10 @@ import os +from contextlib import suppress from pathlib import Path +from tempfile import NamedTemporaryFile from typing import Optional, Union import git.cmd -import requests import yaml from git.exc import GitCommandError @@ -13,135 +14,139 @@ from dstack._internal.core.models.repos.remote import GitRepoURL from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike -from dstack._internal.utils.ssh import ( - get_host_config, - make_ssh_command_for_git, - try_ssh_key_passphrase, -) +from dstack._internal.utils.ssh import get_host_config, make_git_env, try_ssh_key_passphrase logger = get_logger(__name__) gh_config_path = os.path.expanduser("~/.config/gh/hosts.yml") default_ssh_key = os.path.expanduser("~/.ssh/id_rsa") -no_prompt_env = dict(GIT_TERMINAL_PROMPT="0") - class InvalidRepoCredentialsError(DstackError): pass -def get_local_repo_credentials( +def get_repo_creds_and_default_branch( repo_url: str, identity_file: Optional[PathLike] = None, + private_key: Optional[str] = None, oauth_token: Optional[str] = None, -) -> RemoteRepoCreds: +) -> tuple[RemoteRepoCreds, Optional[str]]: url = GitRepoURL.parse(repo_url, get_ssh_config=get_host_config) # no auth - r = requests.get(f"{url.as_https()}/info/refs?service=git-upload-pack", timeout=10) - if r.status_code == 200: - return RemoteRepoCreds( - clone_url=url.as_https(), - private_key=None, - oauth_token=None, - ) - - # user-provided ssh key - if identity_file is not None: - identity_file = os.path.expanduser(identity_file) - return check_remote_repo_credentials_ssh(url, identity_file) - - # user-provided oauth token + with suppress(InvalidRepoCredentialsError): + return _get_repo_creds_and_default_branch_https(url) + + # ssh key provided by the user or pulled from the server + if identity_file is not None or private_key is not None: + if identity_file is not None: + private_key = _read_private_key(identity_file) + return _get_repo_creds_and_default_branch_ssh(url, identity_file, private_key) + elif private_key is not None: + with NamedTemporaryFile("w+", 0o600) as f: + f.write(private_key) + f.flush() + return _get_repo_creds_and_default_branch_ssh(url, f.name, private_key) + else: + assert False, "should not reach here" + + # oauth token provided by the user or pulled from the server if oauth_token is not None: - return check_remote_repo_credentials_https(url, oauth_token) + return _get_repo_creds_and_default_branch_https(url, oauth_token) # key from ssh config identities = get_host_config(url.original_host).get("identityfile") if identities: - return check_remote_repo_credentials_ssh(url, identities[0]) + _identity_file = identities[0] + with suppress(InvalidRepoCredentialsError): + _private_key = _read_private_key(_identity_file) + return _get_repo_creds_and_default_branch_ssh(url, _identity_file, _private_key) # token from gh config if os.path.exists(gh_config_path): with open(gh_config_path, "r") as f: gh_hosts = yaml.load(f, Loader=yaml.FullLoader) - oauth_token = gh_hosts.get(url.host, {}).get("oauth_token") - if oauth_token is not None: - try: - return check_remote_repo_credentials_https(url, oauth_token) - except InvalidRepoCredentialsError: - pass + _oauth_token = gh_hosts.get(url.host, {}).get("oauth_token") + if _oauth_token is not None: + with suppress(InvalidRepoCredentialsError): + return _get_repo_creds_and_default_branch_https(url, _oauth_token) # default user key if os.path.exists(default_ssh_key): - try: - return check_remote_repo_credentials_ssh(url, default_ssh_key) - except InvalidRepoCredentialsError: - pass + with suppress(InvalidRepoCredentialsError): + _private_key = _read_private_key(default_ssh_key) + return _get_repo_creds_and_default_branch_ssh(url, default_ssh_key, _private_key) raise InvalidRepoCredentialsError( "No valid default Git credentials found. Pass valid `--token` or `--git-identity`." ) -def check_remote_repo_credentials_https(url: GitRepoURL, oauth_token: str) -> RemoteRepoCreds: +def _get_repo_creds_and_default_branch_ssh( + url: GitRepoURL, identity_file: PathLike, private_key: str +) -> tuple[RemoteRepoCreds, Optional[str]]: + _url = url.as_ssh() try: - git.cmd.Git().ls_remote(url.as_https(oauth_token), env=no_prompt_env) - except GitCommandError: - masked = len(oauth_token[:-4]) * "*" + oauth_token[-4:] - raise InvalidRepoCredentialsError( - f"Can't access `{url.as_https()}` using the `{masked}` token" - ) - return RemoteRepoCreds( - clone_url=url.as_https(), - oauth_token=oauth_token, + default_branch = _get_repo_default_branch(_url, make_git_env(identity_file=identity_file)) + except GitCommandError as e: + message = f"Cannot access `{_url}` using the `{identity_file}` private SSH key" + raise InvalidRepoCredentialsError(message) from e + creds = RemoteRepoCreds( + clone_url=_url, + private_key=private_key, + oauth_token=None, + ) + return creds, default_branch + + +def _get_repo_creds_and_default_branch_https( + url: GitRepoURL, oauth_token: Optional[str] = None +) -> tuple[RemoteRepoCreds, Optional[str]]: + _url = url.as_https() + try: + default_branch = _get_repo_default_branch(url.as_https(oauth_token), make_git_env()) + except GitCommandError as e: + message = f"Cannot access `{_url}`" + if oauth_token is not None: + masked_token = len(oauth_token[:-4]) * "*" + oauth_token[-4:] + message = f"{message} using the `{masked_token}` token" + raise InvalidRepoCredentialsError(message) from e + creds = RemoteRepoCreds( + clone_url=_url, private_key=None, + oauth_token=oauth_token, ) + return creds, default_branch -def check_remote_repo_credentials_ssh(url: GitRepoURL, identity_file: PathLike) -> RemoteRepoCreds: +def _get_repo_default_branch(url: str, env: dict[str, str]) -> Optional[str]: + # output example: "ref: refs/heads/dev\tHEAD\n545344f77c0df78367085952a97fc3a058eb4c65\tHEAD" + output: str = git.cmd.Git().ls_remote("--symref", url, "HEAD", env=env) + for line in output.splitlines(): + # line format: ` TAB LF` + oid, _, ref = line.partition("\t") + if oid.startswith("ref:") and ref == "HEAD": + return oid.rsplit("/", maxsplit=1)[-1] + return None + + +def _read_private_key(identity_file: PathLike) -> str: + identity_file = Path(identity_file).expanduser().resolve() if not Path(identity_file).exists(): - raise InvalidRepoCredentialsError(f"The {identity_file} private SSH key doesn't exist") + raise InvalidRepoCredentialsError(f"The `{identity_file}` private SSH key doesn't exist") if not os.access(identity_file, os.R_OK): - raise InvalidRepoCredentialsError(f"Can't access the {identity_file} private SSH key") + raise InvalidRepoCredentialsError(f"Cannot access the `{identity_file}` private SSH key") if not try_ssh_key_passphrase(identity_file): raise InvalidRepoCredentialsError( f"Cannot use the `{identity_file}` private SSH key. " "Ensure that it is valid and passphrase-free" ) - with open(identity_file, "r") as f: - private_key = f.read() - - try: - git.cmd.Git().ls_remote( - url.as_ssh(), env=dict(GIT_SSH_COMMAND=make_ssh_command_for_git(identity_file)) - ) - except GitCommandError: - raise InvalidRepoCredentialsError( - f"Can't access `{url.as_ssh()}` using the `{identity_file}` private SSH key" - ) - - return RemoteRepoCreds( - clone_url=url.as_ssh(), - private_key=private_key, - oauth_token=None, - ) - - -def get_default_branch(remote_url: str) -> Optional[str]: - """ - Get the default branch of a remote Git repository. - """ - try: - output = git.cmd.Git().ls_remote("--symref", remote_url, "HEAD", env=no_prompt_env) - for line in output.splitlines(): - if line.startswith("ref:"): - return line.split()[1].split("/")[-1] - except Exception as e: - logger.debug("Failed to get remote repo default branch: %s", repr(e)) - return None + with open(identity_file, "r") as file: + return file.read() +# Used for `config.yml` only, remove it with `repos` in `config.yml` def load_repo(config: RepoConfig) -> Union[RemoteRepo, LocalRepo]: if config.repo_type == "remote": return RemoteRepo(repo_id=config.repo_id, local_repo_dir=config.path) diff --git a/src/dstack/_internal/utils/ssh.py b/src/dstack/_internal/utils/ssh.py index 79cfa5ffd..f0dafc7f7 100644 --- a/src/dstack/_internal/utils/ssh.py +++ b/src/dstack/_internal/utils/ssh.py @@ -50,6 +50,13 @@ def make_ssh_command_for_git(identity_file: PathLike) -> str: ) +def make_git_env(*, identity_file: Optional[PathLike] = None) -> dict[str, str]: + env: dict[str, str] = {"GIT_TERMINAL_PROMPT": "0"} + if identity_file is not None: + env["GIT_SSH_COMMAND"] = make_ssh_command_for_git(identity_file) + return env + + def try_ssh_key_passphrase(identity_file: PathLike, passphrase: str = "") -> bool: ssh_keygen = find_ssh_util("ssh-keygen") if ssh_keygen is None: diff --git a/src/dstack/api/_public/repos.py b/src/dstack/api/_public/repos.py index c05f80de7..d7519bbcf 100644 --- a/src/dstack/api/_public/repos.py +++ b/src/dstack/api/_public/repos.py @@ -1,15 +1,21 @@ from pathlib import Path -from typing import Optional, Union +from typing import Literal, Optional, Union, overload from git import InvalidGitRepositoryError from dstack._internal.core.errors import ConfigurationError, ResourceNotExistsError -from dstack._internal.core.models.repos import LocalRepo, RemoteRepo +from dstack._internal.core.models.repos import ( + LocalRepo, + RemoteRepo, + RemoteRepoCreds, + RepoHead, + RepoHeadWithCreds, +) from dstack._internal.core.models.repos.base import Repo, RepoType from dstack._internal.core.services.configs import ConfigManager from dstack._internal.core.services.repos import ( InvalidRepoCredentialsError, - get_local_repo_credentials, + get_repo_creds_and_default_branch, load_repo, ) from dstack._internal.utils.crypto import generate_rsa_key_pair @@ -34,6 +40,7 @@ def init( repo: Repo, git_identity_file: Optional[PathLike] = None, oauth_token: Optional[str] = None, + creds: Optional[RemoteRepoCreds] = None, ): """ Initializes the repo and configures its credentials in the project. @@ -65,12 +72,13 @@ def init( repo: The repo to initialize. git_identity_file: The private SSH key path for accessing the remote repo. oauth_token: The GitHub OAuth token to access the remote repo. + creds: Optional prepared repo credentials. If specified, both `git_identity_file` + and `oauth_token` are ignored. """ - creds = None - if isinstance(repo, RemoteRepo): + if creds is None and isinstance(repo, RemoteRepo): assert repo.repo_url is not None try: - creds = get_local_repo_credentials( + creds, _ = get_repo_creds_and_default_branch( repo_url=repo.repo_url, identity_file=git_identity_file, oauth_token=oauth_token, @@ -175,6 +183,33 @@ def _is_initialized_by_user(self, repo: RemoteRepo) -> bool: # TODO: add an API method with the same logic returning a bool value? return repo_head.repo_creds is not None + @overload + def get(self, repo_id: str, *, with_creds: Literal[False] = False) -> Optional[RepoHead]: ... + + @overload + def get(self, repo_id: str, *, with_creds: Literal[True]) -> Optional[RepoHeadWithCreds]: ... + + def get( + self, repo_id: str, *, with_creds: bool = False + ) -> Optional[Union[RepoHead, RepoHeadWithCreds]]: + """ + Returns the repo by `repo_id` + + Args: + repo_id: The repo ID. + with_creds: include repo credentials in the response. + + Returns: + The repo or `None` if the repo is not found. + """ + method = self._api_client.repos.get + if with_creds: + method = self._api_client.repos.get_with_creds + try: + return method(self._project, repo_id) + except ResourceNotExistsError: + return None + def get_ssh_keypair(key_path: Optional[PathLike], dstack_key_path: Path) -> str: """Returns a path to the private key"""