Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 66 additions & 11 deletions cli/dstack/_internal/core/repo/remote.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import io
import os
import subprocess
import tempfile
import time
from pathlib import Path
from typing import Any, BinaryIO, Dict, Optional

import git
import giturlparse
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import Literal

from dstack._internal.core.repo import RepoProtocol
Expand Down Expand Up @@ -38,7 +41,7 @@ class RemoteRepoData(RepoData, RemoteRepoInfo):
repo_type: Literal["remote"] = "remote"
repo_branch: Optional[str] = None
repo_hash: Optional[str] = None
repo_diff: Optional[str] = None
repo_diff: Optional[str] = Field(None, exclude=True)
repo_config_name: Optional[str] = None
repo_config_email: Optional[str] = None

Expand Down Expand Up @@ -113,13 +116,9 @@ def __init__(
repo_data = RemoteRepoData.from_url(self.repo_url, parse_ssh_config=True)
repo_data.repo_branch = tracking_branch.remote_head
repo_data.repo_hash = tracking_branch.commit.hexsha
repo_data.repo_diff = repo.git.diff(repo_data.repo_hash)
repo_data.repo_config_name = repo.config_reader().get_value("user", "name")
repo_data.repo_config_email = repo.config_reader().get_value("user", "email")
diffs = [repo_data.repo_diff]
for filename in repo.untracked_files:
diffs.append(_add_patch(local_repo_dir, filename))
repo_data.repo_diff = "\n".join([d for d in diffs if d])
repo_data.repo_diff = _repo_diff_verbose(repo, repo_data.repo_hash)
elif self.repo_url is not None:
repo_data = RemoteRepoData.from_url(self.repo_url, parse_ssh_config=True)
elif repo_data is None:
Expand Down Expand Up @@ -165,7 +164,63 @@ def _clone_remote_repo(
# todo checkout branch/hash


def _add_patch(repo_dir: PathLike, filename: str) -> str:
return git.cmd.Git(repo_dir).diff(
"/dev/null", filename, no_index=True, binary=True, with_exceptions=False
)
class _DiffCollector:
def __init__(self, warning_time: float, delay: float = 5):
self.warning_time = warning_time
self.delay = delay
self.warned = False
self.start_time = time.monotonic()
self.buffer = io.StringIO()

def timeout(self):
now = time.monotonic()
if not self.warned and now > self.start_time + self.warning_time:
print(
"Provisioning is taking longer than usual, possibly because of having too many or large local "
"files that haven’t been pushed to Git. Tip: Exclude unnecessary files from provisioning "
"by using the `.gitignore` file."
)
self.warned = True
return (
self.delay
if self.warned
else min(self.delay, self.start_time + self.warning_time - now)
)

def write(self, v: bytes):
self.buffer.write(v.decode())

def get(self) -> str:
if self.warned:
print()
return self.buffer.getvalue()


def _interactive_git_proc(
proc: git.Git.AutoInterrupt, collector: _DiffCollector, ignore_status: bool = False
):
while True:
try:
stdout, stderr = proc.communicate(timeout=collector.timeout())
if not ignore_status and proc.poll() != 0:
raise git.GitCommandError(proc.args, proc.poll(), stderr)
collector.write(stdout)
return
except subprocess.TimeoutExpired:
continue


def _repo_diff_verbose(repo: git.Repo, repo_hash: str, warning_time: float = 5) -> str:
collector = _DiffCollector(warning_time)
try:
_interactive_git_proc(repo.git.diff(repo_hash, as_process=True), collector)
for filename in repo.untracked_files:
_interactive_git_proc(
repo.git.diff("/dev/null", filename, no_index=True, binary=True, as_process=True),
collector,
ignore_status=True,
)
return collector.get()
except KeyboardInterrupt:
print("\nAborted")
exit(1)