Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/dstack/_internal/server/services/runs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,13 @@ def check_can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec):
f"Failed to update fields {changed_spec_fields}."
f" Can only update {updatable_spec_fields}."
)
_check_can_update_configuration(current_run_spec.configuration, new_run_spec.configuration)
# We don't allow update if the order of archives has been changed, as even if the archives
# are the same (the same id => hash => content and the same container path), the order of
# unpacking matters when one path is a subpath of another.
ignore_files = current_run_spec.file_archives == new_run_spec.file_archives
_check_can_update_configuration(
current_run_spec.configuration, new_run_spec.configuration, ignore_files
)


def can_update_run_spec(current_run_spec: RunSpec, new_run_spec: RunSpec) -> bool:
Expand Down Expand Up @@ -159,7 +165,7 @@ def check_run_spec_requires_instance_mounts(run_spec: RunSpec) -> bool:


def _check_can_update_configuration(
current: AnyRunConfiguration, new: AnyRunConfiguration
current: AnyRunConfiguration, new: AnyRunConfiguration, ignore_files: bool
) -> None:
if current.type != new.type:
raise ServerClientError(
Expand All @@ -168,6 +174,13 @@ def _check_can_update_configuration(
updatable_fields = _CONF_UPDATABLE_FIELDS + _TYPE_SPECIFIC_CONF_UPDATABLE_FIELDS.get(
new.type, []
)
if ignore_files:
# We ignore files diff if the file archives are the same. It allows the user to move
# local files/dirs as long as their name(*), content, and the container path stay the same.
# (*) We could also ignore local name changes if the names didn't change in the tarballs.
# Currently, the client preserves the original file/dir name it the tarball, but it could
# use some generic names like "file"/"directory" instead.
updatable_fields.append("files")
diff = diff_models(current, new)
changed_fields = list(diff.keys())
for key in changed_fields:
Expand Down
29 changes: 13 additions & 16 deletions src/dstack/api/_public/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,18 @@ def get_run_plan(
if repo_dir is None and configuration.repos:
repo_dir = configuration.repos[0].path

self._validate_configuration_files(configuration, configuration_path)
file_archives: list[FileArchiveMapping] = []
for file_mapping in configuration.files:
with tempfile.TemporaryFile("w+b") as fp:
try:
archive_hash = create_file_archive(file_mapping.local_path, fp)
except OSError as e:
raise ClientError(f"failed to archive '{file_mapping.local_path}': {e}") from e
fp.seek(0)
archive = self._api_client.files.upload_archive(hash=archive_hash, fp=fp)
file_archives.append(FileArchiveMapping(id=archive.id, path=file_mapping.path))

if ssh_identity_file:
ssh_key_pub = Path(ssh_identity_file).with_suffix(".pub").read_text()
else:
Expand All @@ -513,6 +525,7 @@ def get_run_plan(
repo_data=repo.run_repo_data,
repo_code_hash=repo_code_hash,
repo_dir=repo_dir,
file_archives=file_archives,
# Server doesn't use this field since 0.19.27, but we still send it for compatibility
# with older servers
working_dir=configuration.working_dir,
Expand Down Expand Up @@ -549,22 +562,6 @@ def apply_plan(
# TODO handle multiple jobs
ports_lock = _reserve_ports(run_plan.job_plans[0].job_spec)

run_spec = run_plan.run_spec
configuration = run_spec.configuration

self._validate_configuration_files(configuration, run_spec.configuration_path)
for file_mapping in configuration.files:
with tempfile.TemporaryFile("w+b") as fp:
try:
archive_hash = create_file_archive(file_mapping.local_path, fp)
except OSError as e:
raise ClientError(f"failed to archive '{file_mapping.local_path}': {e}") from e
fp.seek(0)
archive = self._api_client.files.upload_archive(hash=archive_hash, fp=fp)
run_spec.file_archives.append(
FileArchiveMapping(id=archive.id, path=file_mapping.path)
)

if repo is None:
repo = VirtualRepo()
else:
Expand Down