Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/dstack/_internal/core/models/repos/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class RemoteRepoInfo(
class RemoteRunRepoData(RemoteRepoInfo):
repo_branch: Optional[str] = None
repo_hash: Optional[str] = None
repo_diff: Annotated[Optional[str], Field(exclude=True)] = None
repo_diff: Annotated[Optional[bytes], Field(exclude=True)] = None
repo_config_name: Optional[str] = None
repo_config_email: Optional[str] = None

Expand Down Expand Up @@ -183,13 +183,15 @@ def __init__(
def has_code_to_write(self) -> bool:
# repo_diff is:
# * None for RemoteRepo.from_url()
# * an empty string for RemoteRepo.from_dir() if there are no changes ("clean" state)
# * a non-empty string for RemoteRepo.from_dir() if there are changes ("dirty" state)
# * empty bytes for RemoteRepo.from_dir() if there are no changes ("clean" state)
# and untracked files
# * non-empty bytes for RemoteRepo.from_dir() if there are changes ("dirty" state)
# and/or untracked files
return bool(self.run_repo_data.repo_diff)

def write_code_file(self, fp: BinaryIO) -> str:
if self.run_repo_data.repo_diff is not None:
fp.write(self.run_repo_data.repo_diff.encode())
fp.write(self.run_repo_data.repo_diff)
return get_sha256(fp)

def get_repo_info(self) -> RemoteRepoInfo:
Expand Down Expand Up @@ -238,7 +240,7 @@ def __init__(self, warning_time: float, delay: float = 5):
self.delay = delay
self.warned = False
self.start_time = time.monotonic()
self.buffer = io.StringIO()
self.buffer = io.BytesIO()

def timeout(self):
now = time.monotonic()
Expand All @@ -256,9 +258,9 @@ def timeout(self):
)

def write(self, v: bytes):
self.buffer.write(v.decode())
self.buffer.write(v)

def get(self) -> str:
def get(self) -> bytes:
if self.warned:
print()
return self.buffer.getvalue()
Expand Down Expand Up @@ -366,10 +368,10 @@ def _interactive_git_proc(
continue


def _repo_diff_verbose(repo: git.Repo, repo_hash: str, warning_time: float = 5) -> str:
def _repo_diff_verbose(repo: git.Repo, repo_hash: str, warning_time: float = 5) -> bytes:
collector = _DiffCollector(warning_time)
try:
_interactive_git_proc(repo.git.diff(repo_hash, as_process=True), collector)
_interactive_git_proc(repo.git.diff(repo_hash, binary=True, as_process=True), collector)
for filename in repo.untracked_files:
_interactive_git_proc(
repo.git.diff("/dev/null", filename, no_index=True, binary=True, as_process=True),
Expand Down
Loading