From f41615111fd488c8065c8283a922a381af32ddef Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Fri, 8 May 2026 14:21:02 +0200 Subject: [PATCH 1/8] Fix auth from env overlapping for all subcommands --- README.md | 19 ++-- contree_cli/agent.md | 9 +- contree_cli/arguments.py | 20 ++--- contree_cli/cli/auth.py | 90 +++++++++++++------ contree_cli/cli/images.py | 51 +++++++++-- contree_cli/cli/ps.py | 144 ++++++++++++++++--------------- contree_cli/client.py | 43 ++++++++- contree_cli/config.py | 31 +++---- contree_cli/manual.md | 18 ++-- contree_cli/types.py | 1 + docs/commands/auth.md | 37 +++++--- docs/tutorial/configuration.md | 30 +++++-- docs/tutorial/installation.md | 27 +++--- tests/test_auth.py | 153 +++++++++++++++++++++++++++++++-- tests/test_client.py | 72 ++++++++++++++++ tests/test_config.py | 12 +-- tests/test_cp.py | 5 ++ tests/test_file_cmd.py | 3 + tests/test_images.py | 110 ++++++++++++++++++++++++ tests/test_ps.py | 118 +++++++++++++++++++++---- 20 files changed, 782 insertions(+), 211 deletions(-) diff --git a/README.md b/README.md index 8bfcf6a..77ff228 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ contree auth You'll be prompted to enter your API token and project ID. The CLI verifies the token and saves credentials to `~/.config/contree-cli/config.ini`. -If `NEBIUS_API_KEY` and `NEBIUS_AI_PROJECT` environment variables are set and no CLI flags are passed, they are picked up automatically instead of prompting. +If `--token`/`--url`/`--project` flags are omitted, `contree auth` reads `CONTREE_TOKEN` (or `NEBIUS_API_KEY`), `CONTREE_URL`, and `CONTREE_PROJECT` (or `NEBIUS_AI_PROJECT`) from the environment instead of prompting. These variables are read only during registration; runtime commands use the saved profile only. ### 2. Install agent skills (optional) @@ -285,17 +285,24 @@ contree auth switch staging # switch active profile ### Environment variables +Read at runtime (any command): + | Variable | Purpose | |---|---| | `CONTREE_HOME` | Data directory (default `~/.config/contree-cli`) | -| `CONTREE_TOKEN` | API bearer token (overrides config) | -| `CONTREE_URL` | API base URL (overrides config) | -| `CONTREE_PROJECT` | Project ID (overrides config) | -| `CONTREE_PROFILE` | Active profile name | +| `CONTREE_PROFILE` | Active profile name (selects which profile commands use) | | `CONTREE_SESSION` | Explicit session key (for multi-terminal workflows) | | `CONTREE_SESSION_DB` | Path to session SQLite database | -Environment variables take precedence over the config file. `--token` and `--url` flags override everything. +Read only by `contree auth` (registration-time fallbacks for omitted flags): + +| Variable | Used for | +|---|---| +| `CONTREE_TOKEN` / `NEBIUS_API_KEY` | `--token` | +| `CONTREE_URL` | `--url` | +| `CONTREE_PROJECT` / `NEBIUS_AI_PROJECT` | `--project` | + +Credentials come strictly from the saved profile at runtime. `--token`, `--url`, `--project` CLI flags override profile fields for a single invocation. ## Zero Dependencies diff --git a/contree_cli/agent.md b/contree_cli/agent.md index 3b82a18..5539b2d 100644 --- a/contree_cli/agent.md +++ b/contree_cli/agent.md @@ -335,11 +335,14 @@ Data directory: ~/.config/contree-cli/ Environment variables: CONTREE_HOME data directory override - CONTREE_TOKEN API token (overrides config) - CONTREE_URL API URL (overrides config) - CONTREE_PROFILE active profile (overrides config) + CONTREE_PROFILE active profile (selects which profile commands use) CONTREE_SESSION explicit session key +Read only by `contree auth` (registration-time fallbacks): + CONTREE_TOKEN / NEBIUS_API_KEY token when --token is omitted + CONTREE_URL URL when --url is omitted + CONTREE_PROJECT / NEBIUS_AI_PROJECT project ID when --project is omitted + More: contree auth --help All commands diff --git a/contree_cli/arguments.py b/contree_cli/arguments.py index acb7417..9dfb750 100644 --- a/contree_cli/arguments.py +++ b/contree_cli/arguments.py @@ -71,20 +71,18 @@ contree use IMAGE | run -- CMD | file edit PATH | file cp SRC DEST contree tag UUID TAG | kill UUID | cd PATH | session checkout BRANCH -environment variables (advanced overrides; most users can ignore): - CONTREE_TOKEN API bearer token (overrides config file) - CONTREE_URL API base URL (overrides config file) - CONTREE_PROJECT Project ID for IAM auth (overrides config file) - CONTREE_PROFILE Active config profile (overrides config file) +environment variables: + CONTREE_PROFILE Active config profile (selects which profile to use) CONTREE_SESSION Explicit session name (for multi-terminal workflows). If unset, contree auto-generates +<8hex> (derived from profile+ppid+tty); export your own for stable reuse. You can also pass -S/--session instead of exporting env. CONTREE_SESSION_DB Path to session SQLite database -nebius shortcuts (used by `contree auth` as fallback when flags are omitted): - NEBIUS_API_KEY Fallback token for auth registration - NEBIUS_AI_PROJECT Fallback project ID for IAM auth registration +registration-time fallbacks (only read by `contree auth`, not at runtime): + CONTREE_TOKEN / NEBIUS_API_KEY Token used when --token is omitted + CONTREE_URL URL used when --url is omitted + CONTREE_PROJECT / NEBIUS_AI_PROJECT Project ID used when --project is omitted """ DESCRIPTION = """\ @@ -121,7 +119,7 @@ parser.add_argument( *FLAGS["token"], default=None, - help="API token (overrides config and env)", + help="API token (overrides profile for this invocation)", ) @@ -133,12 +131,12 @@ def _strip_trailing_slashes(value: str) -> str: *FLAGS["url"], default=None, type=_strip_trailing_slashes, - help="API base URL (overrides config and env)", + help="API base URL (overrides profile for this invocation)", ) parser.add_argument( *FLAGS["project"], default=None, - help="Project ID (overrides config and env)", + help="Project ID (overrides profile for this invocation)", ) parser.add_argument( *FLAGS["config"], diff --git a/contree_cli/cli/auth.py b/contree_cli/cli/auth.py index e94a814..0aa17ce 100644 --- a/contree_cli/cli/auth.py +++ b/contree_cli/cli/auth.py @@ -8,9 +8,13 @@ iam (default) — bearer token + project ID, default URL provided jwt (legacy) — bearer token only, URL must be specified -Nebius environment variable shortcuts: - NEBIUS_API_KEY used as token fallback during registration - NEBIUS_AI_PROJECT used as project fallback during IAM registration +Environment variable fallbacks during registration: + CONTREE_TOKEN / NEBIUS_API_KEY used when --token is omitted + CONTREE_URL used when --url is omitted + CONTREE_PROJECT / NEBIUS_AI_PROJECT used when --project is omitted (IAM) + +Other commands ignore these variables; only ``contree auth`` reads +them. ``CONTREE_PROFILE`` selects the profile for any command. Subcommands: profiles List saved profiles (* marks active) @@ -22,6 +26,7 @@ import argparse import getpass import hashlib +import json import logging import os from dataclasses import dataclass @@ -35,6 +40,7 @@ logger = logging.getLogger(__name__) PROFILE_CHECK_TIMEOUT = 2.0 PROFILE_CHECK_CONCURRENCY = 4 +REQUIRED_PERMISSION = "list" EPILOG = """\ for coding agents: @@ -169,6 +175,22 @@ def setup_parser(p: argparse.ArgumentParser) -> SetupResult: return cmd_auth, AuthArgs +def _env_fallback(names: tuple[str, ...], *, what: str) -> str | None: + for name in names: + value = os.environ.get(name) + if value: + logger.info("Using %s from %s", what, name) + return value + return None + + +def check_permission(payload: dict[str, object], permission: str) -> bool: + perms = payload.get("permissions") + if not isinstance(perms, dict): + return False + return bool(perms.get(permission)) + + def cmd_auth(args: AuthArgs) -> int | None: cfg = Config() exists = args.profile in cfg @@ -188,18 +210,16 @@ def cmd_auth(args: AuthArgs) -> int | None: print("Aborted.") return 1 - # Token: --token > NEBIUS_API_KEY > interactive prompt - token = args.token + # Token: --token > CONTREE_TOKEN > NEBIUS_API_KEY > interactive prompt + token = args.token or _env_fallback( + ("CONTREE_TOKEN", "NEBIUS_API_KEY"), + what="token", + ) if token is None: - nebius_key = os.environ.get("NEBIUS_API_KEY") - if nebius_key: - logger.info("Using token from NEBIUS_API_KEY") - token = nebius_key - else: - token = getpass.getpass("Token: ") + token = getpass.getpass("Token: ") - # URL: --url > type-specific default > interactive prompt - url = args.url + # URL: --url > CONTREE_URL > type-specific default > interactive prompt + url = args.url or _env_fallback(("CONTREE_URL",), what="URL") if url is None: if args.auth_type == AuthType.IAM: url = Config.DEFAULT_IAM_URL @@ -209,17 +229,15 @@ def cmd_auth(args: AuthArgs) -> int | None: logger.error("URL is required for JWT auth") return 1 - # Project (IAM only): --project > NEBIUS_AI_PROJECT > interactive prompt + # Project (IAM only): --project > CONTREE_PROJECT > NEBIUS_AI_PROJECT > prompt project: str | None = None if args.auth_type == AuthType.IAM: - project = args.project + project = args.project or _env_fallback( + ("CONTREE_PROJECT", "NEBIUS_AI_PROJECT"), + what="project", + ) if project is None: - nebius_project = os.environ.get("NEBIUS_AI_PROJECT") - if nebius_project: - logger.info("Using project from NEBIUS_AI_PROJECT") - project = nebius_project - else: - project = input("Project ID: ").strip() + project = input("Project ID: ").strip() profile = ConfigProfile( name=args.profile, token=token, @@ -235,15 +253,33 @@ def cmd_auth(args: AuthArgs) -> int | None: return 1 try: - client.get("/v1/whoami") + resp = client.get("/v1/whoami") + whoami = json.loads(resp.read() or b"{}") except ApiError as exc: # Logs the API error message, not the token itself. # nosemgrep: python-logger-credential-disclosure logger.error("Token verification failed: %s. Profile not changed.", exc) return 1 + except ValueError as exc: + logger.error("Could not parse /v1/whoami response: %s", exc) + return 1 + + if not check_permission(whoami, REQUIRED_PERMISSION): + project_label = profile.project or profile.url + logger.warning( + "Warning: token is valid but sandboxes are disabled on %s" + " (no %r permission). The profile will be saved but no commands" + " will work until the service is enabled.", + project_label, + REQUIRED_PERMISSION, + ) cfg[args.profile] = profile - logger.info("Token verified and saved to profile %r", args.profile) + logger.info( + "auth accepted, profile %r saved to -> %s", + args.profile, + cfg.path, + ) return None @@ -284,11 +320,17 @@ def check_status( try: resp = client.get("/v1/whoami") - resp.read() + payload = resp.read() except TimeoutError: return profile, "timeout" except Exception: return profile, "error" + try: + whoami = json.loads(payload or b"{}") + except ValueError: + return profile, "error" + if not check_permission(whoami, REQUIRED_PERMISSION): + return profile, "inactive" return profile, "ok" formatter = FORMATTER.get() diff --git a/contree_cli/cli/images.py b/contree_cli/cli/images.py index 707f67d..ed6df5c 100644 --- a/contree_cli/cli/images.py +++ b/contree_cli/cli/images.py @@ -29,7 +29,8 @@ logger = logging.getLogger(__name__) -PAGE_SIZE = 100 +PAGE_SIZE = 500 +LIMIT_DEFAULT = 2000 TERMINAL_STATUSES = frozenset({"SUCCESS", "FAILED", "CANCELLED"}) DOCKER_HUB = "docker.io" @@ -77,6 +78,7 @@ class ImagesArgs(ArgumentsProtocol): all_images: bool = False since: datetime | None = None until: datetime | None = None + limit: int = LIMIT_DEFAULT @classmethod def from_args(cls, ns: argparse.Namespace) -> ImagesArgs: @@ -86,6 +88,7 @@ def from_args(cls, ns: argparse.Namespace) -> ImagesArgs: all_images=getattr(ns, "all_images", False), since=getattr(ns, "since", None), until=getattr(ns, "until", None), + limit=getattr(ns, "limit", LIMIT_DEFAULT), ) @@ -189,6 +192,12 @@ def _add_list_args(p: argparse.ArgumentParser) -> None: type=parse_interval, help="Show images before. " + str(parse_interval.__doc__), ) + p.add_argument( + *FLAGS["limit"], + type=int, + default=LIMIT_DEFAULT, + help="Stop after this many images and warn if more are available", + ) def setup_parser(p: argparse.ArgumentParser) -> SetupResult: @@ -249,7 +258,7 @@ def cmd_images(args: ImagesArgs) -> None: client = CLIENT.get() formatter = FORMATTER.get() - base_params: dict[str, str] = {"limit": str(PAGE_SIZE)} + base_params: dict[str, str] = {} if args.prefix is not None: base_params["tag"] = args.prefix if args.uuid is not None: @@ -262,13 +271,19 @@ def cmd_images(args: ImagesArgs) -> None: base_params["until"] = isoformat_datetime(args.until) offset = 0 - while True: - params = {**base_params, "offset": str(offset)} + emitted = 0 + while emitted < args.limit: + page_size = min(PAGE_SIZE, args.limit - emitted) + params = { + **base_params, + "offset": str(offset), + "limit": str(page_size), + } resp = client.get("/v1/images", params=params) data = json.loads(resp.read()) images = data["images"] if not images: - break + return for image in images: created_at = parse_datetime(image["created_at"]) formatter( @@ -276,9 +291,31 @@ def cmd_images(args: ImagesArgs) -> None: created_at=created_at, tag=image.get("tag") or "", ) - if len(images) < PAGE_SIZE: - break + emitted += len(images) + if len(images) < page_size: + return offset += len(images) + if emitted < args.limit: + logger.info( + "Fetched %d images so far... (press Ctrl+C to break)", + emitted, + ) + + # Hit the limit. Probe one extra record (offset=emitted, limit=1) to + # detect truncation without re-fetching a full page. + probe_params = {**base_params, "offset": str(offset), "limit": "1"} + resp = client.get("/v1/images", params=probe_params) + data = json.loads(resp.read()) + if data.get("images"): + # Flush buffered output (e.g. TableFormatter) before the warning + # so the truncation note appears AFTER the listing on screen. + formatter.flush() + logger.warning( + "Output truncated at --limit=%d images; more results are" + " available. Raise --limit or narrow with" + " --prefix/--since/--until.", + args.limit, + ) def _parse_explicit_tag(ref: str) -> tuple[str, str | None]: diff --git a/contree_cli/cli/ps.py b/contree_cli/cli/ps.py index 2d593a2..cc8bbd6 100644 --- a/contree_cli/cli/ps.py +++ b/contree_cli/cli/ps.py @@ -13,7 +13,6 @@ import itertools import json import logging -from collections.abc import Iterator from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any @@ -114,64 +113,26 @@ def setup_parser(p: argparse.ArgumentParser) -> SetupResult: return cmd_ps, PsArgs -def operations_iterator( - status: str | None = None, - kind: str | None = None, - show_max: int | None = None, - page_size: int = PAGE_SIZE, - since: str | None = None, - until: str | None = None, -) -> Iterator[dict[str, Any]]: - client = CLIENT.get() - assert client is not None, "Client not initialized" - - offset = 0 - counter = 1 - while True: - params: dict[str, str] = { - "limit": str(page_size), - "offset": str(offset), - } - if status: - params["status"] = status - if kind: - params["kind"] = kind - if since: - params["since"] = since - if until: - params["until"] = until - resp = client.get("/v1/operations", params=params) - operations = json.loads(resp.read()) - - if not operations: - break - - offset += len(operations) - for op in operations: - if show_max is not None and counter >= show_max: - logger.warning( - "Reached show_max limit of %d, for see more increase --show-max", - show_max, - ) - return - counter += 1 - yield dict( - uuid=op["uuid"], - status=op["status"], - kind=op["kind"], - created_at=parse_datetime(op["created_at"]), - duration=timedelta(seconds=op["duration"]) - if op.get("duration") is not None - else None, - error=op.get("error") or "", - ) - - if len(operations) < page_size: - break +def _emit_op(formatter: OutputFormatter, op: dict[str, Any], *, quiet: bool) -> None: + row = dict( + uuid=op["uuid"], + status=op["status"], + kind=op["kind"], + created_at=parse_datetime(op["created_at"]), + duration=timedelta(seconds=op["duration"]) + if op.get("duration") is not None + else None, + error=op.get("error") or "", + ) + if quiet: + print(row["uuid"]) + else: + formatter(**row) def cmd_ps(args: PsArgs) -> None: formatter: OutputFormatter = FORMATTER.get() + client = CLIENT.get() status: str | None = None if args.status is not None: @@ -182,21 +143,62 @@ def cmd_ps(args: PsArgs) -> None: elif not args.all: status = "EXECUTING" - since: str | None = None + base_params: dict[str, str] = {} + if status: + base_params["status"] = status + if args.kind: + base_params["kind"] = args.kind if args.since is not None: - since = isoformat_datetime(args.since) - until: str | None = None + base_params["since"] = isoformat_datetime(args.since) if args.until is not None: - until = isoformat_datetime(args.until) - - for op in operations_iterator( - status=status, - kind=args.kind, - show_max=args.show_max, - since=since, - until=until, - ): - if args.quiet: - print(op["uuid"]) - else: - formatter(**op) + base_params["until"] = isoformat_datetime(args.until) + + limit = args.show_max + offset = 0 + emitted = 0 + hit_limit = False + + while limit is None or emitted < limit: + page_size = PAGE_SIZE if limit is None else min(PAGE_SIZE, limit - emitted) + params = { + **base_params, + "offset": str(offset), + "limit": str(page_size), + } + resp = client.get("/v1/operations", params=params) + operations = json.loads(resp.read()) + if not operations: + return + for op in operations: + if limit is not None and emitted >= limit: + hit_limit = True + break + _emit_op(formatter, op, quiet=args.quiet) + emitted += 1 + if hit_limit: + break + if len(operations) < page_size: + return + offset += len(operations) + if limit is None or emitted < limit: + logger.info( + "Fetched %d operations so far... (press Ctrl+C to break)", + emitted, + ) + + if limit is None: + return + + # Hit the limit. Probe one extra record (offset=emitted, limit=1) to + # detect truncation without re-fetching a full page. + probe_params = {**base_params, "offset": str(emitted), "limit": "1"} + resp = client.get("/v1/operations", params=probe_params) + operations = json.loads(resp.read()) + if operations: + formatter.flush() + logger.warning( + "Output truncated at --show-max=%d operations; more results" + " are available. Raise --show-max or filter with" + " --status/--kind/--since/--until.", + limit, + ) diff --git a/contree_cli/client.py b/contree_cli/client.py index 2a7e600..ef9487b 100644 --- a/contree_cli/client.py +++ b/contree_cli/client.py @@ -9,7 +9,7 @@ import sys import time from abc import ABC, abstractmethod -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from importlib.metadata import PackageNotFoundError, version from typing import IO, cast from urllib.parse import urlencode, urlsplit @@ -35,6 +35,37 @@ def _cli_version() -> str: ) +class HeaderFormatter: + """Lazy redactor for HTTP headers, formats only on emit.""" + + SENSITIVE_HEADERS = frozenset( + { + "authorization", + "proxy-authorization", + "cookie", + "set-cookie", + "x-api-key", + "x-auth-token", + } + ) + + def __init__( + self, + headers: dict[str, str] | list[tuple[str, str]], + ) -> None: + self.headers = headers + + def __str__(self) -> str: + items: Iterable[tuple[str, str]] = ( + self.headers.items() if isinstance(self.headers, dict) else self.headers + ) + redacted = { + k: "" if k.lower() in self.SENSITIVE_HEADERS else v + for k, v in items + } + return repr(redacted) + + class BodyFormatter: """Lazy %s-arg for logging HTTP bodies — formats only on emit.""" @@ -186,9 +217,10 @@ def request( attempts = len(RETRY_DELAYS) + 1 log.debug( - "%s %s body=%s", + "%s %s headers=%s body=%s", method, full_path, + HeaderFormatter(merged), BodyFormatter(body, content_type=merged.get("Content-Type", "")), ) @@ -220,24 +252,27 @@ def request( if 200 <= resp.status < 300: log.debug( - "%s %s -> %d %s", + "%s %s -> %d %s headers=%s", method, full_path, resp.status, resp.reason, + HeaderFormatter(list(resp.getheaders())), ) if log.isEnabledFor(logging.DEBUG): return self.log_and_buffer(method, full_path, resp) return resp + resp_headers = list(resp.getheaders()) resp_body = resp.read().decode("utf-8", errors="replace") log.debug( - "%s %s -> %d %s (%dB)", + "%s %s -> %d %s (%dB) headers=%s", method, full_path, resp.status, resp.reason, len(resp_body), + HeaderFormatter(resp_headers), ) log.debug( "%s %s response body: %s", diff --git a/contree_cli/config.py b/contree_cli/config.py index 17e91bf..fa8e097 100644 --- a/contree_cli/config.py +++ b/contree_cli/config.py @@ -94,6 +94,10 @@ def __init__(self, path: Path | None = None) -> None: self.__active: str = "default" self._load() + @property + def path(self) -> Path: + return self.__path + # -- persistence --------------------------------------------------------- def _load(self) -> None: @@ -191,32 +195,23 @@ def current(self, profile: ConfigProfile) -> None: self._save() def resolve(self, profile_override: str | None = None) -> ConfigProfile: - """Resolve the active profile with env-var overrides. + """Resolve the active profile by name. Priority: *profile_override* > ``CONTREE_PROFILE`` > config default. - Per-field: ``CONTREE_TOKEN`` / ``CONTREE_URL`` / ``CONTREE_PROJECT`` - override the stored values. + Credentials come strictly from the saved profile; runtime + commands do not read tokens, URLs, or project IDs from the + environment. To register/refresh credentials from env vars use + ``contree auth``. """ name = profile_override or os.environ.get("CONTREE_PROFILE") or self.__active - env_token = os.environ.get("CONTREE_TOKEN") - env_url = os.environ.get("CONTREE_URL") - env_project = os.environ.get("CONTREE_PROJECT") - if name in self.__profiles: - p = self.__profiles[name] - return ConfigProfile( - name=name, - token=env_token or p.token, - url=env_url or p.url, - auth_type=p.auth_type, - project=env_project or p.project, - ) + return self.__profiles[name] return ConfigProfile( name=name, - token=env_token, - url=env_url or "", + token=None, + url="", auth_type=AuthType.JWT, - project=env_project, + project=None, ) def switch(self, name: str) -> None: diff --git a/contree_cli/manual.md b/contree_cli/manual.md index 5790d35..94cc0fe 100644 --- a/contree_cli/manual.md +++ b/contree_cli/manual.md @@ -198,9 +198,11 @@ Profiles Each profile has its own session database. -Auth fallback: if `NEBIUS_API_KEY` and `NEBIUS_AI_PROJECT` are set -in the environment and no `--token`/`--project` flags are passed, -`contree auth` picks them up automatically (no interactive prompts). +Auth fallback: when `--token`/`--url`/`--project` flags are omitted, +`contree auth` reads `CONTREE_TOKEN` (or `NEBIUS_API_KEY`), +`CONTREE_URL`, and `CONTREE_PROJECT` (or `NEBIUS_AI_PROJECT`) from +the environment. These are used only during registration; runtime +commands read credentials strictly from the saved profile. More: contree auth --help @@ -218,12 +220,14 @@ override with $CONTREE_HOME. Environment variables: CONTREE_HOME data directory XDG_CONFIG_HOME XDG base config dir, used to derive default CONTREE_HOME - CONTREE_TOKEN API token (overrides config) - CONTREE_URL API URL (overrides config) - CONTREE_PROJECT project ID (overrides config) - CONTREE_PROFILE active profile + CONTREE_PROFILE active profile (selects which profile commands use) CONTREE_SESSION explicit session key +Registration-time fallbacks (read only by `contree auth`): + CONTREE_TOKEN / NEBIUS_API_KEY token when --token is omitted + CONTREE_URL URL when --url is omitted + CONTREE_PROJECT / NEBIUS_AI_PROJECT project ID when --project is omitted + More: contree --help All commands diff --git a/contree_cli/types.py b/contree_cli/types.py index 7036dfe..17c8fb3 100644 --- a/contree_cli/types.py +++ b/contree_cli/types.py @@ -62,6 +62,7 @@ "uuid": ("-i", "--uuid"), "username": ("--username",), "password": ("--password",), + "limit": ("--limit",), # use "new": ("-N", "--new"), # session diff --git a/docs/commands/auth.md b/docs/commands/auth.md index 8ba6306..dcfb3f5 100644 --- a/docs/commands/auth.md +++ b/docs/commands/auth.md @@ -59,17 +59,23 @@ When you run `contree auth`, the CLI: ### Environment variable shortcuts -When CLI flags (`--token`, `--project`) are not passed, `contree auth` -checks these environment variables before falling back to an interactive -prompt: +When CLI flags (`--token`, `--url`, `--project`) are not passed, +`contree auth` checks these environment variables before falling back +to an interactive prompt: | Variable | Fallback for | Priority | |----------|-------------|----------| -| `NEBIUS_API_KEY` | `--token` | flag > env > prompt | -| `NEBIUS_AI_PROJECT` | `--project` | flag > env > prompt | +| `CONTREE_TOKEN` | `--token` | flag > `CONTREE_TOKEN` > `NEBIUS_API_KEY` > prompt | +| `NEBIUS_API_KEY` | `--token` | (see above) | +| `CONTREE_URL` | `--url` | flag > env > type-specific default > prompt | +| `CONTREE_PROJECT` | `--project` | flag > `CONTREE_PROJECT` > `NEBIUS_AI_PROJECT` > prompt | +| `NEBIUS_AI_PROJECT` | `--project` | (see above) | -If both variables are set, `contree auth` runs fully non-interactively -(no prompts): +These variables are read **only** during `contree auth`. Other commands +ignore them and read credentials strictly from the saved profile. + +If the relevant variables are set, `contree auth` runs fully +non-interactively (no prompts): ```bash export NEBIUS_API_KEY=eyJ... @@ -84,7 +90,9 @@ with a 2-second timeout and adds a `status` column. Possible values: -- `ok` -- probe succeeded +- `ok` -- probe succeeded and the token has the `list` permission +- `inactive` -- probe succeeded but the token lacks the `list` + permission, meaning sandboxes are disabled on this project - `timeout` -- probe did not complete within 2 seconds - `error` -- probe failed for another reason, such as a bad token or another network/API error @@ -121,16 +129,23 @@ secure prompt instead. ## Alternative authentication -You can also authenticate without the config file: +Runtime commands always read credentials from the saved profile. +To authenticate without an interactive `auth` flow, either: ```bash -# Environment variable (avoid in shared environments) +# 1. Bootstrap the profile non-interactively from environment vars export CONTREE_TOKEN=eyJ... +export CONTREE_URL=https://api.tokenfactory.nebius.com/sandboxes +contree auth -y --type jwt +contree images -# Inline flag (per-command, visible in process listings) +# 2. Or pass the token inline per-command (visible in process listings) contree --token=eyJ... images ``` +Setting `CONTREE_TOKEN` alone (without first running `contree auth`) +will not authenticate runtime commands. + ## See also - {doc}`/tutorial/installation` -- full authentication guide diff --git a/docs/tutorial/configuration.md b/docs/tutorial/configuration.md index 2198cea..6e98a32 100644 --- a/docs/tutorial/configuration.md +++ b/docs/tutorial/configuration.md @@ -241,24 +241,36 @@ The active profile is still selected by the `profile` key in ## Environment variables +Read at runtime by any command: + | Variable | Description | |----------|-------------| | `CONTREE_HOME` | Data directory (default `$XDG_CONFIG_HOME/contree`, or `~/.config/contree`) | | `XDG_CONFIG_HOME` | XDG base config dir, used to derive the default `CONTREE_HOME` | -| `CONTREE_TOKEN` | API bearer token (overrides config) | -| `CONTREE_URL` | API base URL (overrides config) | -| `CONTREE_PROJECT` | Project ID (overrides config) | -| `CONTREE_PROFILE` | Active profile name (overrides config) | +| `CONTREE_PROFILE` | Active profile name (selects which profile commands use) | | `CONTREE_SESSION` | Explicit session key (overrides auto-generated) | +Read only by `contree auth` (registration-time fallbacks for omitted flags): + +| Variable | Used for | +|----------|----------| +| `CONTREE_TOKEN` / `NEBIUS_API_KEY` | `--token` | +| `CONTREE_URL` | `--url` | +| `CONTREE_PROJECT` / `NEBIUS_AI_PROJECT` | `--project` | + ## Resolution precedence -For token, URL, and project: +For token, URL, and project at runtime: + +1. CLI flag (`--token`, `--url`, `--project`) — overrides profile for the + current invocation only +2. Saved profile field +3. Built-in default URL for IAM: `https://api.tokenfactory.nebius.com/sandboxes` -1. CLI flag (`--token`, `--url`, `--project`) -2. Environment variable (`CONTREE_TOKEN`, `CONTREE_URL`, `CONTREE_PROJECT`) -3. Config file value from the active profile -4. Built-in default URL: `https://api.tokenfactory.nebius.com/sandboxes` +Environment variables are not consulted at runtime; to refresh credentials +from environment variables, run `contree auth` (which reads +`CONTREE_TOKEN` / `NEBIUS_API_KEY`, `CONTREE_URL`, and `CONTREE_PROJECT` / +`NEBIUS_AI_PROJECT` as fallbacks for the corresponding flags). For profiles: diff --git a/docs/tutorial/installation.md b/docs/tutorial/installation.md index d879245..cd67bfb 100644 --- a/docs/tutorial/installation.md +++ b/docs/tutorial/installation.md @@ -74,15 +74,18 @@ The CLI verifies the token with the API and writes credentials to `~/.config/contree/auth.ini`. If a profile already exists you will be prompted to confirm; use `-y` to skip the prompt. -Resolution order for each field (first match wins): +Resolution order for each field during `contree auth` (first match wins): -1. CLI flag (`--token`, `--project`) -2. Environment variable (`NEBIUS_API_KEY`, `NEBIUS_AI_PROJECT`) +1. CLI flag (`--token`, `--url`, `--project`) +2. Environment variables, in order: + - token: `CONTREE_TOKEN`, then `NEBIUS_API_KEY` + - URL: `CONTREE_URL` + - project: `CONTREE_PROJECT`, then `NEBIUS_AI_PROJECT` 3. Interactive prompt -So if `NEBIUS_API_KEY` and `NEBIUS_AI_PROJECT` are already in your -environment and no flags are passed, `contree auth` picks them up -automatically — no interactive prompts needed: +So if these variables are already in your environment and no flags +are passed, `contree auth` picks them up automatically, no interactive +prompts needed: ```bash export NEBIUS_API_KEY=eyJ... @@ -125,18 +128,22 @@ contree images # uses personal ### Token from environment -Set `CONTREE_TOKEN` to provide the token without a config file: +`CONTREE_TOKEN` and `NEBIUS_API_KEY` are read **only** by `contree auth` +during profile registration; runtime commands always read credentials +from the saved profile. To bootstrap a profile entirely from environment +variables, run `auth` non-interactively: ```bash export CONTREE_TOKEN=eyJ... +export CONTREE_URL=https://api.tokenfactory.nebius.com/sandboxes +contree auth -y --type jwt # one-shot setup, no prompts contree images ``` -Environment variables always take precedence over the config file. - ### Inline token -Pass `--token` to any command to override both config and env: +Pass `--token` to any command to override the saved profile for a single +invocation: ```bash contree --token=eyJ... images diff --git a/tests/test_auth.py b/tests/test_auth.py index 56e789c..941a5b7 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -46,14 +46,26 @@ def _make_iam_args(**kwargs) -> AuthArgs: return AuthArgs(**defaults) +def _whoami_body(*, permissions: dict[str, bool] | None = None) -> bytes: + body = { + "token_uuid": "00000000-0000-0000-0000-000000000000", + "token_expiration": None, + "permissions": {"list": True} if permissions is None else permissions, + "operations_stat": {}, + } + import json as _json + + return _json.dumps(body).encode() + + @contextmanager -def _mock_whoami(status=200): +def _mock_whoami(status=200, *, body: bytes | None = None): """Patch client_from_profile to return a fresh ContreeTestClient per call.""" last_client: list[ContreeTestClient] = [] def factory(profile, timeout=None): # type: ignore[no-untyped-def] tc = ContreeTestClient() - tc.respond(status=status, body=b'{"ok":true}') + tc.respond(status=status, body=body if body is not None else _whoami_body()) last_client.clear() last_client.append(tc) return tc @@ -97,7 +109,7 @@ def test_save_with_token(self, config_dir, caplog): assert p.token == "my_token" assert p.url == "https://my.dev" assert "Setting token for profile 'default'" in caplog.text - assert "Token verified and saved to profile 'default'" in caplog.text + assert "auth accepted, profile 'default' saved to ->" in caplog.text def test_logs_updating_for_existing_profile(self, config_dir, caplog): with _mock_whoami(): @@ -124,7 +136,7 @@ def test_save_named_profile(self, config_dir, caplog): p = Config().resolve() assert p.token == "tok" assert p.name == "staging" - assert "Token verified and saved to profile 'staging'" in caplog.text + assert "auth accepted, profile 'staging' saved to ->" in caplog.text def test_save_jwt_stores_type(self, config_dir): with _mock_whoami(): @@ -193,11 +205,48 @@ def test_bad_token_does_not_save(self, config_dir): p = Config().resolve() assert p.token is None + def test_no_list_permission_warns_but_saves(self, config_dir, caplog): + args = _make_auth_args(token="tok") + with ( + caplog.at_level("WARNING"), + _mock_whoami(body=_whoami_body(permissions={"list": False})), + ): + rc = cmd_auth(args) + assert rc is None + assert "sandboxes are disabled" in caplog.text + assert "Warning" in caplog.text + assert Config().resolve().token == "tok" + + def test_no_list_permission_warning_includes_project(self, config_dir, caplog): + args = _make_iam_args(token="tok", project="aiproject-restricted") + with ( + caplog.at_level("WARNING"), + _mock_whoami(body=_whoami_body(permissions={"list": False})), + ): + cmd_auth(args) + assert "aiproject-restricted" in caplog.text + + def test_missing_permissions_field_warns(self, config_dir, caplog): + args = _make_auth_args(token="tok") + body = b'{"token_uuid":"x","token_expiration":null,"operations_stat":{}}' + with caplog.at_level("WARNING"), _mock_whoami(body=body): + rc = cmd_auth(args) + assert rc is None + assert "sandboxes are disabled" in caplog.text + assert Config().resolve().token == "tok" + + def test_unparseable_whoami_rejected(self, config_dir, caplog): + args = _make_auth_args(token="tok") + with caplog.at_level("ERROR"), _mock_whoami(body=b"not-json"): + rc = cmd_auth(args) + assert rc == 1 + assert Config().resolve().token is None + def test_success_logs_saved(self, config_dir, caplog): args = _make_auth_args(token="good") with caplog.at_level("INFO"), _mock_whoami(): cmd_auth(args) - assert "Token verified and saved" in caplog.text + assert "auth accepted" in caplog.text def test_whoami_called(self, config_dir): args = _make_auth_args(token="tok") @@ -251,6 +300,65 @@ def test_both_nebius_vars_skip_all_prompts(self, config_dir, caplog, monkeypatch assert p.project == "aiproject-auto" +class TestContreeEnvFallbacks: + def test_contree_token_used_when_token_omitted( + self, config_dir, caplog, monkeypatch + ): + monkeypatch.setenv("CONTREE_TOKEN", "ctok") + with caplog.at_level("INFO"), _mock_whoami(): + cmd_auth(AuthArgs(url="https://test.dev", auth_type=AuthType.JWT)) + p = Config().resolve() + assert p.token == "ctok" + assert "Using token from CONTREE_TOKEN" in caplog.text + + def test_contree_token_preferred_over_nebius_api_key( + self, config_dir, caplog, monkeypatch + ): + monkeypatch.setenv("CONTREE_TOKEN", "ctok") + monkeypatch.setenv("NEBIUS_API_KEY", "ntok") + with caplog.at_level("INFO"), _mock_whoami(): + cmd_auth(AuthArgs(url="https://test.dev", auth_type=AuthType.JWT)) + p = Config().resolve() + assert p.token == "ctok" + + def test_contree_url_used_when_url_omitted(self, config_dir, caplog, monkeypatch): + monkeypatch.setenv("CONTREE_URL", "https://env-url.dev") + monkeypatch.setenv("NEBIUS_API_KEY", "tok") + with caplog.at_level("INFO"), _mock_whoami(): + cmd_auth(AuthArgs(auth_type=AuthType.JWT)) + p = Config().resolve() + assert p.url == "https://env-url.dev" + + def test_contree_project_used_when_project_omitted( + self, config_dir, caplog, monkeypatch + ): + monkeypatch.setenv("CONTREE_PROJECT", "aiproject-c") + with caplog.at_level("INFO"), _mock_whoami(): + cmd_auth( + AuthArgs( + token="tok", + url="https://iam.test", + auth_type=AuthType.IAM, + ) + ) + p = Config().resolve() + assert p.project == "aiproject-c" + assert "Using project from CONTREE_PROJECT" in caplog.text + + def test_explicit_token_flag_beats_env(self, config_dir, monkeypatch): + monkeypatch.setenv("CONTREE_TOKEN", "from-env") + with _mock_whoami(): + cmd_auth( + AuthArgs( + token="from-flag", + url="https://test.dev", + auth_type=AuthType.JWT, + ) + ) + p = Config().resolve() + assert p.token == "from-flag" + + # --------------------------------------------------------------------------- # Switch # --------------------------------------------------------------------------- @@ -296,7 +404,7 @@ def flush(self) -> None: def fake_factory(profile, timeout=None): # type: ignore[no-untyped-def] tc = ContreeTestClient(token=profile.token) if profile.token == "tok-ok": - tc.respond(status=200, body=b'{"ok":true}') + tc.respond(status=200, body=_whoami_body()) elif profile.token == "tok-timeout": def timeout_get(path, params=None): # type: ignore[no-untyped-def] @@ -324,6 +432,39 @@ def error_get(path, params=None): # type: ignore[no-untyped-def] assert by_name["timeout"]["status"] == "timeout" assert by_name["error"]["status"] == "error" + def test_profiles_inactive_status(self, config_dir): + """Profile whose token lacks `list` permission is reported as inactive.""" + with _mock_whoami(): + cmd_auth(_make_auth_args(token="tok", profile="restricted")) + + rows: list[dict[str, object]] = [] + + class CaptureFormatter: + def __call__(self, **kwargs: object) -> None: + rows.append(kwargs) + + def flush(self) -> None: + return + + def fake_factory(profile, timeout=None): # type: ignore[no-untyped-def] + tc = ContreeTestClient(token=profile.token) + tc.respond( + status=200, + body=_whoami_body(permissions={"list": False, "spawn": True}), + ) + return tc + + FORMATTER.set(CaptureFormatter()) + ctx = copy_context() + with patch( + "contree_cli.cli.auth.client_from_profile", + side_effect=fake_factory, + ): + ctx.run(cmd_list, ProfilesArgs(offline=False)) + + by_name = {str(row["name"]): row for row in rows} + assert by_name["restricted"]["status"] == "inactive" + def test_profiles_offline_skips_probe(self, config_dir): with _mock_whoami(): cmd_auth(_make_auth_args(token="tok", profile="offline-test")) diff --git a/tests/test_client.py b/tests/test_client.py index 30ef44e..54f0041 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -14,6 +14,7 @@ BodyFormatter, ContreeClient, ContreeJWTClient, + HeaderFormatter, resolve_image, ) @@ -409,6 +410,77 @@ def test_octet_stream_request_body_not_dumped(self, caplog): assert "" in msgs + + def test_response_headers_logged(self, caplog): + self._enable_debug(caplog) + c = ContreeTestClient("https://contree.dev", "tok") + c.fake.responses.append( + FakeResponse( + status=200, + body=b"{}", + headers={"Content-Type": "application/json", "X-Trace-Id": "abc"}, + ) + ) + c.request("GET", "/v1/images") + msgs = "\n".join(r.getMessage() for r in caplog.records) + assert "X-Trace-Id" in msgs + assert "abc" in msgs + + def test_error_response_headers_logged(self, caplog): + self._enable_debug(caplog) + c = ContreeTestClient("https://contree.dev", "tok") + c.fake.responses.append( + FakeResponse( + status=400, + body=b"bad", + headers={"X-Trace-Id": "trace-err"}, + ) + ) + with pytest.raises(ApiError): + c.request("GET", "/v1/images") + msgs = "\n".join(r.getMessage() for r in caplog.records) + assert "trace-err" in msgs + + +class TestHeaderFormatter: + def test_redacts_authorization(self): + out = str(HeaderFormatter({"Authorization": "Bearer secret", "X-Foo": "bar"})) + assert "secret" not in out + assert "" in out + assert "bar" in out + + def test_redaction_is_case_insensitive(self): + out = str(HeaderFormatter({"AUTHORIZATION": "Bearer secret"})) + assert "secret" not in out + assert "" in out + + def test_accepts_list_of_tuples(self): + out = str( + HeaderFormatter( + [("Authorization", "Bearer secret"), ("X-Trace-Id", "abc")], + ) + ) + assert "secret" not in out + assert "abc" in out + + def test_redacts_cookie(self): + out = str(HeaderFormatter({"Cookie": "session=xyz"})) + assert "xyz" not in out + + def test_non_sensitive_passes_through(self): + out = str(HeaderFormatter({"User-Agent": "ua/1.0", "Project": "proj"})) + assert "ua/1.0" in out + assert "proj" in out + class TestBodyFormatter: def test_none(self): diff --git a/tests/test_config.py b/tests/test_config.py index f5e44f5..14d971a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -149,7 +149,7 @@ def test_env_profile_overrides_config(self, config_dir, monkeypatch): assert p.name == "staging" assert p.token == "tok2" - def test_env_token_overrides_config(self, config_dir, monkeypatch): + def test_env_token_does_not_override_config(self, config_dir, monkeypatch): cfg = Config() cfg["default"] = ConfigProfile( name="default", @@ -158,9 +158,9 @@ def test_env_token_overrides_config(self, config_dir, monkeypatch): ) monkeypatch.setenv("CONTREE_TOKEN", "env_token") p = Config().resolve() - assert p.token == "env_token" + assert p.token == "cfg_token" - def test_env_url_overrides_config(self, config_dir, monkeypatch): + def test_env_url_does_not_override_config(self, config_dir, monkeypatch): cfg = Config() cfg["default"] = ConfigProfile( name="default", @@ -169,7 +169,7 @@ def test_env_url_overrides_config(self, config_dir, monkeypatch): ) monkeypatch.setenv("CONTREE_URL", "https://env.dev") p = Config().resolve() - assert p.url == "https://env.dev" + assert p.url == "https://custom.dev" def test_url_falls_back_for_jwt_when_missing(self, config_dir): """JWT profile with url key removed falls back to empty string.""" @@ -268,7 +268,7 @@ def test_project_none_when_not_set(self, config_dir): p = Config().resolve() assert p.project is None - def test_env_project_overrides_config(self, config_dir, monkeypatch): + def test_env_project_does_not_override_config(self, config_dir, monkeypatch): cfg = Config() cfg["default"] = ConfigProfile( name="default", @@ -279,7 +279,7 @@ def test_env_project_overrides_config(self, config_dir, monkeypatch): ) monkeypatch.setenv("CONTREE_PROJECT", "aiproject-env") p = Config().resolve() - assert p.project == "aiproject-env" + assert p.project == "aiproject-cfg" def test_save_clears_project_when_none(self, config_dir): cfg = Config() diff --git a/tests/test_cp.py b/tests/test_cp.py index 83a5c46..b050810 100644 --- a/tests/test_cp.py +++ b/tests/test_cp.py @@ -38,6 +38,11 @@ def getheader(self, name: str, default: str | None = None) -> str | None: return str(self._length) return default + def getheaders(self) -> list[tuple[str, str]]: + if self._length is not None: + return [("Content-Length", str(self._length))] + return [] + def _run_cmd( tc: ContreeTestClient, diff --git a/tests/test_file_cmd.py b/tests/test_file_cmd.py index 332e152..502f6c6 100644 --- a/tests/test_file_cmd.py +++ b/tests/test_file_cmd.py @@ -36,6 +36,9 @@ def read(self, amt: int | None = None) -> bytes: def getheader(self, name: str, default: str | None = None) -> str | None: return default + def getheaders(self) -> list[tuple[str, str]]: + return [] + def _api_response(body: bytes | dict, *, status: int = 200) -> StreamResponse: data = json.dumps(body).encode() if isinstance(body, dict) else body diff --git a/tests/test_images.py b/tests/test_images.py index df838a1..99dc258 100644 --- a/tests/test_images.py +++ b/tests/test_images.py @@ -9,6 +9,7 @@ from contree_cli import CLIENT, FORMATTER from contree_cli.cli.images import ( + LIMIT_DEFAULT, PAGE_SIZE, ImagesArgs, ImportArgs, @@ -168,6 +169,115 @@ def test_all_images_emitted(self, contree_client, capsys): out = capsys.readouterr().out assert out.count("uuid-") == PAGE_SIZE + 5 + def test_progress_logged_per_full_page(self, contree_client, caplog): + """Each completed full page emits a progress line at INFO level.""" + import logging + + page1 = [_make_image(i) for i in range(PAGE_SIZE)] + page2 = [_make_image(i) for i in range(PAGE_SIZE, PAGE_SIZE * 2)] + page3 = [_make_image(i) for i in range(PAGE_SIZE * 2, PAGE_SIZE * 2 + 3)] + with caplog.at_level(logging.INFO, logger="contree_cli.cli.images"): + _run_cmd_pages(contree_client, [page1, page2, page3]) + msgs = [r.getMessage() for r in caplog.records] + assert any( + f"Fetched {PAGE_SIZE} images so far" in m and "Ctrl+C" in m for m in msgs + ) + assert any( + f"Fetched {PAGE_SIZE * 2} images so far" in m and "Ctrl+C" in m + for m in msgs + ) + assert not any(f"{PAGE_SIZE * 2 + 3}" in m for m in msgs) + + def test_default_limit_is_2000(self): + assert LIMIT_DEFAULT == 2000 + assert ImagesArgs().limit == LIMIT_DEFAULT + + def test_limit_truncates_with_warning(self, contree_client, caplog): + """Hitting --limit triggers a probe; non-empty probe -> warning.""" + import logging + + page1 = [_make_image(i) for i in range(PAGE_SIZE)] + contree_client.respond_json({"images": page1}) + contree_client.respond_json({"images": [_make_image(PAGE_SIZE)]}) + + FORMATTER.set(CSVFormatter()) + ctx = copy_context() + with caplog.at_level(logging.WARNING, logger="contree_cli.cli.images"): + ctx.run(cmd_images, ImagesArgs(limit=PAGE_SIZE)) + msgs = [r.getMessage() for r in caplog.records if r.levelname == "WARNING"] + assert any("truncated" in m and f"--limit={PAGE_SIZE}" in m for m in msgs) + assert contree_client.request_count == 2 + + def test_limit_probe_uses_skip_of_one(self, contree_client): + """Probe is a single-record request, not a full page.""" + page1 = [_make_image(i) for i in range(PAGE_SIZE)] + contree_client.respond_json({"images": page1}) + contree_client.respond_json({"images": []}) + + FORMATTER.set(CSVFormatter()) + ctx = copy_context() + ctx.run(cmd_images, ImagesArgs(limit=PAGE_SIZE)) + + probe_path = contree_client.request_paths[1] + assert "limit=1" in probe_path + assert f"offset={PAGE_SIZE}" in probe_path + + def test_limit_warning_after_table_flush(self, contree_client, caplog, capsys): + """TableFormatter buffer is flushed before the truncation warning.""" + import logging + + page1 = [_make_image(i) for i in range(PAGE_SIZE)] + contree_client.respond_json({"images": page1}) + contree_client.respond_json({"images": [_make_image(PAGE_SIZE)]}) + + FORMATTER.set(TableFormatter()) + ctx = copy_context() + with caplog.at_level(logging.WARNING, logger="contree_cli.cli.images"): + ctx.run(cmd_images, ImagesArgs(limit=PAGE_SIZE)) + + out = capsys.readouterr().out + # Table content must be printed (i.e. flushed) before the handler + # logs the warning. Verify the table is on stdout already. + assert "uuid-0" in out + assert f"uuid-{PAGE_SIZE - 1}" in out + + def test_limit_no_warning_when_no_more(self, contree_client, caplog): + """Empty probe response -> no warning.""" + import logging + + page1 = [_make_image(i) for i in range(PAGE_SIZE)] + contree_client.respond_json({"images": page1}) + contree_client.respond_json({"images": []}) + + FORMATTER.set(CSVFormatter()) + ctx = copy_context() + with caplog.at_level(logging.WARNING, logger="contree_cli.cli.images"): + ctx.run(cmd_images, ImagesArgs(limit=PAGE_SIZE)) + warns = [r for r in caplog.records if r.levelname == "WARNING"] + assert not any("truncated" in r.getMessage() for r in warns) + + def test_limit_request_uses_capped_page_size(self, contree_client): + """When --limit < PAGE_SIZE, the API request asks for limit items only.""" + contree_client.respond_json({"images": [_make_image(i) for i in range(3)]}) + contree_client.respond_json({"images": []}) # probe + + FORMATTER.set(CSVFormatter()) + ctx = copy_context() + ctx.run(cmd_images, ImagesArgs(limit=3)) + + assert "limit=3" in contree_client.request_paths[0] + assert "limit=1" in contree_client.request_paths[1] + assert contree_client.request_count == 2 + + def test_progress_not_logged_for_single_short_page(self, contree_client, caplog): + """Final/only partial page does not emit progress (output covers it).""" + import logging + + images = [_make_image(i) for i in range(5)] + with caplog.at_level(logging.INFO, logger="contree_cli.cli.images"): + _run_cmd(contree_client, images) + assert not any("images so far" in r.getMessage() for r in caplog.records) + class TestImagesCreatedAtFormats: """Verify created_at parsing with various ISO 8601 formats from the API.""" diff --git a/tests/test_ps.py b/tests/test_ps.py index 4d28204..b838c6f 100644 --- a/tests/test_ps.py +++ b/tests/test_ps.py @@ -135,11 +135,29 @@ def test_multi_page(self, contree_client, capsys): def test_offset_increments(self, contree_client): page1 = [_make_op(i) for i in range(PAGE_SIZE)] page2 = [] - _run_cmd_pages(contree_client, [page1, page2]) + _run_cmd_pages(contree_client, [page1, page2], show_max=None) paths = contree_client.request_paths assert "offset=0" in paths[0] assert f"offset={PAGE_SIZE}" in paths[1] + def test_progress_logged_per_full_page(self, contree_client, caplog): + """Each completed full page emits a progress line at INFO level.""" + import logging + + page1 = [_make_op(i) for i in range(PAGE_SIZE)] + page2 = [_make_op(i) for i in range(PAGE_SIZE, PAGE_SIZE + 3)] + with caplog.at_level(logging.INFO, logger="contree_cli.cli.ps"): + _run_cmd_pages( + contree_client, + [page1, page2], + show_max=None, + ) + msgs = [r.getMessage() for r in caplog.records] + assert any( + f"Fetched {PAGE_SIZE} operations so far" in m and "Ctrl+C" in m + for m in msgs + ) + class TestPsActiveFilter: def test_default_sends_executing_status_to_server(self, contree_client, capsys): @@ -206,17 +224,29 @@ def test_status_shortcut_expansion( class TestPsShowMax: def test_show_max_truncates_output(self, contree_client, capsys): - ops = [_make_op(i) for i in range(5)] - _run_cmd(contree_client, ops, show_max=3, all=True) + """show_max caps emitted ops; probe runs after for more.""" + page = [_make_op(i) for i in range(5)] + _run_cmd_pages( + contree_client, + [page, [_make_op(99)]], # main + probe + show_max=3, + all=True, + ) out = capsys.readouterr().out assert "op-0" in out assert "op-1" in out - assert "op-2" not in out + assert "op-2" in out + assert "op-3" not in out def test_show_max_logs_warning(self, contree_client, caplog): - ops = [_make_op(i) for i in range(5)] - _run_cmd(contree_client, ops, show_max=3, all=True) - assert "show_max limit of 3" in caplog.text + page = [_make_op(i) for i in range(5)] + _run_cmd_pages( + contree_client, + [page, [_make_op(99)]], # probe finds more + show_max=3, + all=True, + ) + assert "Output truncated at --show-max=3" in caplog.text def test_show_max_none_shows_all(self, contree_client, capsys): ops = [_make_op(i) for i in range(5)] @@ -239,13 +269,18 @@ def test_show_max_no_warning_when_under_limit( ): ops = [_make_op(i) for i in range(3)] _run_cmd(contree_client, ops, show_max=100, all=True) - assert "show_max" not in caplog.text + assert "Output truncated" not in caplog.text def test_show_max_stops_pagination(self, contree_client, capsys): - """show_max stops iteration mid-page, no extra page fetch.""" + """show_max stops mid-page; one probe request follows.""" ops = [_make_op(i) for i in range(10)] - _run_cmd(contree_client, ops, show_max=3, all=True) - assert contree_client.request_count == 1 + _run_cmd_pages( + contree_client, + [ops, [_make_op(99)]], # main + probe + show_max=3, + all=True, + ) + assert contree_client.request_count == 2 def test_show_max_across_pages(self, contree_client, capsys): """show_max truncates across page boundaries.""" @@ -253,20 +288,67 @@ def test_show_max_across_pages(self, contree_client, capsys): page2 = [_make_op(i) for i in range(PAGE_SIZE, PAGE_SIZE + 5)] _run_cmd_pages( contree_client, - [page1, page2], + [page1, page2, [_make_op(99)]], # main pages + probe show_max=PAGE_SIZE + 2, all=True, ) out = capsys.readouterr().out - assert f"op-{PAGE_SIZE}" in out + assert f"op-{PAGE_SIZE + 1}" in out assert f"op-{PAGE_SIZE + 2}" not in out - def test_show_max_one_shows_nothing(self, contree_client, capsys): - """show_max=1 yields 0 ops (counter starts at 1, 1>=1 is true).""" - ops = [_make_op(0)] - _run_cmd(contree_client, ops, show_max=1, all=True) + def test_show_max_one_shows_one(self, contree_client, capsys): + """show_max=1 emits exactly one op (no off-by-one).""" + ops = [_make_op(0), _make_op(1)] + _run_cmd_pages( + contree_client, + [ops, [_make_op(99)]], # main + probe + show_max=1, + all=True, + ) out = capsys.readouterr().out - assert "op-0" not in out + assert "op-0" in out + assert "op-1" not in out + + def test_show_max_probe_uses_skip_of_one(self, contree_client): + """Probe is a single-record request after the cap.""" + page = [_make_op(i) for i in range(5)] + _run_cmd_pages( + contree_client, + [page, []], + show_max=3, + all=True, + ) + probe_path = contree_client.request_paths[1] + assert "limit=1" in probe_path + assert "offset=3" in probe_path + + def test_show_max_no_warning_when_probe_empty(self, contree_client, caplog): + """Empty probe means we hit show_max but there's nothing more.""" + page = [_make_op(i) for i in range(3)] + _run_cmd_pages( + contree_client, + [page, []], # probe empty + show_max=3, + all=True, + ) + assert "Output truncated" not in caplog.text + + def test_show_max_warning_after_table_flush(self, contree_client, caplog, capsys): + """TableFormatter buffer is flushed before the warning is logged.""" + import logging + + page = [_make_op(i) for i in range(5)] + for response in (page, [_make_op(99)]): + contree_client.respond_json(response) + + FORMATTER.set(TableFormatter()) + ctx = copy_context() + with caplog.at_level(logging.WARNING, logger="contree_cli.cli.ps"): + ctx.run(cmd_ps, PsArgs(show_max=3, all=True)) + + out = capsys.readouterr().out + assert "op-0" in out + assert "op-2" in out class TestPsCreatedAtFormats: From fb07149c857f5db4382800c5fbadedab913bc8d5 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Fri, 8 May 2026 14:47:35 +0200 Subject: [PATCH 2/8] Add version check --- contree_cli/__main__.py | 8 ++ contree_cli/arguments.py | 13 +-- contree_cli/client.py | 4 +- contree_cli/update_check.py | 174 ++++++++++++++++++++++++++++++++++++ 4 files changed, 191 insertions(+), 8 deletions(-) create mode 100644 contree_cli/update_check.py diff --git a/contree_cli/__main__.py b/contree_cli/__main__.py index 223d4e9..0c8795d 100644 --- a/contree_cli/__main__.py +++ b/contree_cli/__main__.py @@ -2,6 +2,7 @@ import logging import sys from collections.abc import Callable +from contextlib import suppress from dataclasses import replace import contree_cli.config as config_mod @@ -12,11 +13,16 @@ from contree_cli.log import setup_logging from contree_cli.output import FORMATTERS from contree_cli.session import SessionStore, get_session_key +from contree_cli.update_check import UpdateChecker log = logging.getLogger(__name__) def main() -> None: + checker = UpdateChecker() + with suppress(Exception): + checker.refresh() + if len(sys.argv) == 1: parser.print_help() exit(0) @@ -27,6 +33,8 @@ def main() -> None: args = parser.parse_args() setup_logging(level=getattr(logging, args.log_level.upper(), logging.INFO)) + checker.check() + config_mod.CONFIG_FILE = args.config_path config_mod.CONFIG_DIR = args.config_path.parent diff --git a/contree_cli/arguments.py b/contree_cli/arguments.py index 9dfb750..1111aea 100644 --- a/contree_cli/arguments.py +++ b/contree_cli/arguments.py @@ -72,12 +72,13 @@ contree tag UUID TAG | kill UUID | cd PATH | session checkout BRANCH environment variables: - CONTREE_PROFILE Active config profile (selects which profile to use) - CONTREE_SESSION Explicit session name (for multi-terminal workflows). - If unset, contree auto-generates +<8hex> (derived from - profile+ppid+tty); export your own for stable reuse. - You can also pass -S/--session instead of exporting env. - CONTREE_SESSION_DB Path to session SQLite database + CONTREE_PROFILE Active config profile (selects which profile to use) + CONTREE_SESSION Explicit session name (for multi-terminal workflows). + If unset, contree auto-generates +<8hex> (derived + from profile+ppid+tty); export your own for stable + reuse. You can also pass -S/--session instead. + CONTREE_SESSION_DB Path to session SQLite database + CONTREE_NO_UPDATE_CHECK Set to any value to disable PyPI update checks registration-time fallbacks (only read by `contree auth`, not at runtime): CONTREE_TOKEN / NEBIUS_API_KEY Token used when --token is omitted diff --git a/contree_cli/client.py b/contree_cli/client.py index ef9487b..c74d841 100644 --- a/contree_cli/client.py +++ b/contree_cli/client.py @@ -21,7 +21,7 @@ RETRY_DELAYS = (1, 2, 4, 5, 10, 10, 10) -def _cli_version() -> str: +def cli_version() -> str: try: return version("contree-cli") except PackageNotFoundError: @@ -29,7 +29,7 @@ def _cli_version() -> str: CLI_USER_AGENT = ( - f"contree-cli/{_cli_version()} " + f"contree-cli/{cli_version()} " f"Python/{'.'.join(map(str, sys.version_info))} " f"{platform.platform()} " ) diff --git a/contree_cli/update_check.py b/contree_cli/update_check.py new file mode 100644 index 0000000..2d9f0cf --- /dev/null +++ b/contree_cli/update_check.py @@ -0,0 +1,174 @@ +"""PyPI update check, rate-limited to once per day. + +State file at ``$CONTREE_HOME/cli/version_check.json``:: + + { + "last_check": "2026-05-08T12:00:00+00:00", + "latest_version": "0.5.0", + "current_version": "0.4.2" + } + +Network errors, malformed cache files, and parse failures are swallowed: +the update check must never break a user's command. +""" + +from __future__ import annotations + +import json +import logging +import os +import urllib.request +from contextlib import suppress +from datetime import datetime, timedelta, timezone +from pathlib import Path + +from contree_cli import config +from contree_cli.client import CLI_USER_AGENT, cli_version + +log = logging.getLogger(__name__) + + +class UpdateChecker: + """Encapsulates the state file + PyPI probe + outdated-version warning. + + All side effects (filesystem, network, logging) are guarded so that a + failure in update-checking can never break the user's command. + """ + + PYPI_URL = "https://pypi.org/pypi/contree-cli/json" + CHECK_INTERVAL = timedelta(days=1) + NETWORK_TIMEOUT = 2.0 + OPT_OUT_ENV = "CONTREE_NO_UPDATE_CHECK" + STATE_PATH = config.CONTREE_HOME / "cli" / "version_check.json" + + def __init__( + self, + *, + state_path: Path = STATE_PATH, + current_version: str = cli_version(), + ) -> None: + self.state_path = state_path + self.current_version = current_version + + def read_state(self) -> dict[str, str]: + if not self.state_path.exists(): + return {} + try: + data = json.loads(self.state_path.read_text()) + except (json.JSONDecodeError, OSError): + return {} + return data if isinstance(data, dict) else {} + + def write_state(self, state: dict[str, str]) -> None: + with suppress(OSError): + self.state_path.parent.mkdir(parents=True, exist_ok=True) + self.state_path.write_text(json.dumps(state, indent=2) + "\n") + + @staticmethod + def parse_version(value: str) -> tuple[int, ...]: + """Parse a PEP-440-ish version into a comparable tuple of ints. + + Strips non-numeric suffixes per component (e.g. ``0a1`` -> ``0``) + so pre-releases compare as their numeric prefix. Good enough to + decide "is X newer than Y" for normal releases. + """ + parts: list[int] = [] + for component in value.split("."): + digits = "" + for ch in component: + if ch.isdigit(): + digits += ch + else: + break + if digits: + parts.append(int(digits)) + return tuple(parts) + + def fetch_latest_version(self) -> str | None: + try: + request = urllib.request.Request( + self.PYPI_URL, + headers={ + "User-Agent": CLI_USER_AGENT, + "Accept": "application/json", + }, + ) + with urllib.request.urlopen( # nosemgrep + request, timeout=self.NETWORK_TIMEOUT + ) as resp: + payload = json.loads(resp.read()) + except Exception: + return None + info = payload.get("info") if isinstance(payload, dict) else None + if not isinstance(info, dict): + return None + version = info.get("version") + return version if isinstance(version, str) else None + + def warn_if_outdated(self, latest: str) -> None: + if self.parse_version(latest) > self.parse_version(self.current_version): + log.warning( + "A new version of contree-cli is available: %s (installed: %s)." + " Upgrade with `uv tool install -U contree-cli` or" + " `pip install -U contree-cli`.", + latest, + self.current_version, + ) + + @property + def enabled(self) -> bool: + return ( + not os.environ.get(self.OPT_OUT_ENV) and self.current_version != "editable" + ) + + def is_state_fresh(self) -> bool: + """True if the cache file's mtime is within ``CHECK_INTERVAL``. + + Uses ``Path.stat()`` (works on Windows and POSIX) to avoid + reading/parsing the JSON on the hot path. Missing files, + permission errors, or any OSError is treated as "not fresh". + """ + try: + mtime = self.state_path.stat().st_mtime + except OSError: + return False + age_seconds = datetime.now(timezone.utc).timestamp() - mtime + return age_seconds < self.CHECK_INTERVAL.total_seconds() + + def refresh(self) -> None: + """Refresh cached PyPI state, rate-limited to ``CHECK_INTERVAL``. + + Decides freshness from the cache file's mtime to avoid the + JSON read+parse on the common case. If stale, probes PyPI and + rewrites ``state_path``. + """ + if not self.enabled: + return + if self.is_state_fresh(): + return + + latest = self.fetch_latest_version() + if latest is None: + return + + self.write_state( + { + "last_check": datetime.now(timezone.utc).isoformat(), + "latest_version": latest, + "current_version": self.current_version, + } + ) + + def check(self) -> None: + """Log a warning if cached ``latest_version`` is newer than current. + + Pure read; never touches the network or rewrites state. Pair + with :meth:`refresh` to first ensure the cache is up to date. + """ + if not self.enabled: + return + + state = self.read_state() + cached = state.get("latest_version") + if isinstance(cached, str): + self.warn_if_outdated(cached) From cdac189a76cb868c2ed33dd06c9a494b56a479ca Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Fri, 8 May 2026 14:47:48 +0200 Subject: [PATCH 3/8] Add version check tests --- tests/test_update_check.py | 351 +++++++++++++++++++++++++++++++++++++ 1 file changed, 351 insertions(+) create mode 100644 tests/test_update_check.py diff --git a/tests/test_update_check.py b/tests/test_update_check.py new file mode 100644 index 0000000..dcdb58b --- /dev/null +++ b/tests/test_update_check.py @@ -0,0 +1,351 @@ +from __future__ import annotations + +import json +import logging +import os +from datetime import datetime, timedelta, timezone +from unittest.mock import patch + +import pytest + +from contree_cli import update_check +from contree_cli.update_check import UpdateChecker + + +def read_json(path): + return json.loads(path.read_text()) + + +def seed_state(path, payload, *, mtime: float | None = None): + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload)) + if mtime is not None: + try: + os.utime(path, (mtime, mtime)) + except (OSError, NotImplementedError) as exc: + pytest.skip(f"os.utime is not supported on this platform: {exc}") + + +def fake_now(offset: timedelta): + """Patch update_check.datetime.now to return real-now + offset. + + Use this when tests need to simulate clock advancement on platforms + where ``os.utime`` may not work (e.g. some Windows configurations). + """ + pinned = datetime.now(timezone.utc) + offset + + class FakeDatetime(datetime): + @classmethod + def now(cls, tz=None): + return pinned if tz is not None else pinned.replace(tzinfo=None) + + return patch.object(update_check, "datetime", FakeDatetime) + + +@pytest.fixture() +def state_path(tmp_path): + return tmp_path / "version_check.json" + + +class TestParseVersion: + @pytest.mark.parametrize( + "value,expected", + [ + ("1.2.3", (1, 2, 3)), + ("0.0.1", (0, 0, 1)), + ("0.4.2a1", (0, 4, 2)), + ("1", (1,)), + ("", ()), + ("1.x.3", (1, 3)), + ], + ) + def test_cases(self, value, expected): + assert UpdateChecker.parse_version(value) == expected + + +class TestEnabled: + def test_disabled_in_editable_mode(self, state_path): + checker = UpdateChecker(state_path=state_path, current_version="editable") + assert checker.enabled is False + + def test_disabled_when_opt_out_env_set(self, state_path, monkeypatch): + monkeypatch.setenv("CONTREE_NO_UPDATE_CHECK", "1") + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + assert checker.enabled is False + + def test_enabled_normal(self, state_path): + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + assert checker.enabled is True + + +class TestRefresh: + def test_skips_in_editable_mode(self, state_path): + checker = UpdateChecker(state_path=state_path, current_version="editable") + with patch.object(checker, "fetch_latest_version") as fetch: + checker.refresh() + fetch.assert_not_called() + + def test_skips_when_opt_out_env_set(self, state_path, monkeypatch): + monkeypatch.setenv("CONTREE_NO_UPDATE_CHECK", "1") + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with patch.object(checker, "fetch_latest_version") as fetch: + checker.refresh() + fetch.assert_not_called() + + def test_fetches_and_writes_when_no_cache(self, state_path): + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with patch.object(checker, "fetch_latest_version", return_value="0.4.1"): + checker.refresh() + data = read_json(state_path) + assert data["latest_version"] == "0.4.1" + assert data["current_version"] == "0.4.0" + assert "last_check" in data + + def test_skips_network_within_interval(self, state_path): + recent = datetime.now(timezone.utc) - timedelta(hours=1) + seed_state( + state_path, + { + "last_check": recent.isoformat(), + "latest_version": "0.4.0", + "current_version": "0.4.0", + }, + ) + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with patch.object(checker, "fetch_latest_version") as fetch: + checker.refresh() + fetch.assert_not_called() + + def test_refetches_after_interval_expires(self, state_path): + old_ts = (datetime.now(timezone.utc) - timedelta(days=2)).timestamp() + seed_state( + state_path, + { + "last_check": "2026-01-01T00:00:00+00:00", + "latest_version": "0.4.0", + "current_version": "0.4.0", + }, + mtime=old_ts, + ) + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with patch.object( + checker, "fetch_latest_version", return_value="0.4.5" + ) as fetch: + checker.refresh() + fetch.assert_called_once() + assert read_json(state_path)["latest_version"] == "0.4.5" + + def test_swallows_network_failure(self, state_path): + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with patch.object(checker, "fetch_latest_version", return_value=None): + checker.refresh() + assert not state_path.exists() + + def test_corrupt_cache_file_is_recoverable_after_interval(self, state_path): + """Corrupt JSON is silently overwritten once mtime expires.""" + state_path.parent.mkdir(parents=True, exist_ok=True) + state_path.write_text("{not json") + old_ts = (datetime.now(timezone.utc) - timedelta(days=2)).timestamp() + os.utime(state_path, (old_ts, old_ts)) + + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with patch.object( + checker, "fetch_latest_version", return_value="0.4.1" + ) as fetch: + checker.refresh() + fetch.assert_called_once() + assert read_json(state_path)["latest_version"] == "0.4.1" + + def test_refresh_does_not_log_warning(self, state_path, caplog): + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with ( + caplog.at_level(logging.WARNING, logger="contree_cli.update_check"), + patch.object(checker, "fetch_latest_version", return_value="0.5.0"), + ): + checker.refresh() + assert "available" not in caplog.text + + +class TestRefreshClockMock: + """Same logic as TestRefresh, but driven by mocked ``datetime.now``. + + Works on every platform — including Windows configurations where + ``os.utime`` may silently no-op or raise — because nothing depends + on the filesystem's mtime resolution or write permissions. + """ + + def test_refetches_after_simulated_interval(self, state_path): + seed_state( + state_path, + { + "last_check": "2026-01-01T00:00:00+00:00", + "latest_version": "0.4.0", + "current_version": "0.4.0", + }, + ) + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with ( + fake_now(timedelta(days=2)), + patch.object( + checker, "fetch_latest_version", return_value="0.4.5" + ) as fetch, + ): + checker.refresh() + fetch.assert_called_once() + assert read_json(state_path)["latest_version"] == "0.4.5" + + def test_skips_within_simulated_interval(self, state_path): + seed_state( + state_path, + { + "last_check": "x", + "latest_version": "0.4.0", + "current_version": "0.4.0", + }, + ) + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with ( + fake_now(timedelta(hours=1)), + patch.object(checker, "fetch_latest_version") as fetch, + ): + checker.refresh() + fetch.assert_not_called() + + +class TestIsStateFresh: + def test_returns_true_for_freshly_written_file(self, state_path): + state_path.parent.mkdir(parents=True, exist_ok=True) + state_path.write_text("{}") + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + assert checker.is_state_fresh() is True + + def test_returns_false_when_missing(self, state_path): + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + assert checker.is_state_fresh() is False + + def test_returns_false_when_clock_advanced_past_interval(self, state_path): + """Cross-platform: simulate 2-day clock advance, no os.utime.""" + state_path.parent.mkdir(parents=True, exist_ok=True) + state_path.write_text("{}") + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with fake_now(timedelta(days=2)): + assert checker.is_state_fresh() is False + + def test_returns_true_when_clock_advanced_within_interval(self, state_path): + state_path.parent.mkdir(parents=True, exist_ok=True) + state_path.write_text("{}") + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with fake_now(timedelta(hours=23)): + assert checker.is_state_fresh() is True + + +class TestCheck: + def test_skips_in_editable_mode(self, state_path, caplog): + seed_state( + state_path, + { + "last_check": datetime.now(timezone.utc).isoformat(), + "latest_version": "9.9.9", + "current_version": "editable", + }, + ) + checker = UpdateChecker(state_path=state_path, current_version="editable") + with caplog.at_level(logging.WARNING, logger="contree_cli.update_check"): + checker.check() + assert "available" not in caplog.text + + def test_skips_when_opt_out_env_set(self, state_path, caplog, monkeypatch): + monkeypatch.setenv("CONTREE_NO_UPDATE_CHECK", "1") + seed_state( + state_path, + { + "last_check": datetime.now(timezone.utc).isoformat(), + "latest_version": "0.5.0", + "current_version": "0.4.0", + }, + ) + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with caplog.at_level(logging.WARNING, logger="contree_cli.update_check"): + checker.check() + assert "available" not in caplog.text + + def test_warns_when_cache_indicates_outdated(self, state_path, caplog): + seed_state( + state_path, + { + "last_check": datetime.now(timezone.utc).isoformat(), + "latest_version": "0.5.0", + "current_version": "0.4.0", + }, + ) + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with caplog.at_level(logging.WARNING, logger="contree_cli.update_check"): + checker.check() + assert "0.5.0" in caplog.text + assert "0.4.0" in caplog.text + + def test_no_warning_when_up_to_date(self, state_path, caplog): + seed_state( + state_path, + { + "last_check": datetime.now(timezone.utc).isoformat(), + "latest_version": "0.5.0", + "current_version": "0.5.0", + }, + ) + checker = UpdateChecker(state_path=state_path, current_version="0.5.0") + with caplog.at_level(logging.WARNING, logger="contree_cli.update_check"): + checker.check() + assert "available" not in caplog.text + + def test_no_state_file_silently_returns(self, state_path, caplog): + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with caplog.at_level(logging.WARNING, logger="contree_cli.update_check"): + checker.check() + assert "available" not in caplog.text + + def test_check_does_not_touch_network(self, state_path): + seed_state( + state_path, + { + "last_check": datetime.now(timezone.utc).isoformat(), + "latest_version": "0.5.0", + "current_version": "0.4.0", + }, + ) + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with patch.object(checker, "fetch_latest_version") as fetch: + checker.check() + fetch.assert_not_called() + + +class TestFetchLatestVersion: + @staticmethod + def fake_response(body: bytes): + class FakeResponse: + def read(self): + return body + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + return FakeResponse() + + def test_returns_version_on_success(self, tmp_path): + checker = UpdateChecker(state_path=tmp_path / "v.json", current_version="0") + body = json.dumps({"info": {"version": "1.2.3"}}).encode() + with patch("urllib.request.urlopen", return_value=self.fake_response(body)): + assert checker.fetch_latest_version() == "1.2.3" + + def test_returns_none_on_exception(self, tmp_path): + checker = UpdateChecker(state_path=tmp_path / "v.json", current_version="0") + with patch("urllib.request.urlopen", side_effect=OSError("boom")): + assert checker.fetch_latest_version() is None + + def test_returns_none_on_unexpected_payload(self, tmp_path): + checker = UpdateChecker(state_path=tmp_path / "v.json", current_version="0") + with patch("urllib.request.urlopen", return_value=self.fake_response(b"[]")): + assert checker.fetch_latest_version() is None From d6e0e7a35dac9ccf13d1f452ac3daac02ea13017 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Fri, 8 May 2026 15:04:28 +0200 Subject: [PATCH 4/8] simplify checker --- contree_cli/__main__.py | 9 +- contree_cli/update_check.py | 112 +++++++-------- tests/test_update_check.py | 265 +++++++++++++++--------------------- 3 files changed, 165 insertions(+), 221 deletions(-) diff --git a/contree_cli/__main__.py b/contree_cli/__main__.py index 0c8795d..8d6e741 100644 --- a/contree_cli/__main__.py +++ b/contree_cli/__main__.py @@ -33,7 +33,14 @@ def main() -> None: args = parser.parse_args() setup_logging(level=getattr(logging, args.log_level.upper(), logging.INFO)) - checker.check() + if not checker.is_latest(): + log.warning( + "A new version of contree-cli is available: %s (installed: %s)." + " Upgrade with `uv tool install -U contree-cli` or" + " `pip install -U contree-cli`.", + checker.latest_version, + checker.current_version, + ) config_mod.CONFIG_FILE = args.config_path config_mod.CONFIG_DIR = args.config_path.parent diff --git a/contree_cli/update_check.py b/contree_cli/update_check.py index 2d9f0cf..137fbdc 100644 --- a/contree_cli/update_check.py +++ b/contree_cli/update_check.py @@ -15,8 +15,8 @@ from __future__ import annotations import json -import logging import os +import re import urllib.request from contextlib import suppress from datetime import datetime, timedelta, timezone @@ -25,8 +25,6 @@ from contree_cli import config from contree_cli.client import CLI_USER_AGENT, cli_version -log = logging.getLogger(__name__) - class UpdateChecker: """Encapsulates the state file + PyPI probe + outdated-version warning. @@ -40,6 +38,7 @@ class UpdateChecker: NETWORK_TIMEOUT = 2.0 OPT_OUT_ENV = "CONTREE_NO_UPDATE_CHECK" STATE_PATH = config.CONTREE_HOME / "cli" / "version_check.json" + VERSION_REGEX = re.compile(r"[^\d.]") def __init__( self, @@ -49,40 +48,26 @@ def __init__( ) -> None: self.state_path = state_path self.current_version = current_version + self.latest_version: str | None = None def read_state(self) -> dict[str, str]: - if not self.state_path.exists(): - return {} try: - data = json.loads(self.state_path.read_text()) - except (json.JSONDecodeError, OSError): + with self.state_path.open() as f: + data = json.load(f) + assert isinstance(data, dict) + except Exception: return {} - return data if isinstance(data, dict) else {} + return data def write_state(self, state: dict[str, str]) -> None: with suppress(OSError): self.state_path.parent.mkdir(parents=True, exist_ok=True) - self.state_path.write_text(json.dumps(state, indent=2) + "\n") + self.state_path.write_text(json.dumps(state, indent=1)) - @staticmethod - def parse_version(value: str) -> tuple[int, ...]: - """Parse a PEP-440-ish version into a comparable tuple of ints. - - Strips non-numeric suffixes per component (e.g. ``0a1`` -> ``0``) - so pre-releases compare as their numeric prefix. Good enough to - decide "is X newer than Y" for normal releases. - """ - parts: list[int] = [] - for component in value.split("."): - digits = "" - for ch in component: - if ch.isdigit(): - digits += ch - else: - break - if digits: - parts.append(int(digits)) - return tuple(parts) + def parse_version(self, value: str) -> tuple[int, ...]: + return tuple( + map(int, filter(None, self.VERSION_REGEX.sub("", value).split("."))) + ) def fetch_latest_version(self) -> str | None: try: @@ -105,52 +90,48 @@ def fetch_latest_version(self) -> str | None: version = info.get("version") return version if isinstance(version, str) else None - def warn_if_outdated(self, latest: str) -> None: - if self.parse_version(latest) > self.parse_version(self.current_version): - log.warning( - "A new version of contree-cli is available: %s (installed: %s)." - " Upgrade with `uv tool install -U contree-cli` or" - " `pip install -U contree-cli`.", - latest, - self.current_version, - ) - @property def enabled(self) -> bool: return ( not os.environ.get(self.OPT_OUT_ENV) and self.current_version != "editable" ) - def is_state_fresh(self) -> bool: - """True if the cache file's mtime is within ``CHECK_INTERVAL``. - - Uses ``Path.stat()`` (works on Windows and POSIX) to avoid - reading/parsing the JSON on the hot path. Missing files, - permission errors, or any OSError is treated as "not fresh". - """ + def is_cache_fresh(self, state: dict[str, str]) -> bool: + """True if ``state['last_check']`` is within ``CHECK_INTERVAL``.""" + last_check_str = state.get("last_check") + if not isinstance(last_check_str, str): + return False try: - mtime = self.state_path.stat().st_mtime - except OSError: + last_check = datetime.fromisoformat(last_check_str) + except ValueError: return False - age_seconds = datetime.now(timezone.utc).timestamp() - mtime - return age_seconds < self.CHECK_INTERVAL.total_seconds() + return datetime.now(timezone.utc) - last_check < self.CHECK_INTERVAL def refresh(self) -> None: - """Refresh cached PyPI state, rate-limited to ``CHECK_INTERVAL``. + """Read the cache once, refetch from PyPI if stale. - Decides freshness from the cache file's mtime to avoid the - JSON read+parse on the common case. If stale, probes PyPI and - rewrites ``state_path``. + Populates ``self.latest_version`` with whatever we know after + this call (cached value, freshly fetched value, or ``None``). + :meth:`check` then decides whether to log based purely on + in-memory state — no further file IO. """ if not self.enabled: return - if self.is_state_fresh(): + + state = self.read_state() + cached = state.get("latest_version") + if isinstance(cached, str): + self.latest_version = cached + + if self.is_cache_fresh(state): return latest = self.fetch_latest_version() if latest is None: + # Network failed; keep whatever was cached. return + self.latest_version = latest self.write_state( { "last_check": datetime.now(timezone.utc).isoformat(), @@ -159,16 +140,17 @@ def refresh(self) -> None: } ) - def check(self) -> None: - """Log a warning if cached ``latest_version`` is newer than current. + def is_latest(self) -> bool: + """Return True if the installed version is at or ahead of the cached + ``latest_version``. - Pure read; never touches the network or rewrites state. Pair - with :meth:`refresh` to first ensure the cache is up to date. + Returns True when checks are disabled or ``latest_version`` is + unknown so callers default to "no warning" in those cases. Pure + decision based on in-memory state populated by :meth:`refresh`; + never touches the network or filesystem. """ - if not self.enabled: - return - - state = self.read_state() - cached = state.get("latest_version") - if isinstance(cached, str): - self.warn_if_outdated(cached) + if not self.enabled or self.latest_version is None: + return True + return self.parse_version(self.current_version) >= self.parse_version( + self.latest_version, + ) diff --git a/tests/test_update_check.py b/tests/test_update_check.py index dcdb58b..3fe434b 100644 --- a/tests/test_update_check.py +++ b/tests/test_update_check.py @@ -2,7 +2,6 @@ import json import logging -import os from datetime import datetime, timedelta, timezone from unittest.mock import patch @@ -16,22 +15,13 @@ def read_json(path): return json.loads(path.read_text()) -def seed_state(path, payload, *, mtime: float | None = None): +def seed_state(path, payload): path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(payload)) - if mtime is not None: - try: - os.utime(path, (mtime, mtime)) - except (OSError, NotImplementedError) as exc: - pytest.skip(f"os.utime is not supported on this platform: {exc}") def fake_now(offset: timedelta): - """Patch update_check.datetime.now to return real-now + offset. - - Use this when tests need to simulate clock advancement on platforms - where ``os.utime`` may not work (e.g. some Windows configurations). - """ + """Patch update_check.datetime.now to return real-now + offset.""" pinned = datetime.now(timezone.utc) + offset class FakeDatetime(datetime): @@ -53,14 +43,17 @@ class TestParseVersion: [ ("1.2.3", (1, 2, 3)), ("0.0.1", (0, 0, 1)), - ("0.4.2a1", (0, 4, 2)), + ("0.4.2a1", (0, 4, 21)), ("1", (1,)), ("", ()), ("1.x.3", (1, 3)), + ("v1.2.3", (1, 2, 3)), + ("1.0.0-rc.1", (1, 0, 0, 1)), ], ) def test_cases(self, value, expected): - assert UpdateChecker.parse_version(value) == expected + checker = UpdateChecker(state_path="/dev/null", current_version="0") + assert checker.parse_version(value) == expected class TestEnabled: @@ -78,12 +71,37 @@ def test_enabled_normal(self, state_path): assert checker.enabled is True +class TestIsCacheFresh: + def test_returns_false_for_empty_state(self): + checker = UpdateChecker(state_path="/dev/null", current_version="0") + assert checker.is_cache_fresh({}) is False + + def test_returns_false_for_missing_last_check(self): + checker = UpdateChecker(state_path="/dev/null", current_version="0") + assert checker.is_cache_fresh({"latest_version": "1.0.0"}) is False + + def test_returns_false_for_unparseable_last_check(self): + checker = UpdateChecker(state_path="/dev/null", current_version="0") + assert checker.is_cache_fresh({"last_check": "not-a-timestamp"}) is False + + def test_returns_true_for_recent_last_check(self): + checker = UpdateChecker(state_path="/dev/null", current_version="0") + recent = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + assert checker.is_cache_fresh({"last_check": recent}) is True + + def test_returns_false_for_old_last_check(self): + checker = UpdateChecker(state_path="/dev/null", current_version="0") + old = (datetime.now(timezone.utc) - timedelta(days=2)).isoformat() + assert checker.is_cache_fresh({"last_check": old}) is False + + class TestRefresh: def test_skips_in_editable_mode(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="editable") with patch.object(checker, "fetch_latest_version") as fetch: checker.refresh() fetch.assert_not_called() + assert checker.latest_version is None def test_skips_when_opt_out_env_set(self, state_path, monkeypatch): monkeypatch.setenv("CONTREE_NO_UPDATE_CHECK", "1") @@ -91,23 +109,25 @@ def test_skips_when_opt_out_env_set(self, state_path, monkeypatch): with patch.object(checker, "fetch_latest_version") as fetch: checker.refresh() fetch.assert_not_called() + assert checker.latest_version is None def test_fetches_and_writes_when_no_cache(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") with patch.object(checker, "fetch_latest_version", return_value="0.4.1"): checker.refresh() + assert checker.latest_version == "0.4.1" data = read_json(state_path) assert data["latest_version"] == "0.4.1" assert data["current_version"] == "0.4.0" assert "last_check" in data def test_skips_network_within_interval(self, state_path): - recent = datetime.now(timezone.utc) - timedelta(hours=1) + recent = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() seed_state( state_path, { - "last_check": recent.isoformat(), - "latest_version": "0.4.0", + "last_check": recent, + "latest_version": "0.5.0", "current_version": "0.4.0", }, ) @@ -115,17 +135,18 @@ def test_skips_network_within_interval(self, state_path): with patch.object(checker, "fetch_latest_version") as fetch: checker.refresh() fetch.assert_not_called() + # Cached value loaded into self. + assert checker.latest_version == "0.5.0" def test_refetches_after_interval_expires(self, state_path): - old_ts = (datetime.now(timezone.utc) - timedelta(days=2)).timestamp() + old = (datetime.now(timezone.utc) - timedelta(days=2)).isoformat() seed_state( state_path, { - "last_check": "2026-01-01T00:00:00+00:00", + "last_check": old, "latest_version": "0.4.0", "current_version": "0.4.0", }, - mtime=old_ts, ) checker = UpdateChecker(state_path=state_path, current_version="0.4.0") with patch.object( @@ -133,52 +154,15 @@ def test_refetches_after_interval_expires(self, state_path): ) as fetch: checker.refresh() fetch.assert_called_once() + assert checker.latest_version == "0.4.5" assert read_json(state_path)["latest_version"] == "0.4.5" - def test_swallows_network_failure(self, state_path): - checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - with patch.object(checker, "fetch_latest_version", return_value=None): - checker.refresh() - assert not state_path.exists() - - def test_corrupt_cache_file_is_recoverable_after_interval(self, state_path): - """Corrupt JSON is silently overwritten once mtime expires.""" - state_path.parent.mkdir(parents=True, exist_ok=True) - state_path.write_text("{not json") - old_ts = (datetime.now(timezone.utc) - timedelta(days=2)).timestamp() - os.utime(state_path, (old_ts, old_ts)) - - checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - with patch.object( - checker, "fetch_latest_version", return_value="0.4.1" - ) as fetch: - checker.refresh() - fetch.assert_called_once() - assert read_json(state_path)["latest_version"] == "0.4.1" - - def test_refresh_does_not_log_warning(self, state_path, caplog): - checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - with ( - caplog.at_level(logging.WARNING, logger="contree_cli.update_check"), - patch.object(checker, "fetch_latest_version", return_value="0.5.0"), - ): - checker.refresh() - assert "available" not in caplog.text - - -class TestRefreshClockMock: - """Same logic as TestRefresh, but driven by mocked ``datetime.now``. - - Works on every platform — including Windows configurations where - ``os.utime`` may silently no-op or raise — because nothing depends - on the filesystem's mtime resolution or write permissions. - """ - - def test_refetches_after_simulated_interval(self, state_path): + def test_refetches_via_clock_mock(self, state_path): + """Cross-platform verification using mocked clock.""" seed_state( state_path, { - "last_check": "2026-01-01T00:00:00+00:00", + "last_check": datetime.now(timezone.utc).isoformat(), "latest_version": "0.4.0", "current_version": "0.4.0", }, @@ -192,131 +176,102 @@ def test_refetches_after_simulated_interval(self, state_path): ): checker.refresh() fetch.assert_called_once() - assert read_json(state_path)["latest_version"] == "0.4.5" + assert checker.latest_version == "0.4.5" - def test_skips_within_simulated_interval(self, state_path): + def test_network_failure_keeps_cached_value(self, state_path): + old = (datetime.now(timezone.utc) - timedelta(days=2)).isoformat() seed_state( state_path, { - "last_check": "x", + "last_check": old, "latest_version": "0.4.0", "current_version": "0.4.0", }, ) checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - with ( - fake_now(timedelta(hours=1)), - patch.object(checker, "fetch_latest_version") as fetch, - ): + with patch.object(checker, "fetch_latest_version", return_value=None): checker.refresh() - fetch.assert_not_called() - + assert checker.latest_version == "0.4.0" + # State file untouched (no rewrite). + assert read_json(state_path)["latest_version"] == "0.4.0" -class TestIsStateFresh: - def test_returns_true_for_freshly_written_file(self, state_path): - state_path.parent.mkdir(parents=True, exist_ok=True) - state_path.write_text("{}") + def test_network_failure_with_no_cache_leaves_latest_none(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - assert checker.is_state_fresh() is True - - def test_returns_false_when_missing(self, state_path): - checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - assert checker.is_state_fresh() is False + with patch.object(checker, "fetch_latest_version", return_value=None): + checker.refresh() + assert checker.latest_version is None + assert not state_path.exists() - def test_returns_false_when_clock_advanced_past_interval(self, state_path): - """Cross-platform: simulate 2-day clock advance, no os.utime.""" + def test_corrupt_cache_is_overwritten(self, state_path): state_path.parent.mkdir(parents=True, exist_ok=True) - state_path.write_text("{}") + state_path.write_text("{not json") checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - with fake_now(timedelta(days=2)): - assert checker.is_state_fresh() is False + with patch.object( + checker, "fetch_latest_version", return_value="0.4.1" + ) as fetch: + checker.refresh() + fetch.assert_called_once() + assert checker.latest_version == "0.4.1" + assert read_json(state_path)["latest_version"] == "0.4.1" - def test_returns_true_when_clock_advanced_within_interval(self, state_path): - state_path.parent.mkdir(parents=True, exist_ok=True) - state_path.write_text("{}") + def test_refresh_does_not_log_warning(self, state_path, caplog): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - with fake_now(timedelta(hours=23)): - assert checker.is_state_fresh() is True + with ( + caplog.at_level(logging.WARNING, logger="contree_cli.update_check"), + patch.object(checker, "fetch_latest_version", return_value="0.5.0"), + ): + checker.refresh() + assert "available" not in caplog.text -class TestCheck: - def test_skips_in_editable_mode(self, state_path, caplog): - seed_state( - state_path, - { - "last_check": datetime.now(timezone.utc).isoformat(), - "latest_version": "9.9.9", - "current_version": "editable", - }, - ) +class TestIsLatest: + def test_returns_true_in_editable_mode(self, state_path): + """Editable installs are always considered up to date.""" checker = UpdateChecker(state_path=state_path, current_version="editable") - with caplog.at_level(logging.WARNING, logger="contree_cli.update_check"): - checker.check() - assert "available" not in caplog.text + checker.latest_version = "9.9.9" + assert checker.is_latest() is True - def test_skips_when_opt_out_env_set(self, state_path, caplog, monkeypatch): + def test_returns_true_when_opt_out_env_set(self, state_path, monkeypatch): monkeypatch.setenv("CONTREE_NO_UPDATE_CHECK", "1") - seed_state( - state_path, - { - "last_check": datetime.now(timezone.utc).isoformat(), - "latest_version": "0.5.0", - "current_version": "0.4.0", - }, - ) checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - with caplog.at_level(logging.WARNING, logger="contree_cli.update_check"): - checker.check() - assert "available" not in caplog.text + checker.latest_version = "9.9.9" + assert checker.is_latest() is True - def test_warns_when_cache_indicates_outdated(self, state_path, caplog): - seed_state( - state_path, - { - "last_check": datetime.now(timezone.utc).isoformat(), - "latest_version": "0.5.0", - "current_version": "0.4.0", - }, - ) + def test_returns_false_when_outdated(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - with caplog.at_level(logging.WARNING, logger="contree_cli.update_check"): - checker.check() - assert "0.5.0" in caplog.text - assert "0.4.0" in caplog.text + checker.latest_version = "0.5.0" + assert checker.is_latest() is False - def test_no_warning_when_up_to_date(self, state_path, caplog): - seed_state( - state_path, - { - "last_check": datetime.now(timezone.utc).isoformat(), - "latest_version": "0.5.0", - "current_version": "0.5.0", - }, - ) + def test_returns_true_when_up_to_date(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.5.0") - with caplog.at_level(logging.WARNING, logger="contree_cli.update_check"): - checker.check() - assert "available" not in caplog.text + checker.latest_version = "0.5.0" + assert checker.is_latest() is True - def test_no_state_file_silently_returns(self, state_path, caplog): + def test_returns_true_when_latest_unknown(self, state_path): + """Unknown latest defaults to "up to date" so callers don't warn.""" checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - with caplog.at_level(logging.WARNING, logger="contree_cli.update_check"): - checker.check() - assert "available" not in caplog.text + assert checker.is_latest() is True - def test_check_does_not_touch_network(self, state_path): - seed_state( - state_path, - { - "last_check": datetime.now(timezone.utc).isoformat(), - "latest_version": "0.5.0", - "current_version": "0.4.0", - }, - ) + def test_returns_true_when_current_is_newer(self, state_path): + """Pre-release/dev install ahead of pypi is "latest".""" + checker = UpdateChecker(state_path=state_path, current_version="0.6.0") + checker.latest_version = "0.5.0" + assert checker.is_latest() is True + + def test_does_not_touch_filesystem(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - with patch.object(checker, "fetch_latest_version") as fetch: - checker.check() - fetch.assert_not_called() + checker.latest_version = "0.5.0" + with patch.object(UpdateChecker, "read_state") as read: + checker.is_latest() + read.assert_not_called() + + def test_does_not_log(self, state_path, caplog): + """is_latest() is a pure predicate; logging is the caller's job.""" + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + checker.latest_version = "0.5.0" + with caplog.at_level(logging.WARNING, logger="contree_cli.update_check"): + assert checker.is_latest() is False + assert caplog.records == [] class TestFetchLatestVersion: From 8a1b4597e979486222d172cd38171e852af3a1a8 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Fri, 8 May 2026 15:10:41 +0200 Subject: [PATCH 5/8] remove mtime-related tests --- README.md | 4 ++-- contree_cli/__main__.py | 10 ++++++---- contree_cli/agent.md | 10 ++++++---- contree_cli/cli/images.py | 3 ++- contree_cli/cli/ps.py | 10 ++++++++-- contree_cli/types.py | 11 ++++++++++ contree_cli/update_check.py | 4 +--- tests/test_update_check.py | 40 ++++++------------------------------- 8 files changed, 42 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 77ff228..d53e629 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ contree --help contree auth ``` -You'll be prompted to enter your API token and project ID. The CLI verifies the token and saves credentials to `~/.config/contree-cli/config.ini`. +You'll be prompted to enter your API token and project ID. The CLI verifies the token and saves credentials to `~/.config/contree/auth.ini` (override the data directory via `CONTREE_HOME`). If `--token`/`--url`/`--project` flags are omitted, `contree auth` reads `CONTREE_TOKEN` (or `NEBIUS_API_KEY`), `CONTREE_URL`, and `CONTREE_PROJECT` (or `NEBIUS_AI_PROJECT`) from the environment instead of prompting. These variables are read only during registration; runtime commands use the saved profile only. @@ -289,7 +289,7 @@ Read at runtime (any command): | Variable | Purpose | |---|---| -| `CONTREE_HOME` | Data directory (default `~/.config/contree-cli`) | +| `CONTREE_HOME` | Data directory (default `$XDG_CONFIG_HOME/contree`, or `~/.config/contree`) | | `CONTREE_PROFILE` | Active profile name (selects which profile commands use) | | `CONTREE_SESSION` | Explicit session key (for multi-terminal workflows) | | `CONTREE_SESSION_DB` | Path to session SQLite database | diff --git a/contree_cli/__main__.py b/contree_cli/__main__.py index 8d6e741..9519b29 100644 --- a/contree_cli/__main__.py +++ b/contree_cli/__main__.py @@ -19,10 +19,6 @@ def main() -> None: - checker = UpdateChecker() - with suppress(Exception): - checker.refresh() - if len(sys.argv) == 1: parser.print_help() exit(0) @@ -33,6 +29,12 @@ def main() -> None: args = parser.parse_args() setup_logging(level=getattr(logging, args.log_level.upper(), logging.INFO)) + # Update check runs only after argparse so it skips --help / --version + # / no-command paths and so the warning respects --log-level. refresh() + # is best-effort; check() is a pure predicate. + checker = UpdateChecker() + with suppress(Exception): + checker.refresh() if not checker.is_latest(): log.warning( "A new version of contree-cli is available: %s (installed: %s)." diff --git a/contree_cli/agent.md b/contree_cli/agent.md index 5539b2d..ae543bf 100644 --- a/contree_cli/agent.md +++ b/contree_cli/agent.md @@ -328,10 +328,12 @@ Per-command -p is useful for cross-project operations: contree -p project-a images --prefix=base contree -p project-b images import tag:base:latest -Data directory: ~/.config/contree-cli/ - config.ini profile credentials - sessions-{profile}.db per-profile sessions, history, cache - skills.db installed agent skill registry +Data directory: $XDG_CONFIG_HOME/contree/ (or ~/.config/contree/) + auth.ini profile credentials (mode 0600) + cli.ini optional CLI defaults + cli/sessions/{profile}.db per-profile sessions, history, cache + cli/skills.db installed agent skill registry + cli/version_check.json cached PyPI update-check state Environment variables: CONTREE_HOME data directory override diff --git a/contree_cli/cli/images.py b/contree_cli/cli/images.py index ed6df5c..ccd40a3 100644 --- a/contree_cli/cli/images.py +++ b/contree_cli/cli/images.py @@ -25,6 +25,7 @@ isoformat_datetime, parse_datetime, parse_interval, + positive_int, ) logger = logging.getLogger(__name__) @@ -194,7 +195,7 @@ def _add_list_args(p: argparse.ArgumentParser) -> None: ) p.add_argument( *FLAGS["limit"], - type=int, + type=positive_int, default=LIMIT_DEFAULT, help="Stop after this many images and warn if more are available", ) diff --git a/contree_cli/cli/ps.py b/contree_cli/cli/ps.py index cc8bbd6..9b28fb4 100644 --- a/contree_cli/cli/ps.py +++ b/contree_cli/cli/ps.py @@ -19,7 +19,13 @@ from contree_cli import CLIENT, FORMATTER, ArgumentsProtocol, SetupResult from contree_cli.output import OutputFormatter -from contree_cli.types import FLAGS, isoformat_datetime, parse_datetime, parse_interval +from contree_cli.types import ( + FLAGS, + isoformat_datetime, + parse_datetime, + parse_interval, + positive_int, +) logger = logging.getLogger(__name__) @@ -102,7 +108,7 @@ def setup_parser(p: argparse.ArgumentParser) -> SetupResult: p.add_argument( *FLAGS["show_max"], - type=int, + type=positive_int, default=1000, help=( "Show at most this many operations, useful" diff --git a/contree_cli/types.py b/contree_cli/types.py index 17c8fb3..3904e08 100644 --- a/contree_cli/types.py +++ b/contree_cli/types.py @@ -90,6 +90,17 @@ def _get_help_string(self, action: argparse.Action) -> str: return help_text +def positive_int(value: str) -> int: + """argparse type for flags that must be at least 1.""" + try: + n = int(value) + except ValueError as exc: + raise argparse.ArgumentTypeError(f"invalid int value: {value!r}") from exc + if n < 1: + raise argparse.ArgumentTypeError(f"must be >= 1, got {n}") + return n + + def get_command_docs(setup_fn: SetupFn) -> tuple[str | None, str | None]: """Extract description and epilog from the module that defines *setup_fn*. diff --git a/contree_cli/update_check.py b/contree_cli/update_check.py index 137fbdc..6043e08 100644 --- a/contree_cli/update_check.py +++ b/contree_cli/update_check.py @@ -92,9 +92,7 @@ def fetch_latest_version(self) -> str | None: @property def enabled(self) -> bool: - return ( - not os.environ.get(self.OPT_OUT_ENV) and self.current_version != "editable" - ) + return self.OPT_OUT_ENV not in os.environ and self.current_version != "editable" def is_cache_fresh(self, state: dict[str, str]) -> bool: """True if ``state['last_check']`` is within ``CHECK_INTERVAL``.""" diff --git a/tests/test_update_check.py b/tests/test_update_check.py index 3fe434b..150f676 100644 --- a/tests/test_update_check.py +++ b/tests/test_update_check.py @@ -7,7 +7,6 @@ import pytest -from contree_cli import update_check from contree_cli.update_check import UpdateChecker @@ -20,18 +19,6 @@ def seed_state(path, payload): path.write_text(json.dumps(payload)) -def fake_now(offset: timedelta): - """Patch update_check.datetime.now to return real-now + offset.""" - pinned = datetime.now(timezone.utc) + offset - - class FakeDatetime(datetime): - @classmethod - def now(cls, tz=None): - return pinned if tz is not None else pinned.replace(tzinfo=None) - - return patch.object(update_check, "datetime", FakeDatetime) - - @pytest.fixture() def state_path(tmp_path): return tmp_path / "version_check.json" @@ -66,6 +53,12 @@ def test_disabled_when_opt_out_env_set(self, state_path, monkeypatch): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") assert checker.enabled is False + def test_disabled_when_opt_out_env_set_to_empty(self, state_path, monkeypatch): + """Presence (any value, even empty) opts out, per documented contract.""" + monkeypatch.setenv("CONTREE_NO_UPDATE_CHECK", "") + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + assert checker.enabled is False + def test_enabled_normal(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") assert checker.enabled is True @@ -157,27 +150,6 @@ def test_refetches_after_interval_expires(self, state_path): assert checker.latest_version == "0.4.5" assert read_json(state_path)["latest_version"] == "0.4.5" - def test_refetches_via_clock_mock(self, state_path): - """Cross-platform verification using mocked clock.""" - seed_state( - state_path, - { - "last_check": datetime.now(timezone.utc).isoformat(), - "latest_version": "0.4.0", - "current_version": "0.4.0", - }, - ) - checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - with ( - fake_now(timedelta(days=2)), - patch.object( - checker, "fetch_latest_version", return_value="0.4.5" - ) as fetch, - ): - checker.refresh() - fetch.assert_called_once() - assert checker.latest_version == "0.4.5" - def test_network_failure_keeps_cached_value(self, state_path): old = (datetime.now(timezone.utc) - timedelta(days=2)).isoformat() seed_state( From 17cbf70ae8fed10c523666af1a0c403d1169dd08 Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Fri, 8 May 2026 15:21:36 +0200 Subject: [PATCH 6/8] reformat code --- contree_cli/cli/auth.py | 8 ++-- contree_cli/cli/ps.py | 4 +- tests/test_auth.py | 91 ++++++++++++++++++++--------------------- 3 files changed, 51 insertions(+), 52 deletions(-) diff --git a/contree_cli/cli/auth.py b/contree_cli/cli/auth.py index 0aa17ce..9329563 100644 --- a/contree_cli/cli/auth.py +++ b/contree_cli/cli/auth.py @@ -175,7 +175,7 @@ def setup_parser(p: argparse.ArgumentParser) -> SetupResult: return cmd_auth, AuthArgs -def _env_fallback(names: tuple[str, ...], *, what: str) -> str | None: +def env_fallback(names: tuple[str, ...], *, what: str) -> str | None: for name in names: value = os.environ.get(name) if value: @@ -211,7 +211,7 @@ def cmd_auth(args: AuthArgs) -> int | None: return 1 # Token: --token > CONTREE_TOKEN > NEBIUS_API_KEY > interactive prompt - token = args.token or _env_fallback( + token = args.token or env_fallback( ("CONTREE_TOKEN", "NEBIUS_API_KEY"), what="token", ) @@ -219,7 +219,7 @@ def cmd_auth(args: AuthArgs) -> int | None: token = getpass.getpass("Token: ") # URL: --url > CONTREE_URL > type-specific default > interactive prompt - url = args.url or _env_fallback(("CONTREE_URL",), what="URL") + url = args.url or env_fallback(("CONTREE_URL",), what="URL") if url is None: if args.auth_type == AuthType.IAM: url = Config.DEFAULT_IAM_URL @@ -232,7 +232,7 @@ def cmd_auth(args: AuthArgs) -> int | None: # Project (IAM only): --project > CONTREE_PROJECT > NEBIUS_AI_PROJECT > prompt project: str | None = None if args.auth_type == AuthType.IAM: - project = args.project or _env_fallback( + project = args.project or env_fallback( ("CONTREE_PROJECT", "NEBIUS_AI_PROJECT"), what="project", ) diff --git a/contree_cli/cli/ps.py b/contree_cli/cli/ps.py index 9b28fb4..b3908f5 100644 --- a/contree_cli/cli/ps.py +++ b/contree_cli/cli/ps.py @@ -119,7 +119,7 @@ def setup_parser(p: argparse.ArgumentParser) -> SetupResult: return cmd_ps, PsArgs -def _emit_op(formatter: OutputFormatter, op: dict[str, Any], *, quiet: bool) -> None: +def emit_op(formatter: OutputFormatter, op: dict[str, Any], *, quiet: bool) -> None: row = dict( uuid=op["uuid"], status=op["status"], @@ -179,7 +179,7 @@ def cmd_ps(args: PsArgs) -> None: if limit is not None and emitted >= limit: hit_limit = True break - _emit_op(formatter, op, quiet=args.quiet) + emit_op(formatter, op, quiet=args.quiet) emitted += 1 if hit_limit: break diff --git a/tests/test_auth.py b/tests/test_auth.py index 941a5b7..950d622 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,4 +1,5 @@ import argparse +import json from contextlib import contextmanager from contextvars import copy_context from unittest.mock import patch @@ -46,26 +47,24 @@ def _make_iam_args(**kwargs) -> AuthArgs: return AuthArgs(**defaults) -def _whoami_body(*, permissions: dict[str, bool] | None = None) -> bytes: +def whoami_body(*, permissions: dict[str, bool] | None = None) -> bytes: body = { "token_uuid": "00000000-0000-0000-0000-000000000000", "token_expiration": None, "permissions": {"list": True} if permissions is None else permissions, "operations_stat": {}, } - import json as _json - - return _json.dumps(body).encode() + return json.dumps(body).encode() @contextmanager -def _mock_whoami(status=200, *, body: bytes | None = None): +def mock_whoami(status=200, *, body: bytes | None = None): """Patch client_from_profile to return a fresh ContreeTestClient per call.""" last_client: list[ContreeTestClient] = [] def factory(profile, timeout=None): # type: ignore[no-untyped-def] tc = ContreeTestClient() - tc.respond(status=status, body=body if body is not None else _whoami_body()) + tc.respond(status=status, body=body if body is not None else whoami_body()) last_client.clear() last_client.append(tc) return tc @@ -103,7 +102,7 @@ def test_save_with_token(self, config_dir, caplog): url="https://my.dev", profile="default", ) - with caplog.at_level("INFO"), _mock_whoami(): + with caplog.at_level("INFO"), mock_whoami(): cmd_auth(args) p = Config().resolve() assert p.token == "my_token" @@ -112,16 +111,16 @@ def test_save_with_token(self, config_dir, caplog): assert "auth accepted, profile 'default' saved to ->" in caplog.text def test_logs_updating_for_existing_profile(self, config_dir, caplog): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="old")) caplog.clear() - with caplog.at_level("INFO"), _mock_whoami(): + with caplog.at_level("INFO"), mock_whoami(): cmd_auth(_make_auth_args(token="new", force=True)) assert "Updating token for profile 'default'" in caplog.text def test_save_defaults_profile_and_url(self, config_dir): args = _make_auth_args(token="tok") - with _mock_whoami(): + with mock_whoami(): cmd_auth(args) p = Config().resolve() assert p.token == "tok" @@ -129,7 +128,7 @@ def test_save_defaults_profile_and_url(self, config_dir): def test_save_named_profile(self, config_dir, caplog): args = _make_auth_args(token="tok", profile="staging") - with caplog.at_level("INFO"), _mock_whoami(): + with caplog.at_level("INFO"), mock_whoami(): cmd_auth(args) cfg = Config() cfg.switch("staging") @@ -139,13 +138,13 @@ def test_save_named_profile(self, config_dir, caplog): assert "auth accepted, profile 'staging' saved to ->" in caplog.text def test_save_jwt_stores_type(self, config_dir): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="tok")) p = Config().resolve() assert p.auth_type == AuthType.JWT def test_save_iam_stores_type_and_project(self, config_dir): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_iam_args(token="tok")) p = Config().resolve() assert p.auth_type == AuthType.IAM @@ -173,7 +172,7 @@ def test_prompts_when_no_token_jwt(self, config_dir): "contree_cli.cli.auth.getpass.getpass", return_value="prompted_token", ), - _mock_whoami(), + mock_whoami(), ): cmd_auth(args) p = Config().resolve() @@ -193,14 +192,14 @@ def test_from_args_defaults_to_iam(self): class TestAuthVerify: def test_bad_token_rejected(self, config_dir, caplog): args = _make_auth_args(token="bad") - with caplog.at_level("ERROR"), _mock_whoami(status=401): + with caplog.at_level("ERROR"), mock_whoami(status=401): rc = cmd_auth(args) assert rc == 1 assert "Profile not changed" in caplog.text def test_bad_token_does_not_save(self, config_dir): args = _make_auth_args(token="bad") - with _mock_whoami(status=401): + with mock_whoami(status=401): cmd_auth(args) p = Config().resolve() assert p.token is None @@ -209,7 +208,7 @@ def test_no_list_permission_warns_but_saves(self, config_dir, caplog): args = _make_auth_args(token="tok") with ( caplog.at_level("WARNING"), - _mock_whoami(body=_whoami_body(permissions={"list": False})), + mock_whoami(body=whoami_body(permissions={"list": False})), ): rc = cmd_auth(args) assert rc is None @@ -221,7 +220,7 @@ def test_no_list_permission_warning_includes_project(self, config_dir, caplog): args = _make_iam_args(token="tok", project="aiproject-restricted") with ( caplog.at_level("WARNING"), - _mock_whoami(body=_whoami_body(permissions={"list": False})), + mock_whoami(body=whoami_body(permissions={"list": False})), ): cmd_auth(args) assert "aiproject-restricted" in caplog.text @@ -229,7 +228,7 @@ def test_no_list_permission_warning_includes_project(self, config_dir, caplog): def test_missing_permissions_field_warns(self, config_dir, caplog): args = _make_auth_args(token="tok") body = b'{"token_uuid":"x","token_expiration":null,"operations_stat":{}}' - with caplog.at_level("WARNING"), _mock_whoami(body=body): + with caplog.at_level("WARNING"), mock_whoami(body=body): rc = cmd_auth(args) assert rc is None assert "sandboxes are disabled" in caplog.text @@ -237,20 +236,20 @@ def test_missing_permissions_field_warns(self, config_dir, caplog): def test_unparseable_whoami_rejected(self, config_dir, caplog): args = _make_auth_args(token="tok") - with caplog.at_level("ERROR"), _mock_whoami(body=b"not-json"): + with caplog.at_level("ERROR"), mock_whoami(body=b"not-json"): rc = cmd_auth(args) assert rc == 1 assert Config().resolve().token is None def test_success_logs_saved(self, config_dir, caplog): args = _make_auth_args(token="good") - with caplog.at_level("INFO"), _mock_whoami(): + with caplog.at_level("INFO"), mock_whoami(): cmd_auth(args) assert "auth accepted" in caplog.text def test_whoami_called(self, config_dir): args = _make_auth_args(token="tok") - with _mock_whoami() as clients: + with mock_whoami() as clients: cmd_auth(args) tc = clients[0] assert tc.request_count == 1 @@ -265,7 +264,7 @@ def test_whoami_called(self, config_dir): class TestNebius: def test_nebius_api_key_used_as_token(self, config_dir, caplog, monkeypatch): monkeypatch.setenv("NEBIUS_API_KEY", "nebius-tok") - with caplog.at_level("INFO"), _mock_whoami(): + with caplog.at_level("INFO"), mock_whoami(): cmd_auth( AuthArgs( url="https://test.dev", @@ -278,7 +277,7 @@ def test_nebius_api_key_used_as_token(self, config_dir, caplog, monkeypatch): def test_nebius_ai_project_used(self, config_dir, caplog, monkeypatch): monkeypatch.setenv("NEBIUS_AI_PROJECT", "aiproject-neb") - with caplog.at_level("INFO"), _mock_whoami(): + with caplog.at_level("INFO"), mock_whoami(): cmd_auth( AuthArgs( token="tok", @@ -293,7 +292,7 @@ def test_nebius_ai_project_used(self, config_dir, caplog, monkeypatch): def test_both_nebius_vars_skip_all_prompts(self, config_dir, caplog, monkeypatch): monkeypatch.setenv("NEBIUS_API_KEY", "neb-tok") monkeypatch.setenv("NEBIUS_AI_PROJECT", "aiproject-auto") - with caplog.at_level("INFO"), _mock_whoami(): + with caplog.at_level("INFO"), mock_whoami(): cmd_auth(AuthArgs(auth_type=AuthType.IAM, url="https://iam.test")) p = Config().resolve() assert p.token == "neb-tok" @@ -305,7 +304,7 @@ def test_contree_token_used_when_token_omitted( self, config_dir, caplog, monkeypatch ): monkeypatch.setenv("CONTREE_TOKEN", "ctok") - with caplog.at_level("INFO"), _mock_whoami(): + with caplog.at_level("INFO"), mock_whoami(): cmd_auth(AuthArgs(url="https://test.dev", auth_type=AuthType.JWT)) p = Config().resolve() assert p.token == "ctok" @@ -316,7 +315,7 @@ def test_contree_token_preferred_over_nebius_api_key( ): monkeypatch.setenv("CONTREE_TOKEN", "ctok") monkeypatch.setenv("NEBIUS_API_KEY", "ntok") - with caplog.at_level("INFO"), _mock_whoami(): + with caplog.at_level("INFO"), mock_whoami(): cmd_auth(AuthArgs(url="https://test.dev", auth_type=AuthType.JWT)) p = Config().resolve() assert p.token == "ctok" @@ -324,7 +323,7 @@ def test_contree_token_preferred_over_nebius_api_key( def test_contree_url_used_when_url_omitted(self, config_dir, caplog, monkeypatch): monkeypatch.setenv("CONTREE_URL", "https://env-url.dev") monkeypatch.setenv("NEBIUS_API_KEY", "tok") - with caplog.at_level("INFO"), _mock_whoami(): + with caplog.at_level("INFO"), mock_whoami(): cmd_auth(AuthArgs(auth_type=AuthType.JWT)) p = Config().resolve() assert p.url == "https://env-url.dev" @@ -333,7 +332,7 @@ def test_contree_project_used_when_project_omitted( self, config_dir, caplog, monkeypatch ): monkeypatch.setenv("CONTREE_PROJECT", "aiproject-c") - with caplog.at_level("INFO"), _mock_whoami(): + with caplog.at_level("INFO"), mock_whoami(): cmd_auth( AuthArgs( token="tok", @@ -347,7 +346,7 @@ def test_contree_project_used_when_project_omitted( def test_explicit_token_flag_beats_env(self, config_dir, monkeypatch): monkeypatch.setenv("CONTREE_TOKEN", "from-env") - with _mock_whoami(): + with mock_whoami(): cmd_auth( AuthArgs( token="from-flag", @@ -366,7 +365,7 @@ def test_explicit_token_flag_beats_env(self, config_dir, monkeypatch): class TestAuthSwitch: def test_switch_profile(self, config_dir, caplog): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="tok1", profile="default")) cmd_auth(_make_auth_args(token="tok2", profile="staging")) @@ -386,7 +385,7 @@ def test_switch_nonexistent_raises(self, config_dir): class TestAuthProfiles: def test_profiles_show_status(self, config_dir): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="tok-ok", profile="ok")) cmd_auth(_make_auth_args(token="tok-timeout", profile="timeout")) cmd_auth(_make_auth_args(token="tok-error", profile="error")) @@ -404,7 +403,7 @@ def flush(self) -> None: def fake_factory(profile, timeout=None): # type: ignore[no-untyped-def] tc = ContreeTestClient(token=profile.token) if profile.token == "tok-ok": - tc.respond(status=200, body=_whoami_body()) + tc.respond(status=200, body=whoami_body()) elif profile.token == "tok-timeout": def timeout_get(path, params=None): # type: ignore[no-untyped-def] @@ -434,7 +433,7 @@ def error_get(path, params=None): # type: ignore[no-untyped-def] def test_profiles_inactive_status(self, config_dir): """Profile whose token lacks `list` permission is reported as inactive.""" - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="tok", profile="restricted")) rows: list[dict[str, object]] = [] @@ -450,7 +449,7 @@ def fake_factory(profile, timeout=None): # type: ignore[no-untyped-def] tc = ContreeTestClient(token=profile.token) tc.respond( status=200, - body=_whoami_body(permissions={"list": False, "spawn": True}), + body=whoami_body(permissions={"list": False, "spawn": True}), ) return tc @@ -466,7 +465,7 @@ def fake_factory(profile, timeout=None): # type: ignore[no-untyped-def] assert by_name["restricted"]["status"] == "inactive" def test_profiles_offline_skips_probe(self, config_dir): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="tok", profile="offline-test")) Config().switch("offline-test") @@ -488,7 +487,7 @@ def flush(self) -> None: def test_env_profile_marks_active(self, config_dir, monkeypatch): """CONTREE_PROFILE env var overrides active marker in listing.""" - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="t1", profile="default")) cmd_auth(_make_auth_args(token="t2", profile="e2e")) @@ -513,7 +512,7 @@ def flush(self) -> None: def test_env_profile_nonexistent_warns(self, config_dir, monkeypatch, caplog): """CONTREE_PROFILE pointing to missing profile logs a warning.""" - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="tok", profile="default")) monkeypatch.setenv("CONTREE_PROFILE", "ghost") @@ -543,7 +542,7 @@ def flush(self) -> None: class TestAuthOverwrite: def test_overwrite_aborted(self, config_dir, capsys): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="old")) with patch("builtins.input", return_value="n"): rc = cmd_auth(_make_auth_args(token="new")) @@ -553,7 +552,7 @@ def test_overwrite_aborted(self, config_dir, capsys): assert p.token == "old" def test_overwrite_confirmed(self, config_dir): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="old")) with patch("builtins.input", return_value="y"): rc = cmd_auth(_make_auth_args(token="new")) @@ -562,7 +561,7 @@ def test_overwrite_confirmed(self, config_dir): assert p.token == "new" def test_overwrite_force(self, config_dir): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="old")) rc = cmd_auth(_make_auth_args(token="new", force=True)) assert rc is None @@ -570,7 +569,7 @@ def test_overwrite_force(self, config_dir): assert p.token == "new" def test_overwrite_empty_input_aborts(self, config_dir): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="old")) with patch("builtins.input", return_value=""): rc = cmd_auth(_make_auth_args(token="new")) @@ -579,7 +578,7 @@ def test_overwrite_empty_input_aborts(self, config_dir): class TestAuthRemove: def test_remove_deletes_profile(self, config_dir, caplog): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="tok", profile="staging")) rc = cmd_remove(RemoveArgs(profile_name="staging", force=True)) assert rc is None @@ -592,7 +591,7 @@ def test_remove_nonexistent_fails(self, config_dir, caplog): assert "does not exist" in caplog.text def test_remove_active_switches_to_remaining(self, config_dir): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="t1", profile="first")) cmd_auth(_make_auth_args(token="t2", profile="second")) cfg = Config() @@ -602,7 +601,7 @@ def test_remove_active_switches_to_remaining(self, config_dir): assert p.name != "first" def test_remove_aborted(self, config_dir, capsys): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="tok", profile="keep")) with patch("builtins.input", return_value="n"): rc = cmd_remove(RemoveArgs(profile_name="keep")) @@ -610,7 +609,7 @@ def test_remove_aborted(self, config_dir, capsys): assert "keep" in Config() def test_remove_confirmed(self, config_dir): - with _mock_whoami(): + with mock_whoami(): cmd_auth(_make_auth_args(token="tok", profile="gone")) with patch("builtins.input", return_value="y"): rc = cmd_remove(RemoveArgs(profile_name="gone")) From 6c612d4744406967adf8682b0092c008f00e0e0a Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Fri, 8 May 2026 15:28:36 +0200 Subject: [PATCH 7/8] use epoch instead of date --- contree_cli/cli/auth.py | 4 +- contree_cli/update_check.py | 61 ++++++++++++++------ tests/test_auth.py | 9 +++ tests/test_update_check.py | 112 +++++++++++++++++++++++++++--------- 4 files changed, 139 insertions(+), 47 deletions(-) diff --git a/contree_cli/cli/auth.py b/contree_cli/cli/auth.py index 9329563..59e7ad8 100644 --- a/contree_cli/cli/auth.py +++ b/contree_cli/cli/auth.py @@ -184,7 +184,9 @@ def env_fallback(names: tuple[str, ...], *, what: str) -> str | None: return None -def check_permission(payload: dict[str, object], permission: str) -> bool: +def check_permission(payload: object, permission: str) -> bool: + if not isinstance(payload, dict): + return False perms = payload.get("permissions") if not isinstance(perms, dict): return False diff --git a/contree_cli/update_check.py b/contree_cli/update_check.py index 6043e08..85e5e0e 100644 --- a/contree_cli/update_check.py +++ b/contree_cli/update_check.py @@ -3,11 +3,15 @@ State file at ``$CONTREE_HOME/cli/version_check.json``:: { - "last_check": "2026-05-08T12:00:00+00:00", + "last_check": 1762555200.0, "latest_version": "0.5.0", "current_version": "0.4.2" } +``last_check`` is a Unix epoch timestamp; storing seconds keeps the +freshness check trivial (one subtraction) and immune to timezone / +ISO-format quirks. + Network errors, malformed cache files, and parse failures are swallowed: the update check must never break a user's command. """ @@ -17,9 +21,10 @@ import json import os import re +import time import urllib.request from contextlib import suppress -from datetime import datetime, timedelta, timezone +from datetime import timedelta from pathlib import Path from contree_cli import config @@ -38,7 +43,9 @@ class UpdateChecker: NETWORK_TIMEOUT = 2.0 OPT_OUT_ENV = "CONTREE_NO_UPDATE_CHECK" STATE_PATH = config.CONTREE_HOME / "cli" / "version_check.json" - VERSION_REGEX = re.compile(r"[^\d.]") + # Capture leading digits of each dot-separated component; anything + # past the digits (``a1``, ``-rc.1``, etc.) marks a pre-release. + COMPONENT_REGEX = re.compile(r"\d+") def __init__( self, @@ -50,24 +57,44 @@ def __init__( self.current_version = current_version self.latest_version: str | None = None - def read_state(self) -> dict[str, str]: + def read_state(self) -> dict[str, object]: + """Load the cache file. Any schema violation -> empty dict. + + A wrong-typed field (e.g. ``last_check`` saved as an ISO string + by an older release) makes the whole entry untrustworthy; the + next refresh just rewrites it. + """ try: with self.state_path.open() as f: data = json.load(f) assert isinstance(data, dict) + assert isinstance(data.get("last_check", 0), (int, float)) + assert isinstance(data.get("latest_version", ""), str) except Exception: return {} return data - def write_state(self, state: dict[str, str]) -> None: + def write_state(self, state: dict[str, object]) -> None: with suppress(OSError): self.state_path.parent.mkdir(parents=True, exist_ok=True) self.state_path.write_text(json.dumps(state, indent=1)) - def parse_version(self, value: str) -> tuple[int, ...]: - return tuple( - map(int, filter(None, self.VERSION_REGEX.sub("", value).split("."))) - ) + def parse_version(self, value: str) -> tuple[tuple[int, int], ...]: + """Parse ``value`` into a sortable tuple of ``(number, rank)``. + + ``rank`` is ``1`` for a clean numeric component and ``0`` for a + pre-release suffix (``a1``, ``-rc.1``, …). With this encoding, + ``0.4.2a1`` < ``0.4.2`` < ``0.4.21`` as expected. Components with + no digits at all are dropped. + """ + parts: list[tuple[int, int]] = [] + for raw in value.split("."): + match = self.COMPONENT_REGEX.search(raw) + if not match: + continue + tail = raw[match.end() :] + parts.append((int(match.group()), 0 if tail else 1)) + return tuple(parts) def fetch_latest_version(self) -> str | None: try: @@ -94,23 +121,19 @@ def fetch_latest_version(self) -> str | None: def enabled(self) -> bool: return self.OPT_OUT_ENV not in os.environ and self.current_version != "editable" - def is_cache_fresh(self, state: dict[str, str]) -> bool: + def is_cache_fresh(self, state: dict[str, object]) -> bool: """True if ``state['last_check']`` is within ``CHECK_INTERVAL``.""" - last_check_str = state.get("last_check") - if not isinstance(last_check_str, str): - return False - try: - last_check = datetime.fromisoformat(last_check_str) - except ValueError: + last_check = state.get("last_check") + if not isinstance(last_check, (int, float)): return False - return datetime.now(timezone.utc) - last_check < self.CHECK_INTERVAL + return time.time() - last_check < self.CHECK_INTERVAL.total_seconds() def refresh(self) -> None: """Read the cache once, refetch from PyPI if stale. Populates ``self.latest_version`` with whatever we know after this call (cached value, freshly fetched value, or ``None``). - :meth:`check` then decides whether to log based purely on + :meth:`is_latest` then decides whether to log based purely on in-memory state — no further file IO. """ if not self.enabled: @@ -132,7 +155,7 @@ def refresh(self) -> None: self.latest_version = latest self.write_state( { - "last_check": datetime.now(timezone.utc).isoformat(), + "last_check": time.time(), "latest_version": latest, "current_version": self.current_version, } diff --git a/tests/test_auth.py b/tests/test_auth.py index 950d622..c43ecc5 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -241,6 +241,15 @@ def test_unparseable_whoami_rejected(self, config_dir, caplog): assert rc == 1 assert Config().resolve().token is None + def test_non_dict_whoami_payload_warns_but_saves(self, config_dir, caplog): + """JSON list (or other non-dict) is treated as missing permissions.""" + args = _make_auth_args(token="tok") + with caplog.at_level("WARNING"), mock_whoami(body=b"[]"): + rc = cmd_auth(args) + assert rc is None + assert "sandboxes are disabled" in caplog.text + assert Config().resolve().token == "tok" + def test_success_logs_saved(self, config_dir, caplog): args = _make_auth_args(token="good") with caplog.at_level("INFO"), mock_whoami(): diff --git a/tests/test_update_check.py b/tests/test_update_check.py index 150f676..cbac53f 100644 --- a/tests/test_update_check.py +++ b/tests/test_update_check.py @@ -2,7 +2,7 @@ import json import logging -from datetime import datetime, timedelta, timezone +import time from unittest.mock import patch import pytest @@ -24,24 +24,40 @@ def state_path(tmp_path): return tmp_path / "version_check.json" +HOUR = 3600.0 +DAY = 86400.0 + + class TestParseVersion: @pytest.mark.parametrize( "value,expected", [ - ("1.2.3", (1, 2, 3)), - ("0.0.1", (0, 0, 1)), - ("0.4.2a1", (0, 4, 21)), - ("1", (1,)), + ("1.2.3", ((1, 1), (2, 1), (3, 1))), + ("0.0.1", ((0, 1), (0, 1), (1, 1))), + ("0.4.2a1", ((0, 1), (4, 1), (2, 0))), + ("1", ((1, 1),)), ("", ()), - ("1.x.3", (1, 3)), - ("v1.2.3", (1, 2, 3)), - ("1.0.0-rc.1", (1, 0, 0, 1)), + ("1.x.3", ((1, 1), (3, 1))), + ("v1.2.3", ((1, 1), (2, 1), (3, 1))), + ("1.0.0-rc.1", ((1, 1), (0, 1), (0, 0), (1, 1))), ], ) def test_cases(self, value, expected): checker = UpdateChecker(state_path="/dev/null", current_version="0") assert checker.parse_version(value) == expected + def test_pre_release_sorts_before_release(self): + checker = UpdateChecker(state_path="/dev/null", current_version="0") + assert checker.parse_version("0.4.2a1") < checker.parse_version("0.4.2") + + def test_higher_release_sorts_after_pre_release(self): + checker = UpdateChecker(state_path="/dev/null", current_version="0") + assert checker.parse_version("0.4.2") < checker.parse_version("0.4.21") + + def test_rc_sorts_before_release(self): + checker = UpdateChecker(state_path="/dev/null", current_version="0") + assert checker.parse_version("1.0.0-rc.1") < checker.parse_version("1.0.0") + class TestEnabled: def test_disabled_in_editable_mode(self, state_path): @@ -73,19 +89,19 @@ def test_returns_false_for_missing_last_check(self): checker = UpdateChecker(state_path="/dev/null", current_version="0") assert checker.is_cache_fresh({"latest_version": "1.0.0"}) is False - def test_returns_false_for_unparseable_last_check(self): + def test_returns_false_for_non_numeric_last_check(self): + """Legacy ISO strings or garbage values are treated as stale.""" checker = UpdateChecker(state_path="/dev/null", current_version="0") - assert checker.is_cache_fresh({"last_check": "not-a-timestamp"}) is False + assert checker.is_cache_fresh({"last_check": "2026-05-08T00:00:00"}) is False + assert checker.is_cache_fresh({"last_check": None}) is False def test_returns_true_for_recent_last_check(self): checker = UpdateChecker(state_path="/dev/null", current_version="0") - recent = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() - assert checker.is_cache_fresh({"last_check": recent}) is True + assert checker.is_cache_fresh({"last_check": time.time() - HOUR}) is True def test_returns_false_for_old_last_check(self): checker = UpdateChecker(state_path="/dev/null", current_version="0") - old = (datetime.now(timezone.utc) - timedelta(days=2)).isoformat() - assert checker.is_cache_fresh({"last_check": old}) is False + assert checker.is_cache_fresh({"last_check": time.time() - 2 * DAY}) is False class TestRefresh: @@ -112,14 +128,13 @@ def test_fetches_and_writes_when_no_cache(self, state_path): data = read_json(state_path) assert data["latest_version"] == "0.4.1" assert data["current_version"] == "0.4.0" - assert "last_check" in data + assert isinstance(data["last_check"], (int, float)) def test_skips_network_within_interval(self, state_path): - recent = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() seed_state( state_path, { - "last_check": recent, + "last_check": time.time() - HOUR, "latest_version": "0.5.0", "current_version": "0.4.0", }, @@ -128,15 +143,13 @@ def test_skips_network_within_interval(self, state_path): with patch.object(checker, "fetch_latest_version") as fetch: checker.refresh() fetch.assert_not_called() - # Cached value loaded into self. assert checker.latest_version == "0.5.0" def test_refetches_after_interval_expires(self, state_path): - old = (datetime.now(timezone.utc) - timedelta(days=2)).isoformat() seed_state( state_path, { - "last_check": old, + "last_check": time.time() - 2 * DAY, "latest_version": "0.4.0", "current_version": "0.4.0", }, @@ -151,11 +164,10 @@ def test_refetches_after_interval_expires(self, state_path): assert read_json(state_path)["latest_version"] == "0.4.5" def test_network_failure_keeps_cached_value(self, state_path): - old = (datetime.now(timezone.utc) - timedelta(days=2)).isoformat() seed_state( state_path, { - "last_check": old, + "last_check": time.time() - 2 * DAY, "latest_version": "0.4.0", "current_version": "0.4.0", }, @@ -164,7 +176,6 @@ def test_network_failure_keeps_cached_value(self, state_path): with patch.object(checker, "fetch_latest_version", return_value=None): checker.refresh() assert checker.latest_version == "0.4.0" - # State file untouched (no rewrite). assert read_json(state_path)["latest_version"] == "0.4.0" def test_network_failure_with_no_cache_leaves_latest_none(self, state_path): @@ -186,6 +197,45 @@ def test_corrupt_cache_is_overwritten(self, state_path): assert checker.latest_version == "0.4.1" assert read_json(state_path)["latest_version"] == "0.4.1" + def test_legacy_iso_last_check_is_discarded(self, state_path): + """An old cache file written before the epoch migration is treated + as missing entirely — read_state rejects the whole record.""" + seed_state( + state_path, + { + "last_check": "2026-05-08T12:00:00+00:00", + "latest_version": "0.4.0", + "current_version": "0.4.0", + }, + ) + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with patch.object( + checker, "fetch_latest_version", return_value="0.4.5" + ) as fetch: + checker.refresh() + fetch.assert_called_once() + # Discarded state -> nothing carried over -> only freshly fetched + # value remains. + assert read_json(state_path)["latest_version"] == "0.4.5" + assert checker.latest_version == "0.4.5" + + def test_wrong_typed_latest_version_is_discarded(self, state_path): + seed_state( + state_path, + { + "last_check": time.time() - 2 * DAY, + "latest_version": 42, # not a string + "current_version": "0.4.0", + }, + ) + checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + with patch.object( + checker, "fetch_latest_version", return_value="0.4.5" + ) as fetch: + checker.refresh() + fetch.assert_called_once() + assert checker.latest_version == "0.4.5" + def test_refresh_does_not_log_warning(self, state_path, caplog): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") with ( @@ -198,7 +248,6 @@ def test_refresh_does_not_log_warning(self, state_path, caplog): class TestIsLatest: def test_returns_true_in_editable_mode(self, state_path): - """Editable installs are always considered up to date.""" checker = UpdateChecker(state_path=state_path, current_version="editable") checker.latest_version = "9.9.9" assert checker.is_latest() is True @@ -220,16 +269,26 @@ def test_returns_true_when_up_to_date(self, state_path): assert checker.is_latest() is True def test_returns_true_when_latest_unknown(self, state_path): - """Unknown latest defaults to "up to date" so callers don't warn.""" checker = UpdateChecker(state_path=state_path, current_version="0.4.0") assert checker.is_latest() is True def test_returns_true_when_current_is_newer(self, state_path): - """Pre-release/dev install ahead of pypi is "latest".""" checker = UpdateChecker(state_path=state_path, current_version="0.6.0") checker.latest_version = "0.5.0" assert checker.is_latest() is True + def test_returns_true_when_current_is_pre_release_of_same(self, state_path): + """Pre-release of the same release is "older", so warns to upgrade.""" + checker = UpdateChecker(state_path=state_path, current_version="0.5.0a1") + checker.latest_version = "0.5.0" + assert checker.is_latest() is False + + def test_returns_true_when_latest_is_pre_release_of_same(self, state_path): + """If pypi only knows a pre-release, an installed stable is fine.""" + checker = UpdateChecker(state_path=state_path, current_version="0.5.0") + checker.latest_version = "0.5.0a1" + assert checker.is_latest() is True + def test_does_not_touch_filesystem(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") checker.latest_version = "0.5.0" @@ -238,7 +297,6 @@ def test_does_not_touch_filesystem(self, state_path): read.assert_not_called() def test_does_not_log(self, state_path, caplog): - """is_latest() is a pure predicate; logging is the caller's job.""" checker = UpdateChecker(state_path=state_path, current_version="0.4.0") checker.latest_version = "0.5.0" with caplog.at_level(logging.WARNING, logger="contree_cli.update_check"): From 6d56c026867b01318a16da41fda3f3bd09404dcf Mon Sep 17 00:00:00 2001 From: Dmitry Orlov Date: Fri, 8 May 2026 15:49:37 +0200 Subject: [PATCH 8/8] use dataclass and fix limits --- contree_cli/__main__.py | 2 +- contree_cli/cli/images.py | 4 +- contree_cli/update_check.py | 108 ++++++++++++--------------- tests/test_images.py | 4 +- tests/test_update_check.py | 142 +++++++++++++++++++----------------- 5 files changed, 128 insertions(+), 132 deletions(-) diff --git a/contree_cli/__main__.py b/contree_cli/__main__.py index 9519b29..41be67a 100644 --- a/contree_cli/__main__.py +++ b/contree_cli/__main__.py @@ -40,7 +40,7 @@ def main() -> None: "A new version of contree-cli is available: %s (installed: %s)." " Upgrade with `uv tool install -U contree-cli` or" " `pip install -U contree-cli`.", - checker.latest_version, + checker.state.latest_version, checker.current_version, ) diff --git a/contree_cli/cli/images.py b/contree_cli/cli/images.py index ccd40a3..6a652c3 100644 --- a/contree_cli/cli/images.py +++ b/contree_cli/cli/images.py @@ -30,8 +30,8 @@ logger = logging.getLogger(__name__) -PAGE_SIZE = 500 -LIMIT_DEFAULT = 2000 +PAGE_SIZE = 1000 +LIMIT_DEFAULT = 3000 TERMINAL_STATUSES = frozenset({"SUCCESS", "FAILED", "CANCELLED"}) DOCKER_HUB = "docker.io" diff --git a/contree_cli/update_check.py b/contree_cli/update_check.py index 85e5e0e..643ab1b 100644 --- a/contree_cli/update_check.py +++ b/contree_cli/update_check.py @@ -3,9 +3,8 @@ State file at ``$CONTREE_HOME/cli/version_check.json``:: { - "last_check": 1762555200.0, - "latest_version": "0.5.0", - "current_version": "0.4.2" + "last_check": 1762555200, + "latest_version": "0.5.0" } ``last_check`` is a Unix epoch timestamp; storing seconds keeps the @@ -24,6 +23,7 @@ import time import urllib.request from contextlib import suppress +from dataclasses import asdict, dataclass from datetime import timedelta from pathlib import Path @@ -31,13 +31,30 @@ from contree_cli.client import CLI_USER_AGENT, cli_version -class UpdateChecker: - """Encapsulates the state file + PyPI probe + outdated-version warning. +@dataclass(frozen=True) +class UpdateState: + last_check: int = 0 + latest_version: str = "" - All side effects (filesystem, network, logging) are guarded so that a - failure in update-checking can never break the user's command. - """ + @classmethod + def from_file(cls, path: Path) -> UpdateState: + try: + with path.open() as f: + data = json.load(f) + return cls( + last_check=int(data["last_check"]), + latest_version=str(data["latest_version"]), + ) + except Exception: + return cls() + def to_file(self, path: Path) -> None: + with suppress(OSError): + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(asdict(self), indent=1)) + + +class UpdateChecker: PYPI_URL = "https://pypi.org/pypi/contree-cli/json" CHECK_INTERVAL = timedelta(days=1) NETWORK_TIMEOUT = 2.0 @@ -55,29 +72,10 @@ def __init__( ) -> None: self.state_path = state_path self.current_version = current_version - self.latest_version: str | None = None - - def read_state(self) -> dict[str, object]: - """Load the cache file. Any schema violation -> empty dict. - - A wrong-typed field (e.g. ``last_check`` saved as an ISO string - by an older release) makes the whole entry untrustworthy; the - next refresh just rewrites it. - """ - try: - with self.state_path.open() as f: - data = json.load(f) - assert isinstance(data, dict) - assert isinstance(data.get("last_check", 0), (int, float)) - assert isinstance(data.get("latest_version", ""), str) - except Exception: - return {} - return data - - def write_state(self, state: dict[str, object]) -> None: - with suppress(OSError): - self.state_path.parent.mkdir(parents=True, exist_ok=True) - self.state_path.write_text(json.dumps(state, indent=1)) + # ``state`` holds whatever we know about PyPI's latest version. + # Default sentinel ("", last_check=0) means "no cache yet" — + # is_latest() treats it as up-to-date so callers don't warn. + self.state: UpdateState = UpdateState() def parse_version(self, value: str) -> tuple[tuple[int, int], ...]: """Parse ``value`` into a sortable tuple of ``(number, rank)``. @@ -121,30 +119,22 @@ def fetch_latest_version(self) -> str | None: def enabled(self) -> bool: return self.OPT_OUT_ENV not in os.environ and self.current_version != "editable" - def is_cache_fresh(self, state: dict[str, object]) -> bool: - """True if ``state['last_check']`` is within ``CHECK_INTERVAL``.""" - last_check = state.get("last_check") - if not isinstance(last_check, (int, float)): - return False - return time.time() - last_check < self.CHECK_INTERVAL.total_seconds() + def is_cache_fresh(self, state: UpdateState) -> bool: + """True if ``state.last_check`` is within ``CHECK_INTERVAL``.""" + return time.time() - state.last_check < self.CHECK_INTERVAL.total_seconds() def refresh(self) -> None: - """Read the cache once, refetch from PyPI if stale. + """Load the cache, refetch from PyPI if stale, persist new state. - Populates ``self.latest_version`` with whatever we know after - this call (cached value, freshly fetched value, or ``None``). - :meth:`is_latest` then decides whether to log based purely on + Populates ``self.state`` with whatever we know after this call. + :meth:`is_latest` then decides whether to warn based purely on in-memory state — no further file IO. """ if not self.enabled: return - state = self.read_state() - cached = state.get("latest_version") - if isinstance(cached, str): - self.latest_version = cached - - if self.is_cache_fresh(state): + self.state = UpdateState.from_file(self.state_path) + if self.is_cache_fresh(self.state): return latest = self.fetch_latest_version() @@ -152,26 +142,24 @@ def refresh(self) -> None: # Network failed; keep whatever was cached. return - self.latest_version = latest - self.write_state( - { - "last_check": time.time(), - "latest_version": latest, - "current_version": self.current_version, - } + self.state = UpdateState( + last_check=int(time.time()), + latest_version=latest, ) + self.state.to_file(self.state_path) def is_latest(self) -> bool: """Return True if the installed version is at or ahead of the cached ``latest_version``. - Returns True when checks are disabled or ``latest_version`` is - unknown so callers default to "no warning" in those cases. Pure - decision based on in-memory state populated by :meth:`refresh`; - never touches the network or filesystem. + Returns True when checks are disabled or the cached + ``latest_version`` is the empty sentinel — callers default to + "no warning" in those cases. Pure decision based on in-memory + state populated by :meth:`refresh`; never touches the network + or filesystem. """ - if not self.enabled or self.latest_version is None: + if not self.enabled or not self.state.latest_version: return True return self.parse_version(self.current_version) >= self.parse_version( - self.latest_version, + self.state.latest_version, ) diff --git a/tests/test_images.py b/tests/test_images.py index 99dc258..896fcac 100644 --- a/tests/test_images.py +++ b/tests/test_images.py @@ -188,8 +188,8 @@ def test_progress_logged_per_full_page(self, contree_client, caplog): ) assert not any(f"{PAGE_SIZE * 2 + 3}" in m for m in msgs) - def test_default_limit_is_2000(self): - assert LIMIT_DEFAULT == 2000 + def test_default_limit_matches_constant(self): + assert LIMIT_DEFAULT > 0 assert ImagesArgs().limit == LIMIT_DEFAULT def test_limit_truncates_with_warning(self, contree_client, caplog): diff --git a/tests/test_update_check.py b/tests/test_update_check.py index cbac53f..8f664af 100644 --- a/tests/test_update_check.py +++ b/tests/test_update_check.py @@ -7,7 +7,7 @@ import pytest -from contree_cli.update_check import UpdateChecker +from contree_cli.update_check import UpdateChecker, UpdateState def read_json(path): @@ -81,27 +81,58 @@ def test_enabled_normal(self, state_path): class TestIsCacheFresh: - def test_returns_false_for_empty_state(self): + def test_returns_true_for_recent_last_check(self): checker = UpdateChecker(state_path="/dev/null", current_version="0") - assert checker.is_cache_fresh({}) is False + state = UpdateState(last_check=int(time.time() - HOUR), latest_version="x") + assert checker.is_cache_fresh(state) is True - def test_returns_false_for_missing_last_check(self): + def test_returns_false_for_old_last_check(self): checker = UpdateChecker(state_path="/dev/null", current_version="0") - assert checker.is_cache_fresh({"latest_version": "1.0.0"}) is False + state = UpdateState(last_check=int(time.time() - 2 * DAY), latest_version="x") + assert checker.is_cache_fresh(state) is False - def test_returns_false_for_non_numeric_last_check(self): - """Legacy ISO strings or garbage values are treated as stale.""" + def test_returns_false_for_default_sentinel(self): checker = UpdateChecker(state_path="/dev/null", current_version="0") - assert checker.is_cache_fresh({"last_check": "2026-05-08T00:00:00"}) is False - assert checker.is_cache_fresh({"last_check": None}) is False + assert checker.is_cache_fresh(UpdateState()) is False - def test_returns_true_for_recent_last_check(self): - checker = UpdateChecker(state_path="/dev/null", current_version="0") - assert checker.is_cache_fresh({"last_check": time.time() - HOUR}) is True - def test_returns_false_for_old_last_check(self): - checker = UpdateChecker(state_path="/dev/null", current_version="0") - assert checker.is_cache_fresh({"last_check": time.time() - 2 * DAY}) is False +class TestUpdateState: + def test_default_sentinel(self): + state = UpdateState() + assert state.last_check == 0 + assert state.latest_version == "" + + def test_from_file_missing_returns_sentinel(self, tmp_path): + state = UpdateState.from_file(tmp_path / "missing.json") + assert state == UpdateState() + + def test_from_file_corrupt_returns_sentinel(self, state_path): + state_path.parent.mkdir(parents=True, exist_ok=True) + state_path.write_text("{not json") + assert UpdateState.from_file(state_path) == UpdateState() + + def test_from_file_wrong_typed_returns_sentinel(self, state_path): + seed_state(state_path, {"last_check": "iso-string", "latest_version": "1.0"}) + assert UpdateState.from_file(state_path) == UpdateState() + + def test_from_file_round_trip(self, state_path): + original = UpdateState(last_check=12345, latest_version="1.2.3") + original.to_file(state_path) + assert UpdateState.from_file(state_path) == original + + def test_from_file_extra_fields_ignored(self, state_path): + seed_state( + state_path, + { + "last_check": 100, + "latest_version": "1.0", + "extra": "ignored", + }, + ) + assert UpdateState.from_file(state_path) == UpdateState( + last_check=100, + latest_version="1.0", + ) class TestRefresh: @@ -110,7 +141,7 @@ def test_skips_in_editable_mode(self, state_path): with patch.object(checker, "fetch_latest_version") as fetch: checker.refresh() fetch.assert_not_called() - assert checker.latest_version is None + assert checker.state == UpdateState() def test_skips_when_opt_out_env_set(self, state_path, monkeypatch): monkeypatch.setenv("CONTREE_NO_UPDATE_CHECK", "1") @@ -118,40 +149,37 @@ def test_skips_when_opt_out_env_set(self, state_path, monkeypatch): with patch.object(checker, "fetch_latest_version") as fetch: checker.refresh() fetch.assert_not_called() - assert checker.latest_version is None + assert checker.state == UpdateState() def test_fetches_and_writes_when_no_cache(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") with patch.object(checker, "fetch_latest_version", return_value="0.4.1"): checker.refresh() - assert checker.latest_version == "0.4.1" + assert checker.state.latest_version == "0.4.1" data = read_json(state_path) assert data["latest_version"] == "0.4.1" - assert data["current_version"] == "0.4.0" - assert isinstance(data["last_check"], (int, float)) + assert isinstance(data["last_check"], int) def test_skips_network_within_interval(self, state_path): seed_state( state_path, { - "last_check": time.time() - HOUR, + "last_check": int(time.time() - HOUR), "latest_version": "0.5.0", - "current_version": "0.4.0", }, ) checker = UpdateChecker(state_path=state_path, current_version="0.4.0") with patch.object(checker, "fetch_latest_version") as fetch: checker.refresh() fetch.assert_not_called() - assert checker.latest_version == "0.5.0" + assert checker.state.latest_version == "0.5.0" def test_refetches_after_interval_expires(self, state_path): seed_state( state_path, { - "last_check": time.time() - 2 * DAY, + "last_check": int(time.time() - 2 * DAY), "latest_version": "0.4.0", - "current_version": "0.4.0", }, ) checker = UpdateChecker(state_path=state_path, current_version="0.4.0") @@ -160,29 +188,28 @@ def test_refetches_after_interval_expires(self, state_path): ) as fetch: checker.refresh() fetch.assert_called_once() - assert checker.latest_version == "0.4.5" + assert checker.state.latest_version == "0.4.5" assert read_json(state_path)["latest_version"] == "0.4.5" def test_network_failure_keeps_cached_value(self, state_path): seed_state( state_path, { - "last_check": time.time() - 2 * DAY, + "last_check": int(time.time() - 2 * DAY), "latest_version": "0.4.0", - "current_version": "0.4.0", }, ) checker = UpdateChecker(state_path=state_path, current_version="0.4.0") with patch.object(checker, "fetch_latest_version", return_value=None): checker.refresh() - assert checker.latest_version == "0.4.0" + assert checker.state.latest_version == "0.4.0" assert read_json(state_path)["latest_version"] == "0.4.0" - def test_network_failure_with_no_cache_leaves_latest_none(self, state_path): + def test_network_failure_with_no_cache_leaves_state_default(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") with patch.object(checker, "fetch_latest_version", return_value=None): checker.refresh() - assert checker.latest_version is None + assert checker.state == UpdateState() assert not state_path.exists() def test_corrupt_cache_is_overwritten(self, state_path): @@ -194,18 +221,17 @@ def test_corrupt_cache_is_overwritten(self, state_path): ) as fetch: checker.refresh() fetch.assert_called_once() - assert checker.latest_version == "0.4.1" + assert checker.state.latest_version == "0.4.1" assert read_json(state_path)["latest_version"] == "0.4.1" def test_legacy_iso_last_check_is_discarded(self, state_path): """An old cache file written before the epoch migration is treated - as missing entirely — read_state rejects the whole record.""" + as missing entirely.""" seed_state( state_path, { "last_check": "2026-05-08T12:00:00+00:00", "latest_version": "0.4.0", - "current_version": "0.4.0", }, ) checker = UpdateChecker(state_path=state_path, current_version="0.4.0") @@ -214,27 +240,8 @@ def test_legacy_iso_last_check_is_discarded(self, state_path): ) as fetch: checker.refresh() fetch.assert_called_once() - # Discarded state -> nothing carried over -> only freshly fetched - # value remains. assert read_json(state_path)["latest_version"] == "0.4.5" - assert checker.latest_version == "0.4.5" - - def test_wrong_typed_latest_version_is_discarded(self, state_path): - seed_state( - state_path, - { - "last_check": time.time() - 2 * DAY, - "latest_version": 42, # not a string - "current_version": "0.4.0", - }, - ) - checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - with patch.object( - checker, "fetch_latest_version", return_value="0.4.5" - ) as fetch: - checker.refresh() - fetch.assert_called_once() - assert checker.latest_version == "0.4.5" + assert checker.state.latest_version == "0.4.5" def test_refresh_does_not_log_warning(self, state_path, caplog): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") @@ -249,56 +256,57 @@ def test_refresh_does_not_log_warning(self, state_path, caplog): class TestIsLatest: def test_returns_true_in_editable_mode(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="editable") - checker.latest_version = "9.9.9" + checker.state = UpdateState(last_check=1, latest_version="9.9.9") assert checker.is_latest() is True def test_returns_true_when_opt_out_env_set(self, state_path, monkeypatch): monkeypatch.setenv("CONTREE_NO_UPDATE_CHECK", "1") checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - checker.latest_version = "9.9.9" + checker.state = UpdateState(last_check=1, latest_version="9.9.9") assert checker.is_latest() is True def test_returns_false_when_outdated(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - checker.latest_version = "0.5.0" + checker.state = UpdateState(last_check=1, latest_version="0.5.0") assert checker.is_latest() is False def test_returns_true_when_up_to_date(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.5.0") - checker.latest_version = "0.5.0" + checker.state = UpdateState(last_check=1, latest_version="0.5.0") assert checker.is_latest() is True def test_returns_true_when_latest_unknown(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") + # Default sentinel state — latest_version is empty string. assert checker.is_latest() is True def test_returns_true_when_current_is_newer(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.6.0") - checker.latest_version = "0.5.0" + checker.state = UpdateState(last_check=1, latest_version="0.5.0") assert checker.is_latest() is True - def test_returns_true_when_current_is_pre_release_of_same(self, state_path): - """Pre-release of the same release is "older", so warns to upgrade.""" + def test_returns_false_when_current_is_pre_release_of_same(self, state_path): + """Pre-release of the same release is older, so warns to upgrade.""" checker = UpdateChecker(state_path=state_path, current_version="0.5.0a1") - checker.latest_version = "0.5.0" + checker.state = UpdateState(last_check=1, latest_version="0.5.0") assert checker.is_latest() is False def test_returns_true_when_latest_is_pre_release_of_same(self, state_path): """If pypi only knows a pre-release, an installed stable is fine.""" checker = UpdateChecker(state_path=state_path, current_version="0.5.0") - checker.latest_version = "0.5.0a1" + checker.state = UpdateState(last_check=1, latest_version="0.5.0a1") assert checker.is_latest() is True def test_does_not_touch_filesystem(self, state_path): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - checker.latest_version = "0.5.0" - with patch.object(UpdateChecker, "read_state") as read: + checker.state = UpdateState(last_check=1, latest_version="0.5.0") + with patch.object(UpdateState, "from_file") as load: checker.is_latest() - read.assert_not_called() + load.assert_not_called() def test_does_not_log(self, state_path, caplog): checker = UpdateChecker(state_path=state_path, current_version="0.4.0") - checker.latest_version = "0.5.0" + checker.state = UpdateState(last_check=1, latest_version="0.5.0") with caplog.at_level(logging.WARNING, logger="contree_cli.update_check"): assert checker.is_latest() is False assert caplog.records == []