diff --git a/scripts/codex_package/archive.py b/scripts/codex_package/archive.py index fe09c0a4f746..08944a650f0e 100644 --- a/scripts/codex_package/archive.py +++ b/scripts/codex_package/archive.py @@ -5,8 +5,14 @@ import tarfile import tempfile import zipfile +from collections.abc import Callable from pathlib import Path +from .targets import REPO_ROOT + + +ZSTD_DOTSLASH = REPO_ROOT / ".github" / "workflows" / "zstd" + def write_archive(package_dir: Path, archive_path: Path, *, force: bool) -> None: if is_relative_to(archive_path, package_dir): @@ -63,14 +69,33 @@ def write_tar_archive(package_dir: Path, archive_path: Path, *, mode: str) -> No def write_tar_zst_archive(package_dir: Path, archive_path: Path) -> None: - zstd = shutil.which("zstd") - if zstd is None: - raise RuntimeError("zstd is required to write .tar.zst archives.") + zstd_command = resolve_zstd_command() with tempfile.TemporaryDirectory(prefix="codex-package-archive-") as temp_dir_str: tar_path = Path(temp_dir_str) / "package.tar" write_tar_archive(package_dir, tar_path, mode="w") - subprocess.check_call([zstd, "-T0", "-19", "-f", str(tar_path), "-o", str(archive_path)]) + subprocess.check_call( + [*zstd_command, "-T0", "-19", "-f", str(tar_path), "-o", str(archive_path)] + ) + + +def resolve_zstd_command( + *, + dotslash_manifest: Path = ZSTD_DOTSLASH, + which: Callable[[str], str | None] = shutil.which, +) -> list[str]: + zstd = which("zstd") + if zstd is not None: + return [zstd] + + dotslash = which("dotslash") + if dotslash is not None and dotslash_manifest.is_file(): + return [dotslash, str(dotslash_manifest)] + + raise RuntimeError( + "zstd is required to write .tar.zst archives. Install zstd, or install " + f"DotSlash so the repository wrapper can run: {dotslash_manifest}" + ) def write_zip_archive(package_dir: Path, archive_path: Path) -> None: diff --git a/scripts/codex_package/test_archive.py b/scripts/codex_package/test_archive.py new file mode 100644 index 000000000000..cade7d12f676 --- /dev/null +++ b/scripts/codex_package/test_archive.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +from pathlib import Path +import sys +import tempfile +import unittest + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from codex_package.archive import resolve_zstd_command + + +class ResolveZstdCommandTest(unittest.TestCase): + def test_prefers_zstd_from_path(self) -> None: + def which(name: str) -> str | None: + return {"zstd": "/usr/bin/zstd", "dotslash": "/usr/bin/dotslash"}.get(name) + + self.assertEqual(resolve_zstd_command(which=which), ["/usr/bin/zstd"]) + + def test_falls_back_to_dotslash_manifest(self) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + manifest = Path(temp_dir) / "zstd" + manifest.write_text("#!/usr/bin/env dotslash\n{}\n", encoding="utf-8") + + def which(name: str) -> str | None: + return {"dotslash": "/usr/bin/dotslash"}.get(name) + + self.assertEqual( + resolve_zstd_command(dotslash_manifest=manifest, which=which), + ["/usr/bin/dotslash", str(manifest)], + ) + + def test_errors_when_no_zstd_or_dotslash_manifest_is_available(self) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + missing_manifest = Path(temp_dir) / "zstd" + + with self.assertRaisesRegex(RuntimeError, "zstd is required"): + resolve_zstd_command( + dotslash_manifest=missing_manifest, + which=lambda _name: None, + ) + + +if __name__ == "__main__": + unittest.main()