diff --git a/cli/dstack/_internal/core/repo/remote.py b/cli/dstack/_internal/core/repo/remote.py index 291fdc11a..28ae07765 100644 --- a/cli/dstack/_internal/core/repo/remote.py +++ b/cli/dstack/_internal/core/repo/remote.py @@ -1,11 +1,14 @@ +import io import os +import subprocess import tempfile +import time from pathlib import Path from typing import Any, BinaryIO, Dict, Optional import git import giturlparse -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import Literal from dstack._internal.core.repo import RepoProtocol @@ -38,7 +41,7 @@ class RemoteRepoData(RepoData, RemoteRepoInfo): repo_type: Literal["remote"] = "remote" repo_branch: Optional[str] = None repo_hash: Optional[str] = None - repo_diff: Optional[str] = None + repo_diff: Optional[str] = Field(None, exclude=True) repo_config_name: Optional[str] = None repo_config_email: Optional[str] = None @@ -113,13 +116,9 @@ def __init__( repo_data = RemoteRepoData.from_url(self.repo_url, parse_ssh_config=True) repo_data.repo_branch = tracking_branch.remote_head repo_data.repo_hash = tracking_branch.commit.hexsha - repo_data.repo_diff = repo.git.diff(repo_data.repo_hash) repo_data.repo_config_name = repo.config_reader().get_value("user", "name") repo_data.repo_config_email = repo.config_reader().get_value("user", "email") - diffs = [repo_data.repo_diff] - for filename in repo.untracked_files: - diffs.append(_add_patch(local_repo_dir, filename)) - repo_data.repo_diff = "\n".join([d for d in diffs if d]) + repo_data.repo_diff = _repo_diff_verbose(repo, repo_data.repo_hash) elif self.repo_url is not None: repo_data = RemoteRepoData.from_url(self.repo_url, parse_ssh_config=True) elif repo_data is None: @@ -165,7 +164,63 @@ def _clone_remote_repo( # todo checkout branch/hash -def _add_patch(repo_dir: PathLike, filename: str) -> str: - return git.cmd.Git(repo_dir).diff( - "/dev/null", filename, no_index=True, binary=True, with_exceptions=False - ) +class _DiffCollector: + def __init__(self, warning_time: float, delay: float = 5): + self.warning_time = warning_time + self.delay = delay + self.warned = False + self.start_time = time.monotonic() + self.buffer = io.StringIO() + + def timeout(self): + now = time.monotonic() + if not self.warned and now > self.start_time + self.warning_time: + print( + "Provisioning is taking longer than usual, possibly because of having too many or large local " + "files that haven’t been pushed to Git. Tip: Exclude unnecessary files from provisioning " + "by using the `.gitignore` file." + ) + self.warned = True + return ( + self.delay + if self.warned + else min(self.delay, self.start_time + self.warning_time - now) + ) + + def write(self, v: bytes): + self.buffer.write(v.decode()) + + def get(self) -> str: + if self.warned: + print() + return self.buffer.getvalue() + + +def _interactive_git_proc( + proc: git.Git.AutoInterrupt, collector: _DiffCollector, ignore_status: bool = False +): + while True: + try: + stdout, stderr = proc.communicate(timeout=collector.timeout()) + if not ignore_status and proc.poll() != 0: + raise git.GitCommandError(proc.args, proc.poll(), stderr) + collector.write(stdout) + return + except subprocess.TimeoutExpired: + continue + + +def _repo_diff_verbose(repo: git.Repo, repo_hash: str, warning_time: float = 5) -> str: + collector = _DiffCollector(warning_time) + try: + _interactive_git_proc(repo.git.diff(repo_hash, as_process=True), collector) + for filename in repo.untracked_files: + _interactive_git_proc( + repo.git.diff("/dev/null", filename, no_index=True, binary=True, as_process=True), + collector, + ignore_status=True, + ) + return collector.get() + except KeyboardInterrupt: + print("\nAborted") + exit(1)