Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions scripts/codex_package/archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@
import tarfile
import tempfile
import zipfile
from collections.abc import Callable
from pathlib import Path

from .targets import REPO_ROOT


ZSTD_DOTSLASH = REPO_ROOT / ".github" / "workflows" / "zstd"


def write_archive(package_dir: Path, archive_path: Path, *, force: bool) -> None:
if is_relative_to(archive_path, package_dir):
Expand Down Expand Up @@ -63,14 +69,33 @@ def write_tar_archive(package_dir: Path, archive_path: Path, *, mode: str) -> No


def write_tar_zst_archive(package_dir: Path, archive_path: Path) -> None:
zstd = shutil.which("zstd")
if zstd is None:
raise RuntimeError("zstd is required to write .tar.zst archives.")
zstd_command = resolve_zstd_command()

with tempfile.TemporaryDirectory(prefix="codex-package-archive-") as temp_dir_str:
tar_path = Path(temp_dir_str) / "package.tar"
write_tar_archive(package_dir, tar_path, mode="w")
subprocess.check_call([zstd, "-T0", "-19", "-f", str(tar_path), "-o", str(archive_path)])
subprocess.check_call(
[*zstd_command, "-T0", "-19", "-f", str(tar_path), "-o", str(archive_path)]
)


def resolve_zstd_command(
*,
dotslash_manifest: Path = ZSTD_DOTSLASH,
which: Callable[[str], str | None] = shutil.which,
) -> list[str]:
zstd = which("zstd")
if zstd is not None:
return [zstd]

dotslash = which("dotslash")
if dotslash is not None and dotslash_manifest.is_file():
return [dotslash, str(dotslash_manifest)]

raise RuntimeError(
"zstd is required to write .tar.zst archives. Install zstd, or install "
f"DotSlash so the repository wrapper can run: {dotslash_manifest}"
)


def write_zip_archive(package_dir: Path, archive_path: Path) -> None:
Expand Down
47 changes: 47 additions & 0 deletions scripts/codex_package/test_archive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3

from __future__ import annotations

from pathlib import Path
import sys
import tempfile
import unittest

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from codex_package.archive import resolve_zstd_command


class ResolveZstdCommandTest(unittest.TestCase):
def test_prefers_zstd_from_path(self) -> None:
def which(name: str) -> str | None:
return {"zstd": "/usr/bin/zstd", "dotslash": "/usr/bin/dotslash"}.get(name)

self.assertEqual(resolve_zstd_command(which=which), ["/usr/bin/zstd"])

def test_falls_back_to_dotslash_manifest(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
manifest = Path(temp_dir) / "zstd"
manifest.write_text("#!/usr/bin/env dotslash\n{}\n", encoding="utf-8")

def which(name: str) -> str | None:
return {"dotslash": "/usr/bin/dotslash"}.get(name)

self.assertEqual(
resolve_zstd_command(dotslash_manifest=manifest, which=which),
["/usr/bin/dotslash", str(manifest)],
)

def test_errors_when_no_zstd_or_dotslash_manifest_is_available(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
missing_manifest = Path(temp_dir) / "zstd"

with self.assertRaisesRegex(RuntimeError, "zstd is required"):
resolve_zstd_command(
dotslash_manifest=missing_manifest,
which=lambda _name: None,
)


if __name__ == "__main__":
unittest.main()
Loading