Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/dstack/_internal/cli/commands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from dstack._internal.cli.commands import BaseCommand
from dstack._internal.cli.services.repos import (
get_repo_from_dir,
get_repo_from_url,
is_git_repo_url,
register_init_repo_args,
)
from dstack._internal.cli.utils.common import configure_logging, confirm_ask, console, warn
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.repos.remote import RemoteRepo
from dstack._internal.core.services.configs import ConfigManager
from dstack.api import Client

Expand Down Expand Up @@ -101,7 +101,7 @@ def _command(self, args: argparse.Namespace):
if repo_url is not None:
# Dummy repo branch to avoid autodetection that fails on private repos.
# We don't need branch/hash for repo_id anyway.
repo = get_repo_from_url(repo_url, repo_branch="master")
repo = RemoteRepo.from_url(repo_url, repo_branch="master")
elif repo_path is not None:
repo = get_repo_from_dir(repo_path, local=local)
else:
Expand Down
72 changes: 57 additions & 15 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from dstack._internal.cli.services.profile import apply_profile_args, register_profile_args
from dstack._internal.cli.services.repos import (
get_repo_from_dir,
get_repo_from_url,
init_default_virtual_repo,
is_git_repo_url,
register_init_repo_args,
Expand All @@ -43,13 +42,19 @@
ServiceConfiguration,
TaskConfiguration,
)
from dstack._internal.core.models.repos import RepoHeadWithCreds
from dstack._internal.core.models.repos.base import Repo
from dstack._internal.core.models.repos.local import LocalRepo
from dstack._internal.core.models.repos.remote import RemoteRepo, RemoteRepoCreds
from dstack._internal.core.models.resources import CPUSpec
from dstack._internal.core.models.runs import JobStatus, JobSubmission, RunSpec, RunStatus
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.core.services.diff import diff_models
from dstack._internal.core.services.repos import load_repo
from dstack._internal.core.services.repos import (
InvalidRepoCredentialsError,
get_repo_creds_and_default_branch,
load_repo,
)
from dstack._internal.utils.common import local_time
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
from dstack._internal.utils.logging import get_logger
Expand Down Expand Up @@ -524,15 +529,17 @@ def get_repo(
return init_default_virtual_repo(api=self.api)

repo: Optional[Repo] = None
repo_head: Optional[RepoHeadWithCreds] = None
repo_branch: Optional[str] = configurator_args.repo_branch
repo_hash: Optional[str] = configurator_args.repo_hash
repo_creds: Optional[RemoteRepoCreds] = None
git_identity_file: Optional[str] = configurator_args.git_identity_file
git_private_key: Optional[str] = None
oauth_token: Optional[str] = configurator_args.gh_token
# Should we (re)initialize the repo?
# If any Git credentials provided, we reinitialize the repo, as the user may have provided
# updated credentials.
init = (
configurator_args.git_identity_file is not None
or configurator_args.gh_token is not None
)
init = git_identity_file is not None or oauth_token is not None

url: Optional[str] = None
local_path: Optional[Path] = None
Expand Down Expand Up @@ -565,15 +572,15 @@ def get_repo(
local_path = Path.cwd()
legacy_local_path = True
if url:
repo = get_repo_from_url(repo_url=url, repo_branch=repo_branch, repo_hash=repo_hash)
if not self.api.repos.is_initialized(repo, by_user=True):
init = True
# "master" is a dummy value, we'll fetch the actual default branch later
repo = RemoteRepo.from_url(repo_url=url, repo_branch="master")
repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
elif local_path:
if legacy_local_path:
if repo_config := config_manager.get_repo_config(local_path):
repo = load_repo(repo_config)
# allow users with legacy configurations use shared repo creds
if self.api.repos.is_initialized(repo, by_user=False):
repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
if repo_head is not None:
warn(
"The repo is not specified but found and will be used in the run\n"
"Future versions will not load repos automatically\n"
Expand All @@ -600,20 +607,55 @@ def get_repo(
)
local: bool = configurator_args.local
repo = get_repo_from_dir(local_path, local=local)
if not self.api.repos.is_initialized(repo, by_user=True):
init = True
repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
if isinstance(repo, RemoteRepo):
repo_branch = repo.run_repo_data.repo_branch
repo_hash = repo.run_repo_data.repo_hash
else:
assert False, "should not reach here"

if repo is None:
return init_default_virtual_repo(api=self.api)

if isinstance(repo, RemoteRepo):
assert repo.repo_url is not None

if repo_head is not None and repo_head.repo_creds is not None:
if git_identity_file is None and oauth_token is None:
git_private_key = repo_head.repo_creds.private_key
oauth_token = repo_head.repo_creds.oauth_token
else:
init = True

try:
repo_creds, default_repo_branch = get_repo_creds_and_default_branch(
repo_url=repo.repo_url,
identity_file=git_identity_file,
private_key=git_private_key,
oauth_token=oauth_token,
)
except InvalidRepoCredentialsError as e:
raise CLIError(*e.args) from e

if repo_branch is None and repo_hash is None:
repo_branch = default_repo_branch
if repo_branch is None:
raise CLIError(
"Failed to automatically detect remote repo branch."
" Specify branch or hash."
)
repo = RemoteRepo.from_url(
repo_url=repo.repo_url, repo_branch=repo_branch, repo_hash=repo_hash
)

if init:
self.api.repos.init(
repo=repo,
git_identity_file=configurator_args.git_identity_file,
oauth_token=configurator_args.gh_token,
git_identity_file=git_identity_file,
oauth_token=oauth_token,
creds=repo_creds,
)

if isinstance(repo, LocalRepo):
warn(
f"{repo.repo_dir} is a local repo\n"
Expand Down
19 changes: 1 addition & 18 deletions src/dstack/_internal/cli/services/repos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse
from typing import Literal, Optional, Union, overload
from typing import Literal, Union, overload

import git

Expand All @@ -8,7 +8,6 @@
from dstack._internal.core.models.repos.local import LocalRepo
from dstack._internal.core.models.repos.remote import GitRepoURL, RemoteRepo, RepoError
from dstack._internal.core.models.repos.virtual import VirtualRepo
from dstack._internal.core.services.repos import get_default_branch
from dstack._internal.utils.path import PathLike
from dstack.api._public import Client

Expand Down Expand Up @@ -43,22 +42,6 @@ def init_default_virtual_repo(api: Client) -> VirtualRepo:
return repo


def get_repo_from_url(
repo_url: str, repo_branch: Optional[str] = None, repo_hash: Optional[str] = None
) -> RemoteRepo:
if repo_branch is None and repo_hash is None:
repo_branch = get_default_branch(repo_url)
if repo_branch is None:
raise CLIError(
"Failed to automatically detect remote repo branch. Specify branch or hash."
)
return RemoteRepo.from_url(
repo_url=repo_url,
repo_branch=repo_branch,
repo_hash=repo_hash,
)


@overload
def get_repo_from_dir(repo_dir: PathLike, local: Literal[False] = False) -> RemoteRepo: ...

Expand Down
Loading
Loading